Source code for mne_rt.tools.bad_channel_detector

"""Real-time bad channel detection for streaming M/EEG.

Classes
-------
BadChannelDetector
    Multi-criterion online bad channel monitor with rolling-window voting.

Notes
-----
The approach follows the methods used in EEGLAB's ``clean_rawdata`` plugin
:footcite:p:`mullen2015real` and MNE's automated quality checks, adapted for
chunk-by-chunk streaming data with no look-ahead.

References
----------
Mullen, T. R., et al. (2015). Real-time neuroimaging and cognitive monitoring
using wearable dry EEG. *IEEE Trans. Biomed. Eng.*, 62(11), 2553–2567.
https://doi.org/10.1109/TBME.2015.2481482
"""
from __future__ import annotations

import collections
import warnings
from typing import Optional, Union

import numpy as np


[docs] class BadChannelDetector: """Multi-criterion real-time bad channel detector for streaming M/EEG. Evaluates each incoming data window against up to four independent criteria. A channel is flagged as bad only when it exceeds its criterion in at least ``min_bad_frac`` of the rolling window history — this voting mechanism avoids false positives from transient artifacts. **Criteria** ``"flat"`` Channels whose RMS amplitude falls below ``flat_threshold`` (dead electrode / disconnected lead). ``"variance"`` Channels whose RMS amplitude is a statistical outlier across all channels, measured by a robust z-score (:math:`z = (\\text{rms} - \\text{median}) / (\\text{MAD} \\times 1.4826)`). Catches both excessively noisy channels and channels that have gone unusually quiet. ``"correlation"`` Channels whose mean Pearson correlation with their *K* nearest spatial neighbours drops below ``corr_threshold``. A channel that has lost contact or broken its reference will de-correlate from its neighbours while the neighbours remain correlated with each other. Requires channel positions to be set in ``info`` (i.e. montage applied). ``"hf_noise"`` Channels with an abnormally high ratio of high-frequency power (> ``hf_cutoff`` Hz) to broadband power. Catches channels contaminated by EMG, electrode cable noise, or loose connections. Uses the same MAD-based robust z-score as ``"variance"``. Parameters ---------- info : mne.Info MNE channel information. Must contain ``ch_names`` and ``sfreq``. For ``"correlation"``, channel 3-D positions must be set (``set_montage`` applied). method : {"all", "flat", "variance", "correlation", "hf_noise"} or list Criterion or list of criteria to evaluate. ``"all"`` enables all four criteria. Default is ``"all"``. flat_threshold : float, default 1e-7 Minimum RMS amplitude (in raw data units — V for EEG, T for MEG) for a channel not to be considered flat. Channels below this value are dead. Default 100 nV = ``1e-7`` V. variance_threshold : float, default 5.0 Robust z-score cutoff for the variance criterion. A channel whose RMS deviates by more than this many MAD-units from the channel median is flagged. Default ``5.0`` (conservative; reduce to 3–4 to be more aggressive). corr_threshold : float, default 0.4 Minimum mean Pearson correlation with spatial neighbours. Channels below this value are poorly coupled to their surroundings. Default ``0.4``. hf_threshold : float, default 5.0 Robust z-score cutoff for the HF noise criterion. Default ``5.0``. hf_cutoff : float, default 40.0 Frequency in Hz above which power is classified as *high-frequency* for the noise criterion. Default ``40.0`` Hz. n_neighbors : int, default 4 Number of nearest spatial neighbours used in the correlation criterion. Default ``4``. history_windows : int, default 30 Number of per-window bad-flags to retain in the rolling history. Combined with ``min_bad_frac`` this sets the effective time-scale for declaring a channel persistently bad. min_bad_frac : float, default 0.5 Fraction of rolling-history windows in which a channel must be flagged before it is declared bad. ``0.5`` = majority vote. ``1.0`` = must be bad in every recent window (very lenient). ``0.1`` = bad in any 10 % of windows (very strict). Attributes ---------- bad_channels_ : list of str Channel names currently declared bad. Updated on every :meth:`update` call. scores_ : dict of str → float Per-channel composite badness score in ``[0, 1]``: fraction of recent windows in which the channel was flagged by *any* active criterion. n_windows_ : int Total number of windows processed since initialisation or last :meth:`reset`. Examples -------- Basic usage — update once per NF window and pass bad channels to MNE:: detector = BadChannelDetector(raw.info, method="all") while streaming: window = stream.get_data(1.0) # 1 s chunk bad = detector.update(window) print("Bad channels:", bad) Use only the variance + flat criteria (no montage required):: detector = BadChannelDetector( info, method=["flat", "variance"], variance_threshold=4.0 ) .. versionadded:: 1.0.0 """ # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------
[docs] def __init__( self, info, method: Union[str, list[str]] = "all", flat_threshold: float = 1e-7, variance_threshold: float = 5.0, corr_threshold: float = 0.4, hf_threshold: float = 5.0, hf_cutoff: float = 40.0, n_neighbors: int = 4, history_windows: int = 30, min_bad_frac: float = 0.5, ) -> None: self._info = info self._ch_names: list[str] = list(info["ch_names"]) self._sfreq: float = float(info["sfreq"]) self._n_ch: int = len(self._ch_names) _valid = {"all", "flat", "variance", "correlation", "hf_noise"} if isinstance(method, str): method = list(_valid - {"all"}) if method == "all" else [method] for m in method: if m not in _valid: raise ValueError(f"Unknown method {m!r}. Choose from {_valid}.") self._methods: set[str] = set(method) if flat_threshold <= 0: raise ValueError("flat_threshold must be > 0") if not (0 < min_bad_frac <= 1): raise ValueError("min_bad_frac must be in (0, 1]") self.flat_threshold = flat_threshold self.variance_threshold = variance_threshold self.corr_threshold = corr_threshold self.hf_threshold = hf_threshold self.hf_cutoff = hf_cutoff self.n_neighbors = n_neighbors self.history_windows = history_windows self.min_bad_frac = min_bad_frac # Rolling history: per-channel deque of bool (True = bad in that window) self._history: dict[str, collections.deque] = { ch: collections.deque(maxlen=history_windows) for ch in self._ch_names } self.bad_channels_: list[str] = [] self.scores_: dict[str, float] = {ch: 0.0 for ch in self._ch_names} self.n_windows_: int = 0 # Pre-compute neighbour indices (may be empty if no positions) self._neighbor_idx: Optional[dict[int, list[int]]] = None if "correlation" in self._methods: self._neighbor_idx = self._build_neighbor_index()
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] def update(self, data: np.ndarray) -> list[str]: """Process one data window and return current bad-channel list. Parameters ---------- data : ndarray, shape (n_channels, n_samples) One analysis window of raw M/EEG data. The channel order must match ``info["ch_names"]``. Returns ------- bad_channels : list of str Channel names currently declared bad (majority vote over recent history). Raises ------ ValueError If ``data.shape[0] != n_channels``. """ if data.shape[0] != self._n_ch: raise ValueError( f"Expected {self._n_ch} channels, got {data.shape[0]}." ) flagged = np.zeros(self._n_ch, dtype=bool) if "flat" in self._methods: flagged |= self._criterion_flat(data) if "variance" in self._methods: flagged |= self._criterion_variance(data) if "correlation" in self._methods and self._neighbor_idx is not None: flagged |= self._criterion_correlation(data) if "hf_noise" in self._methods: flagged |= self._criterion_hf_noise(data) # Update rolling history for i, ch in enumerate(self._ch_names): self._history[ch].append(bool(flagged[i])) self.n_windows_ += 1 self._update_bad_list() return list(self.bad_channels_)
[docs] def get_bad_channels(self) -> list[str]: """Return the current list of declared bad channels. Returns ------- bad_channels : list of str Channel names that have been bad in ≥ ``min_bad_frac`` of the most recent ``history_windows`` windows. """ return list(self.bad_channels_)
[docs] def get_scores(self) -> dict[str, float]: """Return per-channel badness scores in the range [0, 1]. A score of ``1.0`` means the channel was flagged in every recent window; ``0.0`` means it was never flagged. Returns ------- scores : dict of str → float """ return dict(self.scores_)
[docs] def reset(self) -> None: """Clear rolling history and reset all counters. Constructor parameters and neighbour indices are preserved. """ for dq in self._history.values(): dq.clear() self.bad_channels_ = [] self.scores_ = {ch: 0.0 for ch in self._ch_names} self.n_windows_ = 0
# ------------------------------------------------------------------ # Criteria # ------------------------------------------------------------------ def _criterion_flat(self, data: np.ndarray) -> np.ndarray: """Flag channels whose RMS is below flat_threshold.""" rms = np.sqrt(np.mean(data ** 2, axis=1)) return rms < self.flat_threshold def _criterion_variance(self, data: np.ndarray) -> np.ndarray: """Flag channels whose RMS is a robust outlier across channels.""" rms = np.sqrt(np.mean(data ** 2, axis=1)) median = np.median(rms) mad = np.median(np.abs(rms - median)) * 1.4826 + 1e-30 z = np.abs((rms - median) / mad) return z > self.variance_threshold def _criterion_correlation(self, data: np.ndarray) -> np.ndarray: """Flag channels whose mean correlation with neighbours is low.""" flagged = np.zeros(self._n_ch, dtype=bool) if self._neighbor_idx is None: return flagged # Demean channels (required for valid Pearson correlation) dm = data - data.mean(axis=1, keepdims=True) norms = np.linalg.norm(dm, axis=1, keepdims=True) + 1e-30 dm_norm = dm / norms for i, nbrs in self._neighbor_idx.items(): if not nbrs: continue corrs = dm_norm[i] @ dm_norm[nbrs].T / data.shape[1] mean_corr = float(np.mean(np.abs(corrs))) if mean_corr < self.corr_threshold: flagged[i] = True return flagged def _criterion_hf_noise(self, data: np.ndarray) -> np.ndarray: """Flag channels with abnormally high HF power fraction.""" n = data.shape[1] freqs = np.fft.rfftfreq(n, d=1.0 / self._sfreq) fft_amp = np.abs(np.fft.rfft(data, axis=1)) ** 2 total_power = fft_amp.sum(axis=1) + 1e-30 hf_mask = freqs > self.hf_cutoff hf_power = fft_amp[:, hf_mask].sum(axis=1) hf_ratio = hf_power / total_power median = np.median(hf_ratio) mad = np.median(np.abs(hf_ratio - median)) * 1.4826 + 1e-30 z = (hf_ratio - median) / mad return z > self.hf_threshold # ------------------------------------------------------------------ # Internals # ------------------------------------------------------------------ def _update_bad_list(self) -> None: """Recompute bad_channels_ and scores_ from current rolling history.""" bad = [] for ch in self._ch_names: history = self._history[ch] if not history: score = 0.0 else: score = sum(history) / len(history) self.scores_[ch] = score if score >= self.min_bad_frac: bad.append(ch) self.bad_channels_ = bad def _build_neighbor_index(self) -> Optional[dict[int, list[int]]]: """Build a mapping from channel index to K nearest neighbours. Uses the 3-D electrode positions stored in ``info['chs'][k]['loc'][:3]``. Returns ``None`` when positions are not available. """ positions = [] for ch in self._info["chs"]: loc = np.asarray(ch["loc"][:3]) positions.append(loc) positions = np.array(positions) # Check for channels with no valid position valid = ~np.any(np.isnan(positions), axis=1) & ~(positions == 0).all(axis=1) if valid.sum() < 2: warnings.warn( "BadChannelDetector: fewer than 2 channels have valid positions; " "skipping the correlation criterion.", RuntimeWarning, stacklevel=4, ) return None neighbor_idx: dict[int, list[int]] = {} for i in range(self._n_ch): if not valid[i]: neighbor_idx[i] = [] continue # Euclidean distances to all other valid channels diffs = positions - positions[i] dists = np.linalg.norm(diffs, axis=1) dists[i] = np.inf # exclude self dists[~valid] = np.inf # exclude invalid channels k = min(self.n_neighbors, int(valid.sum()) - 1) nbrs = np.argsort(dists)[:k].tolist() neighbor_idx[i] = nbrs return neighbor_idx # ------------------------------------------------------------------ # Dunder # ------------------------------------------------------------------ def __repr__(self) -> str: return ( f"BadChannelDetector(" f"n_channels={self._n_ch}, " f"methods={sorted(self._methods)}, " f"bad={self.bad_channels_}, " f"n_windows={self.n_windows_})" )