Source code for mne_rt.protocols.threshold

"""Threshold-based feedback reward protocol for MNE-RT.

This module provides :class:`ThresholdProtocol`, a lightweight stateful object
that converts a continuous NF feature value into a binary reward signal with an
optional adaptive threshold mechanism.

Classes
-------
ThresholdProtocol
    Threshold comparator with optional EMA smoothing and adaptive threshold.
"""
from __future__ import annotations

import collections
from typing import Optional

import numpy as np


class ThresholdProtocol:
    """Threshold-based NF reward protocol with optional adaptive threshold.

    Converts a continuous NF feature value into a binary reward signal.
    Optionally adapts the threshold over time to maintain a target success rate.

    Parameters
    ----------
    threshold : float
        Initial decision threshold.  Default is 0.0.
    direction : {"up", "down"}
        "up"  -> reward when value > threshold (e.g., enhance alpha).
        "down" -> reward when value < threshold (e.g., suppress beta).
        Default is "up".
    adaptive : bool
        If True, slowly adjust threshold to keep hit_rate near target_hit_rate.
        Uses an exponential moving average of recent successes.  Default is
        False.
    adapt_rate : float
        Step size for threshold adaptation (in units of the NF signal's
        running standard deviation). Larger = faster adaptation.  Default is
        0.05.
    target_hit_rate : float
        Desired proportion of windows where the threshold is crossed.
        The adaptive mechanism pushes threshold toward this rate.  Default is
        0.70.
    smoothing : float
        EMA smoothing factor for the input value before thresholding.
        0.0 = no smoothing; 0.1 = light smoothing; 0.5 = heavy smoothing.
        Applied as: smoothed = (1 - smoothing) * new + smoothing * prev.
        Default is 0.0.
    history_len : int
        Number of recent evaluations used to estimate the running hit rate
        and running standard deviation (for adaptive scaling).  Default is
        50.

    Raises
    ------
    ValueError
        If any parameter is outside its valid range.

    Examples
    --------
    Basic usage — reward when alpha power exceeds a fixed threshold::

        proto = ThresholdProtocol(threshold=0.5, direction="up")
        crossed, magnitude = proto.evaluate(0.8)

    Adaptive threshold that targets a 70 % success rate::

        proto = ThresholdProtocol(
            threshold=0.0,
            direction="up",
            adaptive=True,
            target_hit_rate=0.70,
        )
        for value in nf_stream:
            crossed, magnitude = proto.evaluate(value)
    """

[docs] def __init__( self, threshold: float = 0.0, direction: str = "up", adaptive: bool = False, adapt_rate: float = 0.05, target_hit_rate: float = 0.70, smoothing: float = 0.0, history_len: int = 50, ) -> None: if direction not in ("up", "down"): raise ValueError( f"direction must be 'up' or 'down', got {direction!r}" ) if adapt_rate <= 0: raise ValueError( f"adapt_rate must be > 0, got {adapt_rate}" ) if not (0 < target_hit_rate < 1): raise ValueError( f"target_hit_rate must be in (0, 1), got {target_hit_rate}" ) if not (0 <= smoothing < 1): raise ValueError( f"smoothing must be in [0, 1), got {smoothing}" ) if history_len < 5: raise ValueError( f"history_len must be >= 5, got {history_len}" ) self._threshold: float = float(threshold) self.direction: str = direction self.adaptive: bool = adaptive self.adapt_rate: float = adapt_rate self.target_hit_rate: float = target_hit_rate self.smoothing: float = smoothing self._history: collections.deque[bool] = collections.deque( maxlen=history_len ) self._values_history: collections.deque[float] = collections.deque( maxlen=history_len ) self._smoothed: Optional[float] = None self._n_evaluated: int = 0
[docs] def evaluate(self, value: float) -> tuple[bool, float]: """Evaluate one NF value and return (success, reward_magnitude). Notes ----- EMA smoothing (if enabled) is applied first, then the smoothed value is compared against the current threshold. The hit/miss result is recorded in history, and (if ``adaptive`` is True) the threshold is updated before the return value is assembled. ``magnitude`` is 0.0 when the threshold was not crossed; a positive float proportional to the distance from the threshold (normalised by the running standard deviation) when crossed. Parameters ---------- value : float Current NF feature value. Returns ------- crossed : bool True if the threshold was crossed in the target direction. magnitude : float Non-negative reward magnitude. 0 when not crossed; ``abs(smoothed - threshold) / (running_std + eps)`` when crossed. """ if self.smoothing > 0.0: if self._smoothed is None: self._smoothed = value else: self._smoothed = ( (1.0 - self.smoothing) * value + self.smoothing * self._smoothed ) else: self._smoothed = float(value) smoothed = self._smoothed if self.direction == "up": crossed = smoothed > self._threshold else: crossed = smoothed < self._threshold self._history.append(crossed) self._values_history.append(smoothed) self._n_evaluated += 1 running_std = ( float(np.std(list(self._values_history))) if len(self._values_history) > 1 else 1.0 ) or 1.0 if self.adaptive and len(self._history) >= 10: step = self.adapt_rate * running_std if self.direction == "up": self._threshold += step * (self.hit_rate - self.target_hit_rate) else: self._threshold -= step * (self.hit_rate - self.target_hit_rate) if crossed: magnitude = abs(smoothed - self._threshold) / (running_std + 1e-6) else: magnitude = 0.0 return crossed, magnitude
[docs] def reset(self) -> None: """Reset hit history and smoothed value, keep threshold. Clears ``_history``, ``_values_history``, and ``_smoothed``, and resets the evaluation counter to zero. The current threshold value is preserved. """ self._history.clear() self._values_history.clear() self._smoothed = None self._n_evaluated = 0
@property def hit_rate(self) -> float: """Fraction of recent windows that crossed the threshold (0–1). Returns 0.0 when no evaluations have been recorded yet. """ if not self._history: return 0.0 return sum(self._history) / len(self._history) @property def threshold(self) -> float: """Current threshold value.""" return self._threshold @threshold.setter def threshold(self, val: float) -> None: """Set threshold directly. Parameters ---------- val : float New threshold value. """ self._threshold = float(val) @property def n_evaluated(self) -> int: """Total number of values evaluated since init or last reset.""" return self._n_evaluated def __repr__(self) -> str: return ( f"ThresholdProtocol(" f"threshold={self._threshold:.4g}, " f"direction={self.direction!r}, " f"adaptive={self.adaptive}, " f"hit_rate={self.hit_rate:.2f}, " f"n_evaluated={self._n_evaluated})" )