"""Multi-band simultaneous reward protocol for MNE-RT.
This module provides :class:`MultiBandProtocol`, which wraps two inner
protocols (one for up-regulation, one for down-regulation) and issues a
combined reward only when both criteria are simultaneously satisfied —
e.g., alpha↑ + theta↓ for focus training or SMR↑ + theta↓ for ADHD NF.
Classes
-------
MultiBandProtocol
Simultaneous two-band reward protocol.
References
----------
Sterman, M. B., & Egner, T. (2006). Foundation and practice of neurofeedback
for the treatment of epilepsy. Applied Psychophysiology and Biofeedback,
31(1), 21–35.
"""
from __future__ import annotations
from typing import Any
import numpy as np
class MultiBandProtocol:
"""Reward protocol for simultaneous two-band control.
Wraps two inner protocols (one for up-regulation, one for
down-regulation) and issues a combined reward only when BOTH criteria
are met simultaneously (or either, if ``require_both=False``).
The combined ``magnitude`` is the geometric mean of the two individual
magnitudes to ensure both bands contribute equally. When one magnitude
is zero the arithmetic mean is used as a fallback so that partial
rewards are still numerically meaningful.
Parameters
----------
protocol_up : protocol with .evaluate(value) -> (bool, float)
Protocol applied to the up-regulation value (e.g., alpha power).
protocol_down : protocol with .evaluate(value) -> (bool, float)
Protocol applied to the down-regulation value (e.g., theta power).
require_both : bool, default True
If True, both criteria must be met for a reward (AND logic).
If False, either criterion suffices (OR logic).
up_label : str, default "up_band"
Human-readable label for the up-regulation band (used in
``__repr__`` and logging).
down_label : str, default "down_band"
Human-readable label for the down-regulation band (used in
``__repr__`` and logging).
Notes
-----
Call ``evaluate(up_value, down_value)`` with TWO positional arguments —
one from each modality/band. Configure ``RTStream`` with
``modality=["sensor_power_alpha", "sensor_power_theta"]`` (or similar)
and unpack the two returned values before each call.
Examples
--------
Alpha-up / theta-down simultaneous reward::
from mne_rt.protocols import ZScoreProtocol
from mne_rt.protocols.multiband import MultiBandProtocol
alpha_proto = ZScoreProtocol(direction="up")
theta_proto = ZScoreProtocol(direction="down")
proto = MultiBandProtocol(
protocol_up=alpha_proto,
protocol_down=theta_proto,
up_label="alpha",
down_label="theta",
)
for alpha_val, theta_val in zip(alpha_stream, theta_stream):
crossed, magnitude = proto.evaluate(alpha_val, theta_val)
if crossed:
send_reward(magnitude)
.. versionadded:: 1.0.0
"""
[docs]
def __init__(
self,
protocol_up: Any,
protocol_down: Any,
require_both: bool = True,
up_label: str = "up_band",
down_label: str = "down_band",
) -> None:
self.protocol_up = protocol_up
self.protocol_down = protocol_down
self.require_both: bool = require_both
self.up_label: str = up_label
self.down_label: str = down_label
self._n_evaluated: int = 0
[docs]
def evaluate(self, up_value: float, down_value: float) -> tuple[bool, float]:
"""Evaluate one pair of NF values and return (crossed, magnitude).
Delegates to both inner protocols, then combines the results
according to ``require_both``. The combined magnitude is the
geometric mean of the two individual magnitudes; when one is zero
the arithmetic mean is used as fallback.
Parameters
----------
up_value : float
Current NF feature value for the up-regulation band.
down_value : float
Current NF feature value for the down-regulation band.
Returns
-------
crossed : bool
True if the combined criterion is met.
magnitude : float
Non-negative combined reward magnitude.
"""
crossed_up, mag_up = self.protocol_up.evaluate(up_value)
crossed_down, mag_down = self.protocol_down.evaluate(down_value)
self._n_evaluated += 1
if self.require_both:
crossed = crossed_up and crossed_down
else:
crossed = crossed_up or crossed_down
if not crossed:
return False, 0.0
if mag_up > 0.0 and mag_down > 0.0:
magnitude = float(np.sqrt(mag_up * mag_down))
else:
magnitude = (mag_up + mag_down) / 2.0
return crossed, magnitude
[docs]
def reset(self) -> None:
"""Reset both inner protocols and the evaluation counter.
Calls ``reset()`` on each inner protocol if that method exists.
"""
if hasattr(self.protocol_up, "reset"):
self.protocol_up.reset()
if hasattr(self.protocol_down, "reset"):
self.protocol_down.reset()
self._n_evaluated = 0
@property
def n_evaluated(self) -> int:
"""Total number of value-pair evaluations since init or last reset."""
return self._n_evaluated
def __repr__(self) -> str:
return (
f"MultiBandProtocol("
f"up={self.up_label!r}→{self.protocol_up!r}, "
f"down={self.down_label!r}→{self.protocol_down!r}, "
f"require_both={self.require_both}, "
f"n_evaluated={self._n_evaluated})"
)