Source code for mne_rt.modalities

"""ModalityMixin — real-time feature extraction for all modalities.

Mixed into :class:`RTStream`. Assumes the following attributes exist on
``self`` at call time:

- ``rec_info``      — MNE :class:`~mne.Info`
- ``data_type``     — ``"eeg"`` | ``"meg"``
- ``_sfreq``        — sampling frequency (Hz)
- ``winsize``       — analysis window length (s)
- ``picks``         — channel selection (may be ``None``)
- ``params``        — current modality parameter dict (set before each prep call)
- ``subject_fs_id`` — FreeSurfer subject identifier
- ``subjects_fs_dir``
- ``subject_dir``   — :class:`~pathlib.Path` to subject data folder
- ``visit``         — visit number
- ``raw_baseline``  — baseline :class:`~mne.io.RawArray` (source modalities)
- ``fwd``, ``noise_cov`` — forward solution / covariance (LCMV)
- ``_prepare_raw_array(data)`` — wraps data in ``RawArray``, applies EEG ref if needed

New modalities added:
- ``erd_ers``          — event-related desynchronisation/synchronisation (%)
- ``laterality``       — inter-hemispheric power asymmetry index
- ``hjorth``           — mean of Hjorth mobility and complexity (no FFT)
- ``spectral_centroid``— frequency-weighted centre-of-mass of the PSD within a band
"""
from __future__ import annotations

from typing import Optional
from warnings import warn

import numpy as np
from scipy.optimize import curve_fit
from scipy.signal import butter, sosfiltfilt, welch
from pactools import Comodulogram

import mne
from mne import read_labels_from_annot
from mne.minimum_norm import apply_inverse_raw, read_inverse_operator
from mne.beamformer import apply_lcmv_raw, make_lcmv
from mne_connectivity import spectral_connectivity_time
from mne_features.univariate import (
    compute_app_entropy,
    compute_samp_entropy,
    compute_spect_entropy,
    compute_svd_entropy,
    
)

from mne_rt._logging import logger
from mne_rt.tools import (
    butter_bandpass,
    compute_bandpower,
    compute_fft,
    estimate_aperiodic_component,
    log_degree_barrier,
    timed,
)


[docs] class ModalityMixin: """Feature-extraction engine for all MNE-RT NF modalities. :class:`ModalityMixin` is mixed into :class:`~mne_rt.RTStream` and provides the **prep / compute** pair for every modality: * **Prep** (``_<modality>_prep()``) — runs *once* before the main loop. Returns a ``dict`` of pre-computed artefacts (filter coefficients, index arrays, connectivity indices, …) that are passed as keyword arguments to the compute step. * **Compute** (``_<modality>(data, **prep_kwargs)``) — runs *every window* inside a thread-pool worker. Decorated with :func:`~ant.tools.timed` so it returns ``(value, elapsed_seconds)``. .. rubric:: Supported modalities (20 total) **Sensor-space power & time-domain** ``sensor_power``, ``band_ratio``, ``erd_ers``, ``laterality``, ``laterality_erd_ers``, ``hjorth``, ``spectral_centroid``, ``argmax_freq``, ``individual_peak_power``, ``entropy``, ``instantaneous_phase``, ``scp``, ``peak_alpha_freq`` **Sensor-space connectivity & graph** ``sensor_connectivity``, ``connectivity_ratio``, ``cfc_sensor``, ``sensor_graph`` **Source-space** ``source_power``, ``source_connectivity``, ``source_graph`` Notes ----- All methods are private (single-underscore prefix) and are called internally by :meth:`~mne_rt.RTStream.record_main`. To extend MNE-RT with a custom modality, sub-class :class:`~mne_rt.RTStream` and add a matching ``_<name>_prep`` / ``_<name>`` pair following the same pattern. """ # ------------------------------------------------------------------ # Prep methods (run once before the main loop, return kwargs dict) # ------------------------------------------------------------------ def _sensor_power_prep(self) -> dict: return { "sfreq": self.rec_info["sfreq"], "frange": self.params["frange"], "method": self.params["method"], "relative": self.params["relative"], } def _band_ratio_prep(self) -> dict: return { "sfreq": self.rec_info["sfreq"], "frange_1": self.params["frange_1"], "frange_2": self.params["frange_2"], "method": self.params["method"], } def _individual_peak_power_prep(self) -> dict: _, peak_params_ = estimate_aperiodic_component( raw_baseline=self.raw_baseline, picks=self.picks, method=self.params["method"], ) candidates = [ p[0] for p in peak_params_ if self.params["frange"][0] < p[0] < self.params["frange"][1] ] if len(candidates) == 1: cf = candidates[0] else: cf = (self.params["frange"][0] + self.params["frange"][1]) / 2.0 warn( "individual_peak_power: center frequency defaulted to mid-range " f"({cf:.1f} Hz); found {len(candidates)} peak(s) in band.", UserWarning, stacklevel=2, ) return {"sfreq": self._sfreq, "freq_var": 2.0, "cf": cf} def _entropy_prep(self) -> dict: sos = butter_bandpass( self.params["frange"][0], self.params["frange"][1], self._sfreq, order=5, ) return { "sos": sos, "method": self.params["method"], "psd_method": self.params["psd_method"], } def _argmax_freq_prep(self) -> dict: if not hasattr(self, "raw_baseline"): raise RuntimeError( "Baseline recording must be completed before using 'argmax_freq'." ) ap_params, _ = estimate_aperiodic_component( raw_baseline=self.raw_baseline, picks=self.picks, method=self.params["method"], ) n_samples = int(self.winsize * self._sfreq) fft_window = np.hanning(n_samples) freqs = np.fft.rfftfreq(n_samples, d=1.0 / self._sfreq) mask = (freqs >= self.params["frange"][0]) & (freqs <= self.params["frange"][1]) freqs_band = freqs[mask] ap_model = (10 ** ap_params[0]) / (freqs_band ** ap_params[1]) def _gaussian(x: np.ndarray, a: float, mu: float, sigma: float) -> np.ndarray: return a * np.exp(-(x - mu) ** 2 / (2 * sigma ** 2)) return { "fft_window": fft_window, "ap_model": ap_model, "gaussian": _gaussian, } def _source_power_prep(self) -> dict: fft_window, _, freq_band_idxs, _ = compute_fft( sfreq=self._sfreq, winsize=self.winsize, freq_range=self.params["frange"], ) bls = read_labels_from_annot( subject=self.subject_fs_id, parc=self.params["atlas"], subjects_dir=self.subjects_fs_dir, ) brain_label = bls[[bl.name for bl in bls].index(self.params["brain_label"])] method = self.params["method"] if method in ("MNE", "dSPM", "sLORETA", "eLORETA"): inverse_operator = read_inverse_operator( fname=self.subject_dir / "inv" / f"visit_{self.visit}-inv.fif" ) elif method == "LCMV": inverse_operator = make_lcmv( self.rec_info, self.fwd, self.noise_cov, reg=0.05, pick_ori="max-power", weight_norm="unit-noise-gain", rank=None, ) else: raise ValueError( f"Unknown source method: {method!r}. " "Expected one of 'MNE', 'dSPM', 'sLORETA', 'eLORETA', 'LCMV'." ) return { "fft_window": fft_window, "freq_band_idxs": freq_band_idxs, "brain_label": brain_label, "inverse_operator": inverse_operator, "method": method, } def _sensor_connectivity_prep(self) -> dict: ch_names = self.rec_info["ch_names"] chs = self.params["channels"] indices = tuple( np.array([ch_names.index(ch1), ch_names.index(ch2)]) for ch1, ch2 in zip(chs[0], chs[1]) ) freqs = np.linspace(self.params["frange"][0], self.params["frange"][1], 6) return { "indices": indices, "freqs": freqs, "fmin": self.params["frange"][0], "fmax": self.params["frange"][1], "mode": self.params["mode"], "method": self.params["method"], } def _source_connectivity_prep(self) -> dict: lbl1, lbl2 = self.params["brain_label_1"], self.params["brain_label_2"] if not lbl1.endswith("-lh"): raise ValueError(f"brain_label_1 must end with '-lh', got {lbl1!r}.") if not lbl2.endswith("-rh"): raise ValueError(f"brain_label_2 must end with '-rh', got {lbl2!r}.") bls = read_labels_from_annot( subject=self.subject_fs_id, parc=self.params["atlas"], subjects_dir=self.subjects_fs_dir, ) bl_names = [bl.name for bl in bls] merged_label = ( bls[bl_names.index(lbl1)] + bls[bl_names.index(lbl2)] ) inverse_operator = read_inverse_operator( fname=self.subject_dir / "inv" / f"visit_{self.visit}-inv.fif" ) freqs = np.linspace(self.params["frange"][0], self.params["frange"][1], 6) return { "merged_label": merged_label, "inverse_operator": inverse_operator, "freqs": freqs, } def _sensor_graph_prep(self) -> dict: ch_names = self.rec_info["ch_names"] chs = self.params["channels"] indices = tuple( np.array([ch_names.index(ch1), ch_names.index(ch2)]) for ch1, ch2 in zip(chs[0], chs[1]) ) sos = butter_bandpass( self.params["frange"][0], self.params["frange"][1], self._sfreq, order=5, ) return { "indices": indices, "sos": sos, "dist_type": self.params["dist_type"], "alpha": self.params["alpha"], "beta": self.params["beta"], } def _source_graph_prep(self) -> dict: bls = read_labels_from_annot( subject=self.subject_fs_id, parc=self.params["atlas"], subjects_dir=self.subjects_fs_dir, ) bl_names = [bl.name for bl in bls] bl_idxs = ( bl_names.index(self.params["brain_label_1"]), bl_names.index(self.params["brain_label_2"]), ) inverse_operator = read_inverse_operator( fname=self.subject_dir / "inv" / f"visit_{self.visit}-inv.fif" ) sos = butter_bandpass( self.params["frange"][0], self.params["frange"][1], self._sfreq, order=5, ) return { "bls": bls, "bl_idxs": bl_idxs, "inverse_operator": inverse_operator, "sos": sos, } def _cfc_sensor_prep(self) -> dict: comod = Comodulogram( fs=self._sfreq, low_fq_range=np.linspace( self.params["frange_1"][0], self.params["frange_1"][1], 5 ), high_fq_range=np.linspace( self.params["frange_2"][0], self.params["frange_2"][1], 5 ), method=self.params["method"], n_surrogates=0, ) return {"comod": comod} # ------------------------------------------------------------------ # Feature-extraction methods (decorated with @timed) # ------------------------------------------------------------------ @timed def _sensor_power( self, data: np.ndarray, sfreq: float, frange: tuple, method: str = "welch", relative: bool = False, ) -> float: """Mean band power across channels at sensor level.""" bp = compute_bandpower(data, sfreq, frange, method=method, relative=relative) return float(bp.mean()) @timed def _band_ratio( self, data: np.ndarray, sfreq: float, frange_1: tuple, frange_2: tuple, method: str = "welch", ) -> float: """Power ratio between two frequency bands.""" bp1 = compute_bandpower(data, sfreq, tuple(frange_1), method=method, relative=False) bp2 = compute_bandpower(data, sfreq, tuple(frange_2), method=method, relative=False) return float(bp1.mean() / (bp2.mean() + 1e-30)) @timed def _individual_peak_power( self, data: np.ndarray, sfreq: float, freq_var: float, cf: float, ) -> float: """Band power in a narrow window around the individual peak frequency.""" bp = compute_bandpower( data, sfreq, (cf - freq_var, cf + freq_var), method="welch", relative=False, ) return float(bp.mean()) @timed def _entropy( self, data: np.ndarray, sos: np.ndarray, method: str, psd_method: Optional[str] = None, ) -> float: """Entropy of band-filtered M/EEG signals.""" data_filt = sosfiltfilt(sos, data) if method == "AppEn": ents = compute_app_entropy(data_filt) elif method == "SampEn": ents = compute_samp_entropy(data_filt) elif method == "Spectral": ents = compute_spect_entropy( sfreq=self._sfreq, data=data_filt, psd_method=psd_method ) elif method == "SVD": ents = compute_svd_entropy(data_filt) else: raise ValueError( f"Unknown entropy method: {method!r}. " "Expected one of 'AppEn', 'SampEn', 'Spectral', 'SVD'." ) return float(ents.mean() - 2) @timed def _argmax_freq( self, data: np.ndarray, fft_window: np.ndarray, ap_model: np.ndarray, gaussian, ) -> float: """Individual peak frequency via aperiodic subtraction + Gaussian fit.""" data_win = data * fft_window fftval = np.abs(np.fft.rfft(data_win, axis=1) / data.shape[-1]) freqs = np.fft.rfftfreq(data.shape[-1], d=1.0 / self._sfreq) mask = (freqs >= self.params["frange"][0]) & (freqs <= self.params["frange"][1]) freqs_band = freqs[mask] periodic_power = np.mean(np.square(fftval[:, mask]), axis=0) - ap_model p0 = [periodic_power.max(), freqs_band[np.argmax(periodic_power)], 1.0] try: popt, _ = curve_fit(gaussian, freqs_band, periodic_power, p0=p0) return float(popt[1]) except RuntimeError: warn( "argmax_freq: Gaussian fit failed; returning 0 Hz.", RuntimeWarning, stacklevel=2, ) return 0.0 @timed def _source_power( self, data: np.ndarray, fft_window: np.ndarray, freq_band_idxs: np.ndarray, brain_label, inverse_operator, method: str, ) -> float: """Source-level band power in a brain label.""" raw_data = self._prepare_raw_array(data) if method in ("MNE", "dSPM", "sLORETA", "eLORETA"): stc_data = apply_inverse_raw( raw_data, inverse_operator, lambda2=1.0 / 9, method=method, pick_ori="normal", label=brain_label, ).data else: stc_data = apply_lcmv_raw(raw_data, inverse_operator).data stc_data = stc_data * fft_window fft_val = np.abs(np.fft.rfft(stc_data, axis=1) / stc_data.shape[-1]) return float(np.mean(np.square(fft_val[:, freq_band_idxs]))) @timed def _sensor_connectivity( self, data: np.ndarray, indices: tuple, freqs: np.ndarray, fmin: float, fmax: float, mode: str, method: str, ) -> float: """Sensor-level spectral connectivity between channel pairs.""" con = spectral_connectivity_time( data=data[np.newaxis, :], freqs=freqs, indices=indices, average=False, sfreq=self._sfreq, fmin=fmin, fmax=fmax, faverage=True, mode=mode, method=method, n_cycles=5, ) return float(np.squeeze(con.get_data(output="dense"))[indices].mean()) @timed def _source_connectivity( self, data: np.ndarray, merged_label, inverse_operator, freqs: np.ndarray, ) -> float: """Source-level connectivity between two brain labels.""" raw_data = self._prepare_raw_array(data) stcs = apply_inverse_raw( raw_data, inverse_operator, lambda2=1.0 / 9, pick_ori="normal", label=merged_label, ) con = spectral_connectivity_time( data=np.array([[stcs.lh_data.mean(axis=0), stcs.rh_data.mean(axis=0)]]), freqs=freqs, indices=None, average=False, sfreq=self._sfreq, fmin=self.params["frange"][0], fmax=self.params["frange"][1], faverage=True, mode=self.params["mode"], method=self.params["method"], n_cycles=5, ) return float(np.squeeze(con.get_data(output="dense"))[1][0]) @timed def _sensor_graph( self, data: np.ndarray, indices: tuple, sos: np.ndarray, dist_type: str, alpha: float, beta: float, ) -> float: """Graph-theoretic connectivity from sensor-space M/EEG.""" data_filt = sosfiltfilt(sos, data) graph_matrix = log_degree_barrier( data_filt, dist_type=dist_type, alpha=alpha, beta=beta ) return float(np.mean([graph_matrix[idxs] for idxs in indices]) - 0.025) @timed def _source_graph( self, data: np.ndarray, bls: list, bl_idxs: tuple, inverse_operator, sos: np.ndarray, ) -> float: """Graph-theoretic connectivity from source-space M/EEG.""" raw_data = self._prepare_raw_array(data) stcs = apply_inverse_raw( raw_data, inverse_operator, lambda2=1.0 / 9, pick_ori="normal", ) tcs = stcs.extract_label_time_course( bls, src=inverse_operator["src"], mode="mean_flip", allow_empty=True, ) tcs_filt = sosfiltfilt(sos, tcs) graph_matrix = log_degree_barrier( tcs_filt, dist_type=self.params["dist_type"], alpha=self.params["alpha"], beta=self.params["beta"], ) return float(graph_matrix[bl_idxs[0], bl_idxs[1]]) @timed def _cfc_sensor(self, data: np.ndarray, comod) -> float: """Cross-frequency coupling (modulation index) at sensor level.""" comod.fit(data) return float(comod.comod_.mean()) # ------------------------------------------------------------------ # ERD/ERS # ------------------------------------------------------------------ def _erd_ers_prep(self) -> dict: if not hasattr(self, "raw_baseline") or self.raw_baseline is None: raise RuntimeError( "erd_ers requires a completed baseline recording. " "Call record_baseline() first." ) baseline_power = compute_bandpower( self.raw_baseline.get_data(), sfreq=self._sfreq, band=tuple(self.params["frange"]), method=self.params["method"], relative=False, ).mean() return { "sfreq": self._sfreq, "frange": self.params["frange"], "method": self.params["method"], "baseline_power": float(baseline_power), } @timed def _erd_ers( self, data: np.ndarray, sfreq: float, frange: tuple, method: str, baseline_power: float, ) -> float: """Event-related desynchronisation / synchronisation (%). Positive values = synchronisation (ERS); negative = desynchronisation (ERD). """ current_power = compute_bandpower(data, sfreq, tuple(frange), method=method, relative=False).mean() return float((current_power - baseline_power) / (baseline_power + 1e-300) * 100.0) # ------------------------------------------------------------------ # Laterality # ------------------------------------------------------------------ def _laterality_prep(self) -> dict: ch_names = self.rec_info["ch_names"] def _is_left(name: str) -> bool: # 10-20 convention: trailing odd digit → left hemisphere for i in range(len(name) - 1, -1, -1): if name[i].isdigit(): return int(name[i]) % 2 == 1 return False def _is_right(name: str) -> bool: for i in range(len(name) - 1, -1, -1): if name[i].isdigit(): return int(name[i]) % 2 == 0 return False lh_idx = [i for i, ch in enumerate(ch_names) if _is_left(ch)] rh_idx = [i for i, ch in enumerate(ch_names) if _is_right(ch)] if not lh_idx or not rh_idx: warn( "laterality: could not auto-detect left/right channels from names; " "splitting by index instead.", UserWarning, stacklevel=2, ) mid = len(ch_names) // 2 lh_idx = list(range(mid)) rh_idx = list(range(mid, len(ch_names))) return { "sfreq": self._sfreq, "frange": self.params["frange"], "method": self.params["method"], "lh_idx": lh_idx, "rh_idx": rh_idx, } @timed def _laterality( self, data: np.ndarray, sfreq: float, frange: tuple, method: str, lh_idx: list, rh_idx: list, ) -> float: """Inter-hemispheric power asymmetry: log(P_right) − log(P_left). Positive → right dominance; negative → left dominance. """ lh_power = compute_bandpower(data[lh_idx], sfreq, tuple(frange), method=method, relative=False).mean() rh_power = compute_bandpower(data[rh_idx], sfreq, tuple(frange), method=method, relative=False).mean() return float(np.log(rh_power + 1e-300) - np.log(lh_power + 1e-300)) # ------------------------------------------------------------------ # Hjorth parameters # ------------------------------------------------------------------ def _hjorth_prep(self) -> dict: sos = butter_bandpass( self.params["frange"][0], self.params["frange"][1], self._sfreq, order=5, ) return {"sos": sos} @timed def _hjorth(self, data: np.ndarray, sos: np.ndarray) -> float: """Mean of Hjorth mobility and complexity across channels. Mobility ≈ dominant frequency proxy; complexity ≈ signal irregularity. No FFT required — pure time-domain. """ x = sosfiltfilt(sos, data) # shape (n_ch, n_samples) dx = np.diff(x, axis=1) ddx = np.diff(dx, axis=1) var_x = np.var(x, axis=1) + 1e-300 var_dx = np.var(dx, axis=1) + 1e-300 var_ddx = np.var(ddx, axis=1) + 1e-300 mobility = np.sqrt(var_dx / var_x) mobility_d = np.sqrt(var_ddx / var_dx) complexity = mobility_d / mobility return float(0.5 * (mobility.mean() + complexity.mean())) # ------------------------------------------------------------------ # Spectral centroid # ------------------------------------------------------------------ def _spectral_centroid_prep(self) -> dict: return { "sfreq": self._sfreq, "frange": self.params["frange"], } @timed def _spectral_centroid( self, data: np.ndarray, sfreq: float, frange: tuple, ) -> float: """Frequency-weighted centre-of-mass of the PSD within a band (Hz). High centroid → activity shifted towards the upper edge of the band (useful for tracking alpha-peak drift or SMR centre-frequency). """ n_samples = data.shape[1] freqs = np.fft.rfftfreq(n_samples, d=1.0 / sfreq) mask = (freqs >= frange[0]) & (freqs <= frange[1]) freqs_band = freqs[mask] psd = np.abs(np.fft.rfft(data, axis=1)) ** 2 # shape (n_ch, n_freqs) psd_band = psd[:, mask] total = psd_band.sum(axis=1, keepdims=True) + 1e-300 centroid_per_ch = (psd_band * freqs_band[np.newaxis, :]).sum(axis=1) / total.squeeze() return float(centroid_per_ch.mean()) # ------------------------------------------------------------------ # ERD/ERS laterality index # ------------------------------------------------------------------ def _laterality_erd_ers_prep(self) -> dict: """Prep: detect hemispheric channel indices + compute baseline powers.""" # requires raw_baseline if not hasattr(self, "raw_baseline") or self.raw_baseline is None: raise RuntimeError( "laterality_erd_ers requires a completed baseline recording. " "Call record_baseline() first." ) ch_names = self.rec_info["ch_names"] def _is_left(name): for i in range(len(name)-1, -1, -1): if name[i].isdigit(): return int(name[i]) % 2 == 1 return False def _is_right(name): for i in range(len(name)-1, -1, -1): if name[i].isdigit(): return int(name[i]) % 2 == 0 return False lh_idx = [i for i, ch in enumerate(ch_names) if _is_left(ch)] rh_idx = [i for i, ch in enumerate(ch_names) if _is_right(ch)] if not lh_idx or not rh_idx: warn("laterality_erd_ers: hemispheric auto-detection failed; splitting by index.", UserWarning, stacklevel=2) mid = len(ch_names) // 2 lh_idx = list(range(mid)) rh_idx = list(range(mid, len(ch_names))) baseline_data = self.raw_baseline.get_data() frange = tuple(self.params["frange"]) method = self.params["method"] baseline_lh = float(compute_bandpower(baseline_data[lh_idx], self._sfreq, frange, method=method, relative=False).mean()) baseline_rh = float(compute_bandpower(baseline_data[rh_idx], self._sfreq, frange, method=method, relative=False).mean()) return { "sfreq": self._sfreq, "frange": frange, "method": method, "lh_idx": lh_idx, "rh_idx": rh_idx, "baseline_lh": baseline_lh, "baseline_rh": baseline_rh, } @timed def _laterality_erd_ers( self, data: np.ndarray, sfreq: float, frange: tuple, method: str, lh_idx: list, rh_idx: list, baseline_lh: float, baseline_rh: float, ) -> float: """Baseline-normalised inter-hemispheric ERD/ERS asymmetry (%). Computes the ERD/ERS ratio for each hemisphere separately (normalised by its own baseline power) and returns the signed difference: feature = ERD_ERS_right − ERD_ERS_left * Positive → right hemisphere more activated (ERS) or less suppressed. * Negative → left hemisphere more activated (or right more suppressed). Motor imagery example: right-hand imagery produces left-hemisphere alpha ERD, so the feature becomes strongly negative during the task and recovers toward zero at rest. """ lh_now = compute_bandpower(data[lh_idx], sfreq, frange, method=method, relative=False).mean() rh_now = compute_bandpower(data[rh_idx], sfreq, frange, method=method, relative=False).mean() erd_lh = (lh_now - baseline_lh) / (baseline_lh + 1e-300) * 100.0 erd_rh = (rh_now - baseline_rh) / (baseline_rh + 1e-300) * 100.0 return float(erd_rh - erd_lh) # ------------------------------------------------------------------ # Slow Cortical Potentials (SCP) # ------------------------------------------------------------------ def _scp_prep(self) -> dict: """Prep: build SOS low-pass (and optional high-pass) Butterworth filters.""" sfreq = self.rec_info["sfreq"] lowpass = self.params["lowpass"] highpass = self.params.get("highpass", 0.0) reference = self.params.get("reference", "mean") nyq = sfreq / 2.0 sos_lp = butter(4, lowpass / nyq, btype="low", output="sos") sos_hp = None if highpass > 0.0: sos_hp = butter(4, highpass / nyq, btype="high", output="sos") return { "sos_lp": sos_lp, "sos_hp": sos_hp, "reference": reference, } @timed def _scp( self, data: np.ndarray, sos_lp: np.ndarray, sos_hp, reference: str, ) -> float: """Slow Cortical Potential: mean amplitude of the DC-coupled slow signal. Applies a low-pass (and optional high-pass) zero-phase Butterworth filter to extract the slow envelope, then collapses channels via mean or median and returns the temporal mean of the resulting signal. Positive SCP → cortical deactivation; negative SCP → activation. """ sig = data.copy() # Optional high-pass first (removes very slow drifts if DC not coupled) if sos_hp is not None: sig = sosfiltfilt(sos_hp, sig) # Low-pass to extract the slow cortical potential sig = sosfiltfilt(sos_lp, sig) # Collapse channels if reference == "median": channel_summary = np.median(sig, axis=0) # shape: (n_samples,) else: channel_summary = np.mean(sig, axis=0) # shape: (n_samples,) return float(channel_summary.mean()) # ------------------------------------------------------------------ # Peak Alpha Frequency (PAF) tracker # ------------------------------------------------------------------ def _peak_alpha_freq_prep(self) -> dict: """Prep: initialise EMA state for the real-time PAF tracker.""" sfreq = self.rec_info["sfreq"] frange = self.params["frange"] method = self.params.get("method", "welch") smoothing = self.params.get("smoothing", 0.85) # Compute initial PAF from baseline if available; else use band midpoint if hasattr(self, "raw_baseline") and self.raw_baseline is not None: baseline_data = self.raw_baseline.get_data() # (n_ch, n_samples) mean_sig = baseline_data.mean(axis=0) # (n_samples,) if method == "welch": freqs_bl, psd_bl = welch(mean_sig, fs=sfreq, nperseg=min(256, mean_sig.shape[-1])) else: n = mean_sig.shape[-1] fft_vals = np.abs(np.fft.rfft(mean_sig)) ** 2 freqs_bl = np.fft.rfftfreq(n, d=1.0 / sfreq) psd_bl = fft_vals mask_bl = (freqs_bl >= frange[0]) & (freqs_bl <= frange[1]) if mask_bl.any(): initial_paf = float(freqs_bl[mask_bl][np.argmax(psd_bl[mask_bl])]) else: initial_paf = float((frange[0] + frange[1]) / 2.0) else: initial_paf = float((frange[0] + frange[1]) / 2.0) return { "sfreq": float(sfreq), "frange": list(frange), "method": method, "smoothing": float(smoothing), "_paf_state": [initial_paf], # mutable reference cell for EMA state } @timed def _peak_alpha_freq( self, data: np.ndarray, sfreq: float, frange: list, method: str, smoothing: float, _paf_state: list, ) -> float: """Real-time peak alpha frequency (PAF) with exponential smoothing. Computes the PSD of the current window (averaged across channels), finds the dominant peak within *frange*, and updates an exponential moving average (EMA) to suppress frame-to-frame jitter. Returns the EMA-smoothed PAF in Hz. """ # Average across channels to get a single time series mean_sig = data.mean(axis=0) # shape: (n_samples,) # Compute PSD if method == "welch": freqs, psd = welch(mean_sig, fs=sfreq, nperseg=min(256, mean_sig.shape[-1])) else: n = mean_sig.shape[-1] psd = np.abs(np.fft.rfft(mean_sig)) ** 2 freqs = np.fft.rfftfreq(n, d=1.0 / sfreq) # Find peak within frange mask = (freqs >= frange[0]) & (freqs <= frange[1]) if mask.any(): peak_freq = float(freqs[mask][np.argmax(psd[mask])]) else: peak_freq = float(_paf_state[0]) # fallback: keep current estimate # EMA update — mutate the state cell so state persists across windows new_paf = (1.0 - smoothing) * peak_freq + smoothing * _paf_state[0] _paf_state[0] = new_paf return float(new_paf) # ------------------------------------------------------------------ # Connectivity Ratio # ------------------------------------------------------------------ def _connectivity_ratio_prep(self) -> dict: """Prep: build connectivity indices for numerator and denominator pairs.""" ch_names = self.rec_info["ch_names"] def _pair_to_indices(pair): a, b = pair[0], pair[1] if a not in ch_names or b not in ch_names: missing = [c for c in [a, b] if c not in ch_names] raise ValueError( f"connectivity_ratio: channels {missing} not found in recording. " f"Available: {ch_names}" ) return (np.array([ch_names.index(a)]), np.array([ch_names.index(b)])) indices_num = _pair_to_indices(self.params["channels_num"]) indices_den = _pair_to_indices(self.params["channels_den"]) freqs = np.linspace(self.params["frange"][0], self.params["frange"][1], 6) return { "indices_num": indices_num, "indices_den": indices_den, "freqs": freqs, "fmin": float(self.params["frange"][0]), "fmax": float(self.params["frange"][1]), "mode": self.params["mode"], "method": self.params["method"], } @timed def _connectivity_ratio( self, data: np.ndarray, indices_num: tuple, indices_den: tuple, freqs: np.ndarray, fmin: float, fmax: float, mode: str, method: str, ) -> float: """Ratio of functional connectivity between two channel pairs (or groups). Useful for laterality of connectivity, e.g. ipsilateral / contralateral. Returns conn_pair1 / conn_pair2. """ # Numerator connectivity con_num = spectral_connectivity_time( data=data[np.newaxis, :], freqs=freqs, indices=indices_num, average=False, sfreq=self._sfreq, fmin=fmin, fmax=fmax, faverage=True, mode=mode, method=method, n_cycles=5, ) conn_num = float(np.squeeze(con_num.get_data(output="dense"))[indices_num].mean()) # Denominator connectivity con_den = spectral_connectivity_time( data=data[np.newaxis, :], freqs=freqs, indices=indices_den, average=False, sfreq=self._sfreq, fmin=fmin, fmax=fmax, faverage=True, mode=mode, method=method, n_cycles=5, ) conn_den = float(np.squeeze(con_den.get_data(output="dense"))[indices_den].mean()) return float(conn_num / (conn_den + 1e-30))