Source code for mne_rt.tools.riemannian_potato

"""Riemannian Potato — online artifact detection for streaming M/EEG.

Detects artifactual data windows by checking whether the covariance matrix
of each incoming chunk is too far from the geometric mean of clean data on
the symmetric positive-definite (SPD) matrix manifold.

Classes
-------
RiemannianPotatoDetector
    Online Riemannian artifact detector with calibration-based baseline.

References
----------
Barachant, A. & Congedo, M. (2014). A Plug&Play P300 BCI Using Information
Geometry. *arXiv:1409.0107*.

Barthélemy, Q., Mayaud, L., Ojeda, D., & Congedo, M. (2019). The Riemannian
potato field: a tool for online signal quality index of EEG.
*IRBM*, 40(1), 15–26. https://doi.org/10.1016/j.irbm.2018.09.001
"""
from __future__ import annotations

from typing import Union

import numpy as np

from mne_rt._logging import logger


[docs] class RiemannianPotatoDetector: """Online EEG/MEG artifact detector based on the Riemannian Potato. The *Riemannian Potato* detects artifactual epochs by measuring how far each incoming covariance matrix is from the geometric mean of a clean calibration set — distance is computed in the Riemannian metric on the manifold of symmetric positive-definite (SPD) matrices. A window is declared artifactual when this distance (expressed as a z-score) exceeds ``threshold``. Unlike threshold-based or ICA-based methods, this approach requires no reference channel and is robust to gradual amplitude drift (because the geometric mean adapts to the signal statistics during calibration rather than using a fixed absolute threshold). Parameters ---------- threshold : float, default 3.0 Z-score cutoff (Riemannian distance units). Windows with a distance z-score above this value are flagged as artifacts. Typical range: 2.5–4.0. Lower = more aggressive rejection. estimator : str, default "oas" Covariance estimator passed to :class:`pyriemann.estimation.Covariances`. ``"oas"`` (Oracle Approximating Shrinkage) is robust at low sample-to-channel ratios; ``"scm"`` gives the raw sample covariance. metric : str, default "riemann" Riemannian metric used for the potato. ``"riemann"`` (affine-invariant) is the canonical choice; ``"logeuclid"`` is faster but less robust. verbose : bool | str | None, default None Verbosity level. See :func:`~ant._logging.set_log_level`. Attributes ---------- is_fitted_ : bool ``True`` after :meth:`fit` has been called successfully. n_channels_ : int Number of channels seen during :meth:`fit`. n_calibration_windows_ : int Number of clean windows used to fit the potato. Raises ------ ImportError If ``pyriemann`` is not installed. RuntimeError If :meth:`detect` is called before :meth:`fit`. Examples -------- Calibrate on 60 s of clean data, then detect online:: detector = RiemannianPotatoDetector(threshold=3.0) detector.fit(clean_windows) # shape (n_windows, n_ch, n_samples) while streaming: window = stream.get_data(1.0) # shape (n_ch, n_samples) is_clean, z_score = detector.detect(window) if is_clean: process_for_nf(window) .. versionadded:: 1.0.0 """
[docs] def __init__( self, threshold: float = 3.0, estimator: str = "oas", metric: str = "riemann", verbose: Union[bool, str, None] = None, ) -> None: self._check_pyriemann() from mne_rt._logging import set_log_level set_log_level(verbose) if threshold <= 0: raise ValueError(f"threshold must be > 0, got {threshold}") self.threshold = threshold self.estimator = estimator self.metric = metric self._potato = None self._cov_estimator = None self.is_fitted_: bool = False self.n_channels_: int = 0 self.n_calibration_windows_: int = 0
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @staticmethod def _check_pyriemann() -> None: try: import pyriemann # noqa: F401 except ImportError as exc: raise ImportError( "RiemannianPotatoDetector requires pyriemann.\n" "Install it with: pip install pyriemann" ) from exc def _make_estimators(self) -> None: from pyriemann.clustering import Potato from pyriemann.estimation import Covariances self._cov_estimator = Covariances(estimator=self.estimator) self._potato = Potato( metric=self.metric, threshold=self.threshold, ) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def fit(self, windows: np.ndarray) -> "RiemannianPotatoDetector": """Calibrate the potato on clean data windows. Computes covariance matrices for each window and fits the geometric mean and variance used as the potato centre and spread. Parameters ---------- windows : ndarray, shape (n_windows, n_channels, n_samples) Calibration data. Should contain **clean** (artifact-free) segments. At least 10 windows are recommended; 30–60 is typical for a 1-second window length at standard EEG sampling rates. Returns ------- self : RiemannianPotatoDetector Raises ------ ValueError If ``windows.ndim != 3`` or fewer than 2 windows are provided. """ if windows.ndim != 3: raise ValueError( f"windows must be 3-D (n_windows, n_ch, n_samples), " f"got shape {windows.shape}" ) if len(windows) < 2: raise ValueError("At least 2 calibration windows are required.") self._make_estimators() covs = self._cov_estimator.fit_transform(windows) self._potato.fit(covs) self.is_fitted_ = True self.n_channels_ = windows.shape[1] self.n_calibration_windows_ = len(windows) logger.info( "RiemannianPotatoDetector fitted on %d windows (%d channels).", self.n_calibration_windows_, self.n_channels_, ) return self
[docs] def detect(self, window: np.ndarray) -> tuple[bool, float]: """Test a single window for artifacts. Parameters ---------- window : ndarray, shape (n_channels, n_samples) One data window to evaluate. Returns ------- is_clean : bool ``True`` when the window is inside the potato (clean); ``False`` when it is outside (potential artifact). z_score : float Riemannian distance z-score. Higher = more anomalous. Raises ------ RuntimeError If called before :meth:`fit`. ValueError If the channel count does not match the calibration set. """ if not self.is_fitted_: raise RuntimeError( "Call fit() with clean calibration data before detect()." ) if window.ndim != 2 or window.shape[0] != self.n_channels_: raise ValueError( f"Expected window shape ({self.n_channels_}, n_samples), " f"got {window.shape}." ) cov = self._cov_estimator.transform(window[np.newaxis]) label = float(self._potato.predict(cov)[0]) z_score = float(self._potato.transform(cov)[0]) is_clean = (label == 1.0) return is_clean, z_score
def __repr__(self) -> str: status = ( f"fitted on {self.n_calibration_windows_} windows, " f"{self.n_channels_} channels" if self.is_fitted_ else "not fitted" ) return ( f"RiemannianPotatoDetector(" f"threshold={self.threshold}, " f"metric={self.metric!r}, " f"{status})" )