Source code for mne_rt.protocols.transfer

"""Cross-session transfer protocol for MNE-RT.

This module provides :class:`TransferProtocol`, which loads baseline
statistics from a previous session's BIDS behavioural JSON file and uses them
to seed a Welford online-mean/variance tracker.  This eliminates the warmup
period of :class:`~mne_rt.protocols.ZScoreProtocol` and provides consistent
cross-day normalisation for longitudinal NF studies.

Classes
-------
TransferProtocol
    Z-score protocol seeded with baseline statistics from a prior session.
"""
from __future__ import annotations

import json
from pathlib import Path
from typing import Optional, Union

import numpy as np


class TransferProtocol:
    """Cross-session transfer NF protocol seeded from a prior-session file.

    Loads a previous session's NF feature time-series from a BIDS
    behavioural JSON file (``beh/*.json``), computes the prior mean and
    standard deviation, and initialises Welford's online algorithm as if
    those samples had already been observed.  Subsequent calls to
    :meth:`evaluate` continue updating the running statistics, but because
    the prior already contributes, no warmup period is needed.

    The JSON file must follow MNE-RT's BIDS behavioural format::

        {
            "meta": {"modalities": ["sensor_power"], ...},
            "data": {"sensor_power": [0.1, 0.2, ...]}
        }

    Parameters
    ----------
    fname : str | Path
        Path to a prior-session BIDS behavioural JSON file.
    modality : str
        Key inside ``data`` from which to read the prior time-series
        (e.g., ``"sensor_power"``).
    direction : {"up", "down"}, default "up"
        "up"  -> reward when z-score >  ``zscore_threshold``.
        "down" -> reward when z-score < -``zscore_threshold``.
    zscore_threshold : float, default 0.5
        Minimum absolute z-score required to issue a reward.  Must be >= 0.
    adapt_rate : float, default 0.0
        Controls how quickly the running statistics adapt to the new session.
        Set to ``0.0`` to freeze the prior statistics (pure transfer);
        higher values let the statistics drift toward the current session.
        Implemented as an EMA blend between the Welford update and the
        frozen prior: the effective weight of each new sample is
        ``adapt_rate`` (must be in ``[0, 1)``).
    smoothing : float, default 0.0
        EMA smoothing coefficient applied to the raw input before z-scoring.
        Must be in ``[0, 1)``.  ``0.0`` disables smoothing.

    Raises
    ------
    FileNotFoundError
        If ``fname`` does not point to an existing file.
    KeyError
        If the ``"data"`` key is absent from the JSON or ``modality`` is not
        found inside ``data``.
    ValueError
        If any numerical parameter is outside its valid range, or if the
        prior data array has fewer than 2 elements (insufficient for std).

    Notes
    -----
    The Welford accumulator is initialised with::

        n    = len(prior_data)
        mean = mean(prior_data)
        M2   = var(prior_data, ddof=1) * (n - 1)

    so the running ``std_`` on the first call to :meth:`evaluate` equals the
    prior standard deviation.  When ``adapt_rate > 0`` each new sample
    contributes weight ``adapt_rate`` to the mean and M2, while the existing
    accumulator retains weight ``1 - adapt_rate``.

    Examples
    --------
    Seed today's session from yesterday's baseline::

        from mne_rt.protocols.transfer import TransferProtocol

        proto = TransferProtocol(
            fname="sub-01/ses-01/beh/sub-01_ses-01_task-nf_beh.json",
            modality="sensor_power",
            direction="up",
            zscore_threshold=0.5,
        )
        for value in nf_stream:
            crossed, magnitude = proto.evaluate(value)
            if crossed:
                send_reward(magnitude)

    .. versionadded:: 1.0.0
    """

[docs] def __init__( self, fname: Union[str, Path], modality: str, direction: str = "up", zscore_threshold: float = 0.5, adapt_rate: float = 0.0, smoothing: float = 0.0, ) -> None: # --- Parameter validation -------------------------------------------- if direction not in ("up", "down"): raise ValueError( f"direction must be 'up' or 'down', got {direction!r}" ) if zscore_threshold < 0.0: raise ValueError( f"zscore_threshold must be >= 0, got {zscore_threshold}" ) if not (0.0 <= adapt_rate < 1.0): raise ValueError( f"adapt_rate must be in [0, 1), got {adapt_rate}" ) if not (0.0 <= smoothing < 1.0): raise ValueError( f"smoothing must be in [0, 1), got {smoothing}" ) # --- Load prior data from file --------------------------------------- fname = Path(fname) if not fname.exists(): raise FileNotFoundError( f"Prior-session file not found: {fname}" ) with fname.open("r", encoding="utf-8") as fh: payload = json.load(fh) if "data" not in payload: raise KeyError( f"The file {fname} does not contain a top-level 'data' key. " "Expected MNE-RT BIDS JSON format: " '{"meta": {...}, "data": {"<modality>": [...]}}' ) data_section = payload["data"] if modality not in data_section: available = list(data_section.keys()) raise KeyError( f"Modality {modality!r} not found in {fname}. " f"Available modalities: {available}" ) prior_data = np.asarray(data_section[modality], dtype=float) if prior_data.ndim != 1 or len(prior_data) < 2: raise ValueError( f"Prior data for modality {modality!r} must be a 1-D array " f"with at least 2 elements; got shape {prior_data.shape}." ) # --- Store constructor parameters ------------------------------------ self.fname: Path = fname self.modality: str = modality self.direction: str = direction self.zscore_threshold: float = zscore_threshold self.adapt_rate: float = adapt_rate self.smoothing: float = smoothing # --- Compute and cache prior statistics ------------------------------ self._prior_data: np.ndarray = prior_data self._prior_mean: float = float(np.mean(prior_data)) self._prior_std: float = float(np.std(prior_data, ddof=1)) self._n_prior: int = len(prior_data) # --- Initialise Welford accumulators from prior ---------------------- self._n_evaluated: int = 0 self._smoothed: Optional[float] = None self._zscore: float = 0.0 self._welford_n: int = self._n_prior self._welford_mean: float = self._prior_mean # M2 = sample_var * (n - 1) self._welford_m2: float = self._prior_std ** 2 * (self._n_prior - 1)
[docs] def evaluate(self, value: float) -> tuple[bool, float]: """Evaluate one NF value and return (crossed, magnitude). Updates the running baseline statistics (seeded from the prior) with the (optionally smoothed) input, computes the z-score against the running mean and standard deviation, and determines whether the reward criterion is met. Parameters ---------- value : float Current NF feature value. Returns ------- crossed : bool True if the z-score criterion is met in the target direction. magnitude : float Absolute value of the z-score when ``crossed`` is True; ``0.0`` otherwise. Notes ----- When ``adapt_rate = 0.0`` the running statistics are frozen at the prior values: every evaluation is z-scored against the prior mean/std. When ``adapt_rate > 0`` new samples gradually shift the running statistics toward the current session distribution. """ # --- EMA smoothing --------------------------------------------------- if self.smoothing > 0.0: if self._smoothed is None: self._smoothed = float(value) else: self._smoothed = ( (1.0 - self.smoothing) * value + self.smoothing * self._smoothed ) else: self._smoothed = float(value) smoothed = self._smoothed self._n_evaluated += 1 # --- Welford update -------------------------------------------------- if self.adapt_rate > 0.0: # Weighted Welford: new sample has effective weight adapt_rate # relative to the accumulated mean. # Equivalent to treating each new observation as contributing # adapt_rate / (adapt_rate + existing_weight) fractionally. old_mean = self._welford_mean old_m2 = self._welford_m2 old_n = self._welford_n # Standard Welford step (unweighted count) new_n = old_n + 1 delta = smoothed - old_mean new_mean = old_mean + (self.adapt_rate * delta) / max(new_n, 1) delta2 = smoothed - new_mean new_m2 = old_m2 + self.adapt_rate * delta * delta2 self._welford_n = new_n self._welford_mean = new_mean self._welford_m2 = max(new_m2, 0.0) else: # Frozen prior: no state update; z-score is purely against prior pass # --- Compute std and z-score ----------------------------------------- total_n = self._welford_n if total_n >= 2: std = max( float(np.sqrt(self._welford_m2 / (total_n - 1))), 1e-6, ) else: std = max(self._prior_std, 1e-6) self._zscore = (smoothed - self._welford_mean) / std # --- Reward criterion ------------------------------------------------ if self.direction == "up": crossed = self._zscore > self.zscore_threshold else: crossed = self._zscore < -self.zscore_threshold magnitude = abs(self._zscore) if crossed else 0.0 return crossed, magnitude
[docs] def reset(self) -> None: """Reset to the prior statistics without re-reading the file. Restores the Welford accumulators to the values computed from the prior data during ``__init__``, clears the smoothed state, and resets the evaluation counter. All constructor parameters are preserved. """ self._n_evaluated = 0 self._smoothed = None self._zscore = 0.0 self._welford_n = self._n_prior self._welford_mean = self._prior_mean self._welford_m2 = self._prior_std ** 2 * (self._n_prior - 1)
# ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property def prior_mean(self) -> float: """Mean of the prior session data (read-only).""" return self._prior_mean @property def prior_std(self) -> float: """Standard deviation of the prior session data (read-only).""" return self._prior_std @property def n_prior(self) -> int: """Number of samples in the prior session data (read-only).""" return self._n_prior @property def n_evaluated(self) -> int: """Number of values evaluated since init or last :meth:`reset`.""" return self._n_evaluated @property def zscore(self) -> float: """Z-score computed during the most recent :meth:`evaluate` call. Returns ``0.0`` before any evaluation. """ return self._zscore @property def mean_(self) -> float: """Current running mean (prior-seeded). Before any adaptations this equals :attr:`prior_mean`. """ return self._welford_mean @property def std_(self) -> float: """Current running standard deviation (prior-seeded). Before any adaptations this equals :attr:`prior_std`. """ total_n = self._welford_n if total_n >= 2: return max( float(np.sqrt(self._welford_m2 / (total_n - 1))), 1e-6, ) return max(self._prior_std, 1e-6) def __repr__(self) -> str: return ( f"TransferProtocol(" f"fname={str(self.fname)!r}, " f"modality={self.modality!r}, " f"direction={self.direction!r}, " f"prior_mean={self._prior_mean:.4g}, " f"prior_std={self._prior_std:.4g}, " f"n_prior={self._n_prior}, " f"zscore={self._zscore:.4g}, " f"mean_={self.mean_:.4g}, " f"std_={self.std_:.4g}, " f"n_evaluated={self._n_evaluated})" )