Source code for mne_rt.protocols.zscores

"""Z-score-based feedback reward protocol for MNE-RT.

This module provides :class:`ZScoreProtocol`, a stateful protocol that
normalises incoming NF values against a running baseline mean and standard
deviation and returns a reward proportional to the resulting z-score.

Classes
-------
ZScoreProtocol
    Rolling z-score normaliser with configurable direction and warmup.
"""
from __future__ import annotations

import collections
from typing import Optional

import numpy as np


class ZScoreProtocol:
    """Z-score feedback protocol with rolling baseline normalisation.

    Normalises each incoming NF value against the running mean and standard
    deviation accumulated since initialisation (or the last :meth:`reset`).
    A reward is issued when the z-score magnitude exceeds
    ``zscore_threshold`` in the requested ``direction``.

    During the first ``warmup_windows`` calls to :meth:`evaluate` the
    baseline statistics are accumulating; ``crossed`` is always ``False``
    and ``magnitude`` is always ``0.0`` until warmup completes.

    Parameters
    ----------
    direction : {"up", "down"}
        "up"  -> reward when z-score >  ``zscore_threshold``.
        "down" -> reward when z-score < -``zscore_threshold``.
    warmup_windows : int
        Number of initial evaluations used solely to seed the baseline
        statistics before any reward can be issued. Default is 20.
    smoothing : float
        EMA smoothing coefficient applied to the raw input before
        z-scoring.  Must be in ``[0, 1)``.  ``0.0`` disables smoothing.
        Applied as: ``smoothed = (1 - smoothing) * new + smoothing * prev``.
        Default is 0.0.
    min_std : float
        Floor applied to the running standard deviation to prevent
        division by zero or near-zero blowup. Default is 1e-6.
    zscore_threshold : float
        Minimum absolute z-score required to issue a reward. Default is 0.5.

    Raises
    ------
    ValueError
        If any parameter is outside its valid range.

    Notes
    -----
    The running mean and variance are updated with Welford's online algorithm,
    which is numerically stable and requires O(1) memory per update.

    Examples
    --------
    Reward upward alpha-power deviations beyond half a standard deviation::

        proto = ZScoreProtocol(direction="up", zscore_threshold=0.5)
        for value in nf_stream:
            crossed, magnitude = proto.evaluate(value)
            if crossed:
                send_reward(magnitude)

    .. versionadded:: 1.0.0
    """

[docs] def __init__( self, direction: str = "up", warmup_windows: int = 20, smoothing: float = 0.0, min_std: float = 1e-6, zscore_threshold: float = 0.5, ) -> None: if direction not in ("up", "down"): raise ValueError( f"direction must be 'up' or 'down', got {direction!r}" ) if warmup_windows < 1: raise ValueError( f"warmup_windows must be >= 1, got {warmup_windows}" ) if not (0.0 <= smoothing < 1.0): raise ValueError( f"smoothing must be in [0, 1), got {smoothing}" ) if min_std <= 0.0: raise ValueError( f"min_std must be > 0, got {min_std}" ) if zscore_threshold < 0.0: raise ValueError( f"zscore_threshold must be >= 0, got {zscore_threshold}" ) self.direction: str = direction self.warmup_windows: int = warmup_windows self.smoothing: float = smoothing self.min_std: float = min_std self.zscore_threshold: float = zscore_threshold self._n_evaluated: int = 0 self._smoothed: Optional[float] = None self._zscore: float = 0.0 # Welford online mean/variance accumulators self._welford_mean: float = 0.0 self._welford_m2: float = 0.0
[docs] def evaluate(self, value: float) -> tuple[bool, float]: """Evaluate one NF value and return (crossed, magnitude). Updates the running baseline with the (optionally smoothed) input, computes the z-score, and determines whether the reward criterion is met. Parameters ---------- value : float Current NF feature value. Returns ------- crossed : bool True if the z-score criterion is met in the target direction and warmup has completed. magnitude : float Absolute value of the z-score when ``crossed`` is True; ``0.0`` otherwise. Notes ----- During warmup (``n_evaluated < warmup_windows``) this method always returns ``(False, 0.0)`` while still accumulating baseline statistics. """ if self.smoothing > 0.0: if self._smoothed is None: self._smoothed = float(value) else: self._smoothed = ( (1.0 - self.smoothing) * value + self.smoothing * self._smoothed ) else: self._smoothed = float(value) smoothed = self._smoothed self._n_evaluated += 1 # Welford's online update delta = smoothed - self._welford_mean self._welford_mean += delta / self._n_evaluated delta2 = smoothed - self._welford_mean self._welford_m2 += delta * delta2 if self._n_evaluated < 2: std = self.min_std else: std = max( float(np.sqrt(self._welford_m2 / (self._n_evaluated - 1))), self.min_std, ) self._zscore = (smoothed - self._welford_mean) / std if self._n_evaluated <= self.warmup_windows: return False, 0.0 if self.direction == "up": crossed = self._zscore > self.zscore_threshold else: crossed = self._zscore < -self.zscore_threshold magnitude = abs(self._zscore) if crossed else 0.0 return crossed, magnitude
[docs] def reset(self) -> None: """Reset all state. Clears the running baseline statistics, smoothed value, z-score, and evaluation counter. All constructor parameters are preserved. """ self._n_evaluated = 0 self._smoothed = None self._zscore = 0.0 self._welford_mean = 0.0 self._welford_m2 = 0.0
@property def n_evaluated(self) -> int: """Total number of values evaluated since init or last reset.""" return self._n_evaluated @property def zscore(self) -> float: """Z-score computed during the most recent :meth:`evaluate` call. Returns ``0.0`` before any evaluation. """ return self._zscore @property def mean_(self) -> float: """Current running baseline mean. Returns ``0.0`` before any evaluation. """ return self._welford_mean @property def std_(self) -> float: """Current running baseline standard deviation. Returns ``min_std`` before at least two evaluations. """ if self._n_evaluated < 2: return self.min_std return max( float(np.sqrt(self._welford_m2 / (self._n_evaluated - 1))), self.min_std, ) def __repr__(self) -> str: return ( f"ZScoreProtocol(" f"direction={self.direction!r}, " f"warmup_windows={self.warmup_windows}, " f"zscore={self._zscore:.4g}, " f"mean={self.mean_:.4g}, " f"std={self.std_:.4g}, " f"n_evaluated={self._n_evaluated})" )