"""Reinforcement-learning adaptive threshold protocol for MNE-RT.
This module provides :class:`RLProtocol`, an adaptive feedback protocol
that maintains a target hit rate by adjusting the threshold using a proper RL
update rule with epsilon-greedy exploration.
Classes
-------
RLProtocol
RL-adaptive threshold protocol with epsilon-greedy exploration.
"""
from __future__ import annotations
import collections
from typing import Optional
import numpy as np
class RLProtocol:
"""Adaptive NF protocol with reinforcement-learning threshold updates.
Adjusts the decision threshold after every evaluation to maintain a target
hit rate using the update rule::
threshold += lr * (hit_rate - target_hit_rate) * running_std
Unlike :class:`~mne_rt.protocols.ThresholdProtocol` (which also has an
adaptive mode), this protocol tracks a rolling hit rate in a fixed-length
window, scales updates by the running standard deviation of recent values,
and optionally applies epsilon-greedy exploration: with probability
``epsilon`` a reward is given regardless of the threshold. Exploration
trials do **not** count toward the hit-rate used for threshold updates.
During the first ``warmup_windows`` calls to :meth:`evaluate` the
threshold is frozen and ``crossed`` is always ``False``.
Parameters
----------
direction : {"up", "down"}
"up" -> reward when value > threshold (e.g., enhance alpha power).
"down" -> reward when value < threshold (e.g., suppress beta power).
Default is "up".
initial_threshold : float
Starting decision threshold. Default is 0.0.
target_hit_rate : float
Desired proportion of non-exploration windows that cross the
threshold. Must be strictly in ``(0, 1)``. Default is 0.70.
lr : float
Learning rate for threshold updates. Must be > 0. Default is 0.05.
epsilon : float
Exploration probability. On each call to :meth:`evaluate`,
``epsilon`` is the chance of giving a reward regardless of threshold.
Must be in ``[0, 1)``. Default is 0.05.
smoothing : float
EMA smoothing coefficient applied to the raw input before
thresholding. Must be in ``[0, 1)``. ``0.0`` disables smoothing.
Applied as: ``smoothed = (1 - smoothing) * new + smoothing * prev``.
Default is 0.0.
history_len : int
Rolling-window length for hit-rate and running-std estimation.
Must be >= 10. Default is 50.
warmup_windows : int
Number of initial evaluations used solely to seed the rolling
statistics before any reward can be issued or any threshold update
is applied. Must be >= 1. Default is 20.
rng_seed : int | None
Seed for the NumPy random generator used for epsilon draws.
Default is None (non-deterministic).
Raises
------
ValueError
If any parameter is outside its valid range.
Notes
-----
The update rule is direction-aware: when ``direction="up"`` a higher
threshold raises difficulty; when ``direction="down"`` a lower threshold
raises difficulty. The sign of the update is therefore flipped for
"down" protocols.
Examples
--------
RL-adaptive alpha-up protocol targeting 70 % hit rate::
proto = RLProtocol(
direction="up",
initial_threshold=0.5,
target_hit_rate=0.70,
lr=0.05,
epsilon=0.05,
)
for value in nf_stream:
crossed, magnitude = proto.evaluate(value)
if crossed:
send_reward(magnitude)
.. versionadded:: 1.0.0
"""
[docs]
def __init__(
self,
direction: str = "up",
initial_threshold: float = 0.0,
target_hit_rate: float = 0.70,
lr: float = 0.05,
epsilon: float = 0.05,
smoothing: float = 0.0,
history_len: int = 50,
warmup_windows: int = 20,
rng_seed: Optional[int] = None,
) -> None:
if direction not in ("up", "down"):
raise ValueError(
f"direction must be 'up' or 'down', got {direction!r}"
)
if not (0.0 < target_hit_rate < 1.0):
raise ValueError(
f"target_hit_rate must be in (0, 1), got {target_hit_rate}"
)
if lr <= 0.0:
raise ValueError(
f"lr must be > 0, got {lr}"
)
if not (0.0 <= epsilon < 1.0):
raise ValueError(
f"epsilon must be in [0, 1), got {epsilon}"
)
if not (0.0 <= smoothing < 1.0):
raise ValueError(
f"smoothing must be in [0, 1), got {smoothing}"
)
if history_len < 10:
raise ValueError(
f"history_len must be >= 10, got {history_len}"
)
if warmup_windows < 1:
raise ValueError(
f"warmup_windows must be >= 1, got {warmup_windows}"
)
self.direction: str = direction
self.initial_threshold: float = float(initial_threshold)
self.target_hit_rate: float = target_hit_rate
self.lr: float = lr
self.epsilon: float = epsilon
self.smoothing: float = smoothing
self.history_len: int = history_len
self.warmup_windows: int = warmup_windows
self._rng = np.random.default_rng(rng_seed)
self._threshold: float = float(initial_threshold)
self._n_evaluated: int = 0
self._n_explored: int = 0
self._smoothed: Optional[float] = None
# Rolling window for hit/miss of non-exploration evaluations
self._hit_history: collections.deque[bool] = collections.deque(
maxlen=history_len
)
# Rolling window for raw smoothed values (for running std)
self._value_history: collections.deque[float] = collections.deque(
maxlen=history_len
)
[docs]
def evaluate(self, value: float) -> tuple[bool, float]:
"""Evaluate one NF value and return (crossed, magnitude).
Applies optional EMA smoothing, checks the current threshold,
draws for epsilon-greedy exploration, updates the rolling hit history
(exploration draws excluded), and then applies the RL threshold update.
Warmup period suppresses all rewards and threshold updates.
Parameters
----------
value : float
Current NF feature value.
Returns
-------
crossed : bool
True if a reward is issued. May be True due to exploration
even when the threshold was not crossed. Always False during
warmup.
magnitude : float
Absolute distance from the current threshold, normalised by the
running standard deviation. ``0.0`` when not rewarded.
Notes
-----
Exploration trials (where the reward is given due to epsilon-greedy)
are counted in :attr:`n_explored` but are not recorded in the hit
history used for the threshold-update rule.
"""
# --- EMA smoothing ---------------------------------------------------
if self.smoothing > 0.0:
if self._smoothed is None:
self._smoothed = float(value)
else:
self._smoothed = (
(1.0 - self.smoothing) * value
+ self.smoothing * self._smoothed
)
else:
self._smoothed = float(value)
smoothed = self._smoothed
self._n_evaluated += 1
self._value_history.append(smoothed)
# Running std from rolling window (floor at 1e-6)
running_std: float
if len(self._value_history) > 1:
running_std = max(
float(np.std(list(self._value_history))), 1e-6
)
else:
running_std = 1e-6
# Warmup: accumulate statistics, issue no rewards, update no threshold
if self._n_evaluated <= self.warmup_windows:
return False, 0.0
# --- Threshold crossing check ----------------------------------------
if self.direction == "up":
real_crossed = smoothed > self._threshold
else:
real_crossed = smoothed < self._threshold
# --- Epsilon-greedy exploration ---------------------------------------
is_explore = bool(self._rng.random() < self.epsilon)
if is_explore:
self._n_explored += 1
crossed = True
# Exploration does NOT count toward hit rate for threshold update
else:
crossed = real_crossed
# Record in hit history only for non-exploration windows
self._hit_history.append(real_crossed)
# --- RL threshold update -----------------------------------------
if self._hit_history:
current_hit_rate = sum(self._hit_history) / len(self._hit_history)
update = self.lr * (current_hit_rate - self.target_hit_rate) * running_std
# "up" direction: raising threshold increases difficulty
# "down" direction: lowering threshold increases difficulty
if self.direction == "up":
self._threshold += update
else:
self._threshold -= update
# --- Magnitude -------------------------------------------------------
magnitude = (
abs(smoothed - self._threshold) / (running_std + 1e-6)
if crossed
else 0.0
)
return crossed, magnitude
[docs]
def reset(self) -> None:
"""Reset all adaptive state to initial conditions.
Restores the threshold to ``initial_threshold``, clears the rolling
histories, resets counters and the smoothed value. All constructor
parameters (``lr``, ``epsilon``, ``target_hit_rate``, etc.) are
preserved.
"""
self._threshold = self.initial_threshold
self._n_evaluated = 0
self._n_explored = 0
self._smoothed = None
self._hit_history.clear()
self._value_history.clear()
@property
def hit_rate(self) -> float:
"""Rolling hit rate over non-exploration evaluations (0–1).
Returns 0.0 before any non-exploration evaluations are recorded.
"""
if not self._hit_history:
return 0.0
return sum(self._hit_history) / len(self._hit_history)
@property
def threshold(self) -> float:
"""Current decision threshold."""
return self._threshold
@property
def n_evaluated(self) -> int:
"""Total number of evaluations since init or last :meth:`reset`."""
return self._n_evaluated
@property
def n_explored(self) -> int:
"""Number of exploration trials (epsilon draws) since init or reset."""
return self._n_explored
def __repr__(self) -> str:
return (
f"RLProtocol("
f"direction={self.direction!r}, "
f"threshold={self._threshold:.4g}, "
f"target_hit_rate={self.target_hit_rate}, "
f"hit_rate={self.hit_rate:.2f}, "
f"lr={self.lr}, "
f"epsilon={self.epsilon}, "
f"n_evaluated={self._n_evaluated}, "
f"n_explored={self._n_explored})"
)