Source code for mne_rt.protocols.staircase

"""Up-down adaptive staircase threshold protocol for MNE-RT.

This module provides :class:`UpDownStaircaseProtocol`, a classic psychophysics
adaptive procedure that converges the NF reward threshold to a target
performance level determined by the ``n_up`` / ``n_down`` rule.

Classes
-------
UpDownStaircaseProtocol
    Up-down adaptive staircase threshold protocol.

References
----------
Levitt, H. (1971). Transformed up-down methods in psychoacoustics.
Journal of the Acoustical Society of America, 49(2B), 467–477.

García-Pérez, M. A. (1998). Forced-choice staircases with fixed step sizes:
Asymptotic and small-sample properties. Vision Research, 38(12), 1861–1881.
"""
from __future__ import annotations

from typing import Optional

import numpy as np


class UpDownStaircaseProtocol:
    """Up-down adaptive staircase threshold protocol.

    Adjusts the reward threshold after each window based on consecutive
    success/failure runs, converging the difficulty to a target performance
    level determined by the ratio n_up/n_down.

    Parameters
    ----------
    initial_threshold : float
        Starting threshold.
    direction : {"up", "down"}, default "up"
        "up" → reward when value > threshold; "down" → reward when
        value < threshold.
    n_up : int, default 1
        Consecutive successes needed to increase difficulty (tighten
        threshold).
    n_down : int, default 2
        Consecutive failures needed to decrease difficulty (loosen
        threshold).  The rule (n_up=1, n_down=2) converges to
        approximately 70.7 % success rate (Levitt, 1971).
    step_size : float, default 0.05
        Initial threshold step size per reversal.
    step_factor : float, default 0.5
        Multiplicative factor applied to ``step_size`` after
        ``n_reversals_before_halving`` reversals (standard Levitt
        procedure to zoom in on the threshold).
    n_reversals_before_halving : int, default 4
        Number of reversals before ``step_size`` is multiplied by
        ``step_factor``.
    min_step : float, default 1e-4
        Floor for ``step_size`` to prevent it collapsing to zero.
    smoothing : float, default 0.0
        EMA smoothing coefficient for the input value before
        thresholding.  Must be in ``[0, 1)``.
    max_reversals : int | None, default None
        If set, stop adapting after this many reversals (threshold
        freezes at its current value).

    Attributes
    ----------
    threshold : float
        Current threshold.
    n_reversals : int
        Number of reversals so far.
    reversal_thresholds : list[float]
        Threshold value recorded at each reversal point.

    Raises
    ------
    ValueError
        If any parameter is outside its valid range.

    Examples
    --------
    1-up/2-down staircase targeting ~70.7 % success rate::

        from mne_rt.protocols.staircase import UpDownStaircaseProtocol

        proto = UpDownStaircaseProtocol(
            initial_threshold=0.5,
            direction="up",
            n_up=1,
            n_down=2,
            step_size=0.05,
        )
        for value in nf_stream:
            crossed, magnitude = proto.evaluate(value)

    .. versionadded:: 1.0.0
    """

[docs] def __init__( self, initial_threshold: float, direction: str = "up", n_up: int = 1, n_down: int = 2, step_size: float = 0.05, step_factor: float = 0.5, n_reversals_before_halving: int = 4, min_step: float = 1e-4, smoothing: float = 0.0, max_reversals: Optional[int] = None, ) -> None: if direction not in ("up", "down"): raise ValueError( f"direction must be 'up' or 'down', got {direction!r}" ) if n_up < 1: raise ValueError(f"n_up must be >= 1, got {n_up}") if n_down < 1: raise ValueError(f"n_down must be >= 1, got {n_down}") if step_size <= 0: raise ValueError(f"step_size must be > 0, got {step_size}") if not (0 < step_factor <= 1): raise ValueError(f"step_factor must be in (0, 1], got {step_factor}") if min_step <= 0: raise ValueError(f"min_step must be > 0, got {min_step}") if not (0.0 <= smoothing < 1.0): raise ValueError(f"smoothing must be in [0, 1), got {smoothing}") if max_reversals is not None and max_reversals < 1: raise ValueError(f"max_reversals must be >= 1 or None, got {max_reversals}") self._initial_threshold: float = float(initial_threshold) self.threshold: float = float(initial_threshold) self.direction: str = direction self.n_up: int = n_up self.n_down: int = n_down self._initial_step_size: float = float(step_size) self._step_size: float = float(step_size) self.step_factor: float = step_factor self.n_reversals_before_halving: int = n_reversals_before_halving self.min_step: float = float(min_step) self.smoothing: float = smoothing self.max_reversals: Optional[int] = max_reversals self.n_reversals: int = 0 self.reversal_thresholds: list[float] = [] self._consecutive_up: int = 0 self._consecutive_down: int = 0 # "up" = last change was threshold increase; "down" = decrease; None = start self._last_direction: Optional[str] = None self._n_evaluated: int = 0 self._smoothed: Optional[float] = None
[docs] def evaluate(self, value: float) -> tuple[bool, float]: """Evaluate one NF value and update the staircase. Applies EMA smoothing, compares against the current threshold, updates consecutive-run counters, and adjusts the threshold when an up or down rule is triggered. A reversal is recorded whenever the direction of threshold change flips. Parameters ---------- value : float Current NF feature value. Returns ------- crossed : bool True if the threshold was crossed in the target direction. magnitude : float Distance from the threshold (absolute difference) when ``crossed`` is True; ``0.0`` otherwise. """ if self.smoothing > 0.0: if self._smoothed is None: self._smoothed = float(value) else: self._smoothed = ( (1.0 - self.smoothing) * value + self.smoothing * self._smoothed ) else: self._smoothed = float(value) smoothed = self._smoothed self._n_evaluated += 1 if self.direction == "up": crossed = smoothed > self.threshold else: crossed = smoothed < self.threshold magnitude = abs(smoothed - self.threshold) if crossed else 0.0 frozen = ( self.max_reversals is not None and self.n_reversals >= self.max_reversals ) if not frozen: self._update_staircase(crossed) return crossed, magnitude
def _update_staircase(self, crossed: bool) -> None: """Apply the up-down rule and update the threshold.""" if crossed: self._consecutive_up += 1 self._consecutive_down = 0 if self._consecutive_up >= self.n_up: self._consecutive_up = 0 self._apply_step(increase_difficulty=True) else: self._consecutive_down += 1 self._consecutive_up = 0 if self._consecutive_down >= self.n_down: self._consecutive_down = 0 self._apply_step(increase_difficulty=False) def _apply_step(self, increase_difficulty: bool) -> None: """Shift the threshold by the current step size and handle reversals.""" # direction of this change in threshold-space if self.direction == "up": # Increasing difficulty means raising the threshold change_dir = "up" if increase_difficulty else "down" else: # Increasing difficulty means lowering the threshold change_dir = "down" if increase_difficulty else "up" # Detect reversal before applying step if self._last_direction is not None and change_dir != self._last_direction: self.n_reversals += 1 self.reversal_thresholds.append(self.threshold) # Halve step size after n_reversals_before_halving reversals if ( self.n_reversals >= self.n_reversals_before_halving and self._step_size * self.step_factor >= self.min_step ): self._step_size = max( self._step_size * self.step_factor, self.min_step ) self._last_direction = change_dir if change_dir == "up": self.threshold += self._step_size else: self.threshold -= self._step_size
[docs] def reset(self) -> None: """Reset all state to initial conditions. Restores the threshold to ``initial_threshold``, clears reversal history, and resets all counters. The ``step_size`` is restored to its value from construction (stored as ``_initial_step_size``). """ self.threshold = self._initial_threshold self._step_size = self._initial_step_size self.n_reversals = 0 self.reversal_thresholds = [] self._consecutive_up = 0 self._consecutive_down = 0 self._last_direction = None self._n_evaluated = 0 self._smoothed = None
def __repr__(self) -> str: return ( f"UpDownStaircaseProtocol(" f"threshold={self.threshold:.4g}, " f"direction={self.direction!r}, " f"n_up={self.n_up}, " f"n_down={self.n_down}, " f"step_size={self._step_size:.4g}, " f"n_reversals={self.n_reversals}, " f"n_evaluated={self._n_evaluated})" )