Source code for mne_rt.protocols.percentile
"""Percentile-based feedback reward protocol for MNE-RT.
This module provides :class:`PercentileProtocol`, a stateful protocol that
rewards the participant when the current NF value crosses the Nth percentile
of a rolling history buffer.
Classes
-------
PercentileProtocol
Rolling-percentile threshold comparator with configurable direction.
"""
from __future__ import annotations
import collections
from typing import Optional
import numpy as np
class PercentileProtocol:
"""Percentile-based NF reward protocol with rolling history.
Maintains a fixed-length circular buffer of recent (optionally smoothed)
NF values. At each call to :meth:`evaluate` the Nth percentile of that
buffer is computed and used as the dynamic reward threshold.
A reward is issued when the current value exceeds (``"up"``) or falls
below (``"down"``) that threshold.
Parameters
----------
percentile : float
Target percentile in the range ``(0, 100)`` used to derive the
dynamic threshold from the history buffer. Default is 75.0.
direction : {"up", "down"}
"up" -> reward when value > percentile threshold.
"down" -> reward when value < percentile threshold.
Default is "up".
history_len : int
Maximum number of recent values retained in the rolling buffer.
Must be >= 2. Default is 100.
smoothing : float
EMA smoothing coefficient applied to the raw input before
comparison. Must be in ``[0, 1)``. ``0.0`` disables smoothing.
Applied as: ``smoothed = (1 - smoothing) * new + smoothing * prev``.
Default is 0.0.
Raises
------
ValueError
If any parameter is outside its valid range.
Notes
-----
The ``current_threshold`` is ``nan`` until at least two values have been
added to the history buffer (``numpy.percentile`` requires at least one
element, but a single-element buffer is degenerate).
Examples
--------
Reward when alpha power exceeds the 75th percentile of recent history::
proto = PercentileProtocol(percentile=75.0, direction="up")
for value in nf_stream:
crossed, magnitude = proto.evaluate(value)
if crossed:
send_reward(magnitude)
.. versionadded:: 1.0.0
"""
[docs]
def __init__(
self,
percentile: float = 75.0,
direction: str = "up",
history_len: int = 100,
smoothing: float = 0.0,
) -> None:
if not (0.0 < percentile < 100.0):
raise ValueError(
f"percentile must be in (0, 100), got {percentile}"
)
if direction not in ("up", "down"):
raise ValueError(
f"direction must be 'up' or 'down', got {direction!r}"
)
if history_len < 2:
raise ValueError(
f"history_len must be >= 2, got {history_len}"
)
if not (0.0 <= smoothing < 1.0):
raise ValueError(
f"smoothing must be in [0, 1), got {smoothing}"
)
self.percentile: float = percentile
self.direction: str = direction
self.history_len: int = history_len
self.smoothing: float = smoothing
self._history: collections.deque[float] = collections.deque(
maxlen=history_len
)
self._hits: collections.deque[bool] = collections.deque(
maxlen=history_len
)
self._smoothed: Optional[float] = None
self._n_evaluated: int = 0
self._current_threshold: float = float("nan")
[docs]
def evaluate(self, value: float) -> tuple[bool, float]:
"""Evaluate one NF value and return (crossed, magnitude).
Appends the (optionally smoothed) value to the rolling buffer,
recomputes the percentile threshold, and tests the crossing condition.
Parameters
----------
value : float
Current NF feature value.
Returns
-------
crossed : bool
True if the current value is on the reward side of the
percentile threshold.
magnitude : float
Absolute distance between the current value and the threshold
when ``crossed`` is True; ``0.0`` otherwise. Returns ``0.0``
when the history buffer contains fewer than two entries.
"""
if self.smoothing > 0.0:
if self._smoothed is None:
self._smoothed = float(value)
else:
self._smoothed = (
(1.0 - self.smoothing) * value
+ self.smoothing * self._smoothed
)
else:
self._smoothed = float(value)
smoothed = self._smoothed
self._history.append(smoothed)
self._n_evaluated += 1
if len(self._history) < 2:
self._current_threshold = float("nan")
self._hits.append(False)
return False, 0.0
self._current_threshold = float(
np.percentile(list(self._history), self.percentile)
)
if self.direction == "up":
crossed = smoothed > self._current_threshold
else:
crossed = smoothed < self._current_threshold
self._hits.append(crossed)
magnitude = abs(smoothed - self._current_threshold) if crossed else 0.0
return crossed, magnitude
[docs]
def reset(self) -> None:
"""Reset all state.
Clears the rolling history buffer, hit log, smoothed value, and
evaluation counter. All constructor parameters are preserved.
"""
self._history.clear()
self._hits.clear()
self._smoothed = None
self._n_evaluated = 0
self._current_threshold = float("nan")
@property
def n_evaluated(self) -> int:
"""Total number of values evaluated since init or last reset."""
return self._n_evaluated
@property
def current_threshold(self) -> float:
"""Percentile threshold computed during the last :meth:`evaluate` call.
Returns ``nan`` until at least two values have been accumulated.
"""
return self._current_threshold
@property
def hit_rate(self) -> float:
"""Fraction of recent evaluations that crossed the threshold (0–1).
Computed over the rolling window defined by ``history_len``.
Returns ``0.0`` when no evaluations have been recorded yet.
"""
if not self._hits:
return 0.0
return sum(self._hits) / len(self._hits)
def __repr__(self) -> str:
return (
f"PercentileProtocol("
f"percentile={self.percentile}, "
f"direction={self.direction!r}, "
f"current_threshold={self._current_threshold:.4g}, "
f"hit_rate={self.hit_rate:.2f}, "
f"n_evaluated={self._n_evaluated})"
)