Source code for mne_rt.protocols.linear_trend
"""Linear-trend feedback protocol for MNE-RT.
This module provides :class:`LinearTrendProtocol`, a stateful protocol that
rewards continuous improvement by detecting a statistically significant
upward (or downward) trend in the recent NF signal history using ordinary
least-squares regression.
Classes
-------
LinearTrendProtocol
OLS-based trend detector with configurable window, slope threshold,
and minimum goodness-of-fit (R²).
"""
from __future__ import annotations
import collections
from typing import Optional
import numpy as np
class LinearTrendProtocol:
"""Reward protocol that detects a statistically significant NF trend.
Instead of rewarding for exceeding a fixed threshold at a single time
point, this protocol fits an ordinary least-squares (OLS) line through
the last ``window`` NF values and issues a reward when:
1. The fitted slope is in the target ``direction`` **and** exceeds
``slope_threshold`` in absolute value, **and**
2. The regression goodness-of-fit R² ≥ ``min_r2`` (optional quality
gate — set to 0.0 to disable).
This is particularly useful in clinical neurofeedback where participants
may not reach a threshold in every window but should be encouraged for
sustained directional change across multiple windows.
Parameters
----------
direction : {"up", "down"}
"up" -> reward when the slope is positive (signal trending upward).
"down" -> reward when the slope is negative (signal trending downward).
Default is ``"up"``.
window : int
Number of most-recent NF values used for each regression. Must be
≥ 3. Larger values give more stable estimates but react more slowly.
Default is 20.
slope_threshold : float
Minimum absolute slope (in NF-signal units per sample) required to
issue a reward. Set to 0.0 to reward any trend in the right direction.
Default is 0.0.
min_r2 : float
Minimum coefficient of determination (R²) of the OLS fit required
before a reward is issued. Values in ``[0.0, 1.0]``. Set to 0.0
to disable the quality gate. Default is 0.0.
warmup_windows : int
Evaluations needed to fill the history buffer before rewards can be
issued. Must be ≥ ``window``. Defaults to ``window``.
smoothing : float
EMA smoothing coefficient applied to the raw input before adding to
history. ``0.0`` disables smoothing. Default is 0.0.
Raises
------
ValueError
If any parameter is outside its valid range.
Notes
-----
The OLS slope and R² are computed analytically (no external libraries
needed) in O(window) time per evaluation.
``magnitude`` returned by :meth:`evaluate` is the absolute slope divided
by the running standard deviation of the history buffer, giving a
dimensionless measure of trend strength.
Examples
--------
Reward sustained alpha-power increase over the last 20 windows::
proto = LinearTrendProtocol(direction="up", window=20)
for value in nf_stream:
crossed, magnitude = proto.evaluate(value)
if crossed:
send_reward(magnitude)
Require a clear trend (R² ≥ 0.5) with a non-trivial slope::
proto = LinearTrendProtocol(
direction="up",
window=15,
slope_threshold=0.01,
min_r2=0.5,
)
.. versionadded:: 1.0.0
"""
[docs]
def __init__(
self,
direction: str = "up",
window: int = 20,
slope_threshold: float = 0.0,
min_r2: float = 0.0,
warmup_windows: Optional[int] = None,
smoothing: float = 0.0,
) -> None:
if direction not in ("up", "down"):
raise ValueError(
f"direction must be 'up' or 'down', got {direction!r}"
)
if window < 3:
raise ValueError(f"window must be >= 3, got {window}")
if slope_threshold < 0.0:
raise ValueError(
f"slope_threshold must be >= 0, got {slope_threshold}"
)
if not (0.0 <= min_r2 <= 1.0):
raise ValueError(
f"min_r2 must be in [0.0, 1.0], got {min_r2}"
)
if not (0.0 <= smoothing < 1.0):
raise ValueError(
f"smoothing must be in [0, 1), got {smoothing}"
)
_warmup = warmup_windows if warmup_windows is not None else window
if _warmup < window:
raise ValueError(
f"warmup_windows ({_warmup}) must be >= window ({window})"
)
self.direction = direction
self.window = window
self.slope_threshold = slope_threshold
self.min_r2 = min_r2
self.warmup_windows = _warmup
self.smoothing = smoothing
self._history: collections.deque[float] = collections.deque(maxlen=window)
self._n_evaluated: int = 0
self._smoothed: Optional[float] = None
self._slope: float = 0.0
self._r2: float = 0.0
[docs]
def evaluate(self, value: float) -> tuple[bool, float]:
"""Evaluate one NF value and return ``(crossed, magnitude)``.
Parameters
----------
value : float
Current NF feature value.
Returns
-------
crossed : bool
``True`` when all of the following hold:
* warmup is complete,
* the OLS slope is in the target direction and ≥ ``slope_threshold``,
* R² ≥ ``min_r2``.
magnitude : float
Absolute slope normalised by the history standard deviation
(dimensionless trend strength). ``0.0`` when not crossed.
"""
if self.smoothing > 0.0:
if self._smoothed is None:
self._smoothed = float(value)
else:
self._smoothed = (
(1.0 - self.smoothing) * value + self.smoothing * self._smoothed
)
else:
self._smoothed = float(value)
self._history.append(self._smoothed)
self._n_evaluated += 1
if self._n_evaluated < self.warmup_windows:
return False, 0.0
y = np.array(self._history, dtype=np.float64)
n = len(y)
x = np.arange(n, dtype=np.float64)
x_mean = (n - 1) / 2.0
y_mean = y.mean()
sxx = np.sum((x - x_mean) ** 2)
sxy = np.sum((x - x_mean) * (y - y_mean))
slope = sxy / sxx if sxx > 0 else 0.0
self._slope = slope
# R²
y_pred = y_mean + slope * (x - x_mean)
ss_res = np.sum((y - y_pred) ** 2)
ss_tot = np.sum((y - y_mean) ** 2)
self._r2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 1.0
std = float(np.std(y)) or 1.0
if self.direction == "up":
crossed = slope >= self.slope_threshold and self._r2 >= self.min_r2
else:
crossed = slope <= -self.slope_threshold and self._r2 >= self.min_r2
magnitude = abs(slope) / std if crossed else 0.0
return crossed, magnitude
[docs]
def reset(self) -> None:
"""Clear history and counters, preserving all constructor parameters."""
self._history.clear()
self._n_evaluated = 0
self._smoothed = None
self._slope = 0.0
self._r2 = 0.0
@property
def slope(self) -> float:
"""OLS slope from the most recent :meth:`evaluate` call (0.0 before warmup)."""
return self._slope
@property
def r2(self) -> float:
"""R² from the most recent :meth:`evaluate` call (0.0 before warmup)."""
return self._r2
@property
def n_evaluated(self) -> int:
"""Total number of values evaluated since init or last :meth:`reset`."""
return self._n_evaluated
def __repr__(self) -> str:
return (
f"LinearTrendProtocol("
f"direction={self.direction!r}, "
f"window={self.window}, "
f"slope={self._slope:.4g}, "
f"r2={self._r2:.3f}, "
f"n_evaluated={self._n_evaluated})"
)