Source code for mne_rt.protocols.operant

"""Operant conditioning reinforcement-schedule wrapper for MNE-RT.

This module provides :class:`OperantProtocol`, which wraps any inner
protocol and gates its reward output through a classical operant conditioning
schedule (fixed ratio, variable ratio, fixed interval, or variable interval).

Classes
-------
OperantProtocol
    Operant conditioning schedule wrapper for any NF protocol.

References
----------
Ferster, C. B., & Skinner, B. F. (1957). Schedules of reinforcement.
Appleton-Century-Crofts.

Skinner, B. F. (1938). The behavior of organisms: An experimental analysis.
Appleton-Century-Crofts.
"""
from __future__ import annotations

import time
from typing import Any, Optional

import numpy as np


_VALID_SCHEDULES = ("FR", "VR", "FI", "VI")


class OperantProtocol:
    """Operant conditioning reinforcement-schedule wrapper.

    Wraps any inner NF protocol and applies a classical operant conditioning
    schedule to determine whether a reward is actually delivered.  The inner
    protocol is always evaluated so that its internal state (running statistics,
    adaptive threshold, etc.) advances correctly; the schedule only decides
    whether to pass the reward through to the caller.

    .. rubric:: Supported schedules

    ``"FR"`` (Fixed Ratio)
        Deliver reward on every ``ratio``-th hit (hit = inner protocol
        returned ``crossed=True``).
    ``"VR"`` (Variable Ratio)
        Deliver reward with probability ``1 / ratio`` on each hit (geometric
        distribution; same average number of hits per reward as FR).
    ``"FI"`` (Fixed Interval)
        Deliver reward for the **first** hit that occurs after ``interval``
        seconds have elapsed since the last reward.
    ``"VI"`` (Variable Interval)
        Deliver reward for the **first** hit that occurs after a random
        interval sampled from an exponential distribution with mean
        ``interval`` seconds.

    Parameters
    ----------
    base_protocol : any protocol with .evaluate(value) -> (bool, float)
        The inner NF protocol whose output is filtered by the schedule.
    schedule : {"FR", "VR", "FI", "VI"}
        Reinforcement schedule identifier.
    ratio : int
        Required for ``"FR"`` and ``"VR"`` schedules.  For FR: reward is
        delivered on every ``ratio``-th hit.  For VR: reward probability per
        hit is ``1 / ratio``.  Must be >= 1.  Default is 5.
    interval : float
        Required for ``"FI"`` and ``"VI"`` schedules, in seconds.  For FI:
        minimum fixed interval between rewards.  For VI: mean interval
        (exponential distribution).  Must be > 0.  Default is 30.0.
    rng_seed : int | None
        Seed for the NumPy random generator used by ``"VR"`` and ``"VI"``
        schedules.  Default is None (non-deterministic).

    Raises
    ------
    ValueError
        If ``schedule`` is not one of the valid identifiers, ``ratio < 1``,
        or ``interval <= 0``.

    Notes
    -----
    The internal clock is started on the **first** call to :meth:`evaluate`
    using :func:`time.monotonic`, not at construction time.  This avoids
    artificially burning interval time between object creation and the start
    of the session.

    ``reset()`` resets all internal counters, hit/reward tallies, and the
    clock.  It also calls ``base_protocol.reset()`` if that method exists.

    Examples
    --------
    Fixed-ratio schedule (reward every 5th hit)::

        from mne_rt.protocols import ZScoreProtocol
        from mne_rt.protocols.operant import OperantProtocol

        inner = ZScoreProtocol(direction="up")
        proto = OperantProtocol(inner, schedule="FR", ratio=5)
        for value in nf_stream:
            crossed, magnitude = proto.evaluate(value)
            if crossed:
                deliver_reward(magnitude)

    Variable-interval schedule (reward at most once per ~30 s on average)::

        proto = OperantProtocol(inner, schedule="VI", interval=30.0, rng_seed=0)

    .. versionadded:: 1.0.0
    """

[docs] def __init__( self, base_protocol: Any, schedule: str = "FR", ratio: int = 5, interval: float = 30.0, rng_seed: Optional[int] = None, ) -> None: if schedule not in _VALID_SCHEDULES: raise ValueError( f"schedule must be one of {_VALID_SCHEDULES}, got {schedule!r}" ) if ratio < 1: raise ValueError( f"ratio must be >= 1, got {ratio}" ) if interval <= 0.0: raise ValueError( f"interval must be > 0, got {interval}" ) self.base_protocol: Any = base_protocol self.schedule: str = schedule self.ratio: int = int(ratio) self.interval: float = float(interval) self._rng = np.random.default_rng(rng_seed) self._n_hits: int = 0 self._n_rewards: int = 0 # Counter for FR: hits since last reward self._fr_counter: int = 0 # Timing state for FI / VI self._clock_start: Optional[float] = None self._next_interval_end: Optional[float] = None
# ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------
[docs] def evaluate(self, value: float) -> tuple[bool, float]: """Evaluate one NF value and apply the reinforcement schedule. Delegates to the inner protocol unconditionally, then applies the schedule logic to decide whether to pass the reward through. Parameters ---------- value : float Current NF feature value. Returns ------- crossed : bool True if the schedule releases a reward on this evaluation. magnitude : float Reward magnitude from the inner protocol when ``crossed`` is True; ``0.0`` otherwise. Notes ----- For interval schedules (``"FI"``, ``"VI"``) the clock is initialised on the first call. """ # Start the clock on first call if self._clock_start is None: self._clock_start = time.monotonic() self._next_interval_end = self._clock_start + self._draw_interval() base_crossed, base_magnitude = self.base_protocol.evaluate(value) if not base_crossed: return False, 0.0 # Base protocol issued a hit — run schedule logic self._n_hits += 1 reward = self._apply_schedule() if reward: self._n_rewards += 1 return True, base_magnitude return False, 0.0
[docs] def reset(self) -> None: """Reset all internal state including counters and the clock. Also calls ``base_protocol.reset()`` if that method exists. All constructor parameters are preserved. """ if hasattr(self.base_protocol, "reset"): self.base_protocol.reset() self._n_hits = 0 self._n_rewards = 0 self._fr_counter = 0 self._clock_start = None self._next_interval_end = None
# ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ @property def n_hits(self) -> int: """Total hits (crossed=True from base protocol) since init or reset.""" return self._n_hits @property def n_rewards(self) -> int: """Total rewards delivered by the schedule since init or reset.""" return self._n_rewards @property def reward_rate(self) -> float: """Fraction of hits that resulted in a reward (0–1). Returns 0.0 before any hits are recorded. """ if self._n_hits == 0: return 0.0 return self._n_rewards / self._n_hits @property def current_threshold(self) -> Optional[float]: """Pass-through threshold from the base protocol, if available. Returns None if the base protocol does not expose a ``threshold`` attribute. """ return getattr(self.base_protocol, "threshold", None) # ------------------------------------------------------------------ # Schedule implementation # ------------------------------------------------------------------ def _apply_schedule(self) -> bool: """Apply the active schedule and return True if a reward is released.""" if self.schedule == "FR": return self._schedule_fr() elif self.schedule == "VR": return self._schedule_vr() elif self.schedule == "FI": return self._schedule_fi() else: # "VI" return self._schedule_vi() def _schedule_fr(self) -> bool: """Fixed Ratio: reward every ratio-th hit.""" self._fr_counter += 1 if self._fr_counter >= self.ratio: self._fr_counter = 0 return True return False def _schedule_vr(self) -> bool: """Variable Ratio: each hit rewarded with probability 1/ratio.""" return bool(self._rng.random() < (1.0 / self.ratio)) def _schedule_fi(self) -> bool: """Fixed Interval: reward first hit after the fixed interval expires.""" now = time.monotonic() if now >= self._next_interval_end: self._next_interval_end = now + self.interval return True return False def _schedule_vi(self) -> bool: """Variable Interval: reward first hit after a random interval expires.""" now = time.monotonic() if now >= self._next_interval_end: self._next_interval_end = now + self._draw_interval() return True return False def _draw_interval(self) -> float: """Draw a random interval for VI schedule (exponential with mean=interval).""" if self.schedule == "VI": return float(self._rng.exponential(scale=self.interval)) return self.interval def __repr__(self) -> str: return ( f"OperantProtocol(" f"schedule={self.schedule!r}, " f"ratio={self.ratio}, " f"interval={self.interval}, " f"n_hits={self._n_hits}, " f"n_rewards={self._n_rewards}, " f"reward_rate={self.reward_rate:.2f}, " f"base_protocol={self.base_protocol!r})" )