Source code for mne_rt.protocols.sham

"""Double-blind sham feedback wrapper protocol for MNE-RT.

This module provides :class:`ShamProtocol`, a stateful wrapper that intercepts
real neurofeedback rewards and substitutes randomly-drawn historical values on a
configurable fraction of windows, enabling within-session double-blind designs
for neurofeedback RCTs.

Classes
-------
ShamProtocol
    Wraps any NF protocol with sham (double-blind) feedback.

References
----------
Thibaut, A., et al. (2018). Sham feedback in neurofeedback research.
Brain Stimulation, 11(3), 459–460.

Zander, T. O., et al. (2016). Towards adaptive classification for BCI.
Journal of Neural Engineering, 13(2), 026005.
"""
from __future__ import annotations

import collections
from typing import Any, Optional

import numpy as np


class ShamProtocol:
    """Wraps any NF protocol with sham (double-blind) feedback.

    In real-time NF, feedback is issued on every incoming window (e.g. every
    1 second). ShamProtocol intercepts those feedback values and on
    ``sham_rate`` fraction of windows replaces the real reward with a randomly
    shuffled historical value drawn from a rolling buffer, creating
    placebo/sham feedback windows indistinguishable from real ones.

    This enables within-session double-blind designs for neurofeedback RCTs
    without a separate sham session.

    Parameters
    ----------
    inner : any protocol with .evaluate(value) -> (bool, float)
        The real protocol to wrap (ThresholdProtocol, ZScoreProtocol, etc.).
    sham_rate : float, default 0.5
        Fraction of windows that receive sham feedback (0–1).
        0.0 = never sham; 1.0 = always sham.
    buffer_len : int, default 60
        Number of historical real-reward values to keep in the sham pool.
    rng_seed : int | None, default None
        Random seed for reproducibility.

    Attributes
    ----------
    n_real : int
        Number of real feedback windows so far.
    n_sham : int
        Number of sham feedback windows so far.
    sham_log : list[bool]
        Per-window sham flag (True = was sham).

    Raises
    ------
    ValueError
        If ``sham_rate`` is not in ``[0, 1]`` or ``buffer_len < 1``.

    Examples
    --------
    Wrap an existing protocol so 50 % of windows receive sham feedback::

        from mne_rt.protocols import ZScoreProtocol
        from mne_rt.protocols.sham import ShamProtocol

        inner = ZScoreProtocol(direction="up")
        proto = ShamProtocol(inner, sham_rate=0.5, rng_seed=42)
        for value in nf_stream:
            crossed, magnitude = proto.evaluate(value)
            # On sham windows, (crossed, magnitude) comes from a historical draw.

    .. versionadded:: 1.0.0
    """

[docs] def __init__( self, inner: Any, sham_rate: float = 0.5, buffer_len: int = 60, rng_seed: Optional[int] = None, ) -> None: if not (0.0 <= sham_rate <= 1.0): raise ValueError( f"sham_rate must be in [0, 1], got {sham_rate}" ) if buffer_len < 1: raise ValueError( f"buffer_len must be >= 1, got {buffer_len}" ) self.inner = inner self.sham_rate: float = sham_rate self.buffer_len: int = buffer_len self._rng = np.random.default_rng(rng_seed) # Buffer seeded with zeros; fills up from real evaluations. self._buffer: collections.deque[tuple[bool, float]] = collections.deque( [(False, 0.0)] * buffer_len, maxlen=buffer_len ) self.n_real: int = 0 self.n_sham: int = 0 self.sham_log: list[bool] = []
[docs] def evaluate(self, value: float) -> tuple[bool, float]: """Evaluate one NF value and return (crossed, magnitude). Always delegates to the inner protocol first so that its state (running statistics, adaptive threshold, etc.) advances correctly. On ``sham_rate`` fraction of calls the real result is silently discarded and a randomly-drawn historical value is returned instead. Parameters ---------- value : float Current NF feature value. Returns ------- crossed : bool True if the criterion was met. On sham windows this value was drawn from the historical buffer, not the current signal. magnitude : float Non-negative reward magnitude. On sham windows this value was drawn from the historical buffer. """ real_crossed, real_magnitude = self.inner.evaluate(value) self._buffer.append((real_crossed, real_magnitude)) is_sham = bool(self._rng.random() < self.sham_rate) self.sham_log.append(is_sham) if is_sham: self.n_sham += 1 idx = int(self._rng.integers(0, len(self._buffer))) crossed, magnitude = list(self._buffer)[idx] else: self.n_real += 1 crossed, magnitude = real_crossed, real_magnitude return crossed, magnitude
[docs] def reset(self) -> None: """Reset sham counters, log, and buffer; also resets the inner protocol. The inner protocol's own ``reset()`` is called if it provides one. Buffer is re-seeded with zeros. """ if hasattr(self.inner, "reset"): self.inner.reset() self._buffer = collections.deque( [(False, 0.0)] * self.buffer_len, maxlen=self.buffer_len ) self.n_real = 0 self.n_sham = 0 self.sham_log = []
@property def sham_fraction(self) -> float: """Observed sham fraction so far (0–1). Returns ``0.0`` when no evaluations have been recorded. """ total = self.n_real + self.n_sham return self.n_sham / total if total > 0 else 0.0 def __repr__(self) -> str: return ( f"ShamProtocol(" f"inner={self.inner!r}, " f"sham_rate={self.sham_rate}, " f"n_real={self.n_real}, " f"n_sham={self.n_sham}, " f"sham_fraction={self.sham_fraction:.2f})" )