"""GEneralized Decomposition for Artifact Identification (GEDAI).
Removes artifacts from EEG/MEG by solving a generalized eigenvalue problem
that finds spatial filters maximising signal in a target band relative to
broadband activity.
"""
from __future__ import annotations
import numpy as np
from scipy import linalg
from scipy.signal import sosfiltfilt
from mne_rt.tools.tools import butter_bandpass
[docs]
class GEDAIDenoiser:
"""Artifact removal via Generalised Eigendecomposition-based Artifact Identification (GEDAI).
Solves the generalised eigenvalue problem (GEP):
.. math::
\\mathbf{C}\\,\\mathbf{w} = \\lambda\\,\\mathbf{R}\\,\\mathbf{w}
where :math:`\\mathbf{C}` is the signal covariance and :math:`\\mathbf{R}`
is a reference covariance. Two modes are supported:
* **Leadfield mode** (recommended, true GEDAI):
:math:`\\mathbf{R} = \\mathbf{L}\\mathbf{L}^\\top`, where
:math:`\\mathbf{L}` is the EEG forward/leadfield gain matrix.
Components with *large* :math:`\\lambda` are well-explained by the
brain's theoretical source model and are kept; components with *small*
:math:`\\lambda` are not leadfield-aligned and are treated as artifacts.
Use :meth:`fit_from_leadfield` to fit in this mode.
* **Band-filter mode** (Cohen-style GED):
:math:`\\mathbf{C} = \\mathbf{R}_{\\mathrm{band}}` and
:math:`\\mathbf{R} = \\mathbf{R}_{\\mathrm{broad}}` (broadband EEG
covariance, Tikhonov-regularised). Components with large :math:`\\lambda`
maximise the target-band-to-broadband ratio.
Use :meth:`fit` or :meth:`fit_from_raw` for this mode.
In both modes the unmixing matrix :math:`\\mathbf{W}` (spatial filters)
and activation patterns :math:`\\mathbf{A} = (\\mathbf{W}^\\top)^+` are
stored after fitting. Denoising zeroes the selected artifact columns of
:math:`\\mathbf{A}` and reconstructs clean sensor data as
:math:`\\hat{\\mathbf{x}} = \\mathbf{A}_{\\mathrm{clean}}\\,\\mathbf{W}^\\top\\mathbf{x}`.
Parameters
----------
n_channels : int
Number of EEG/MEG channels.
shrinkage : float, default 0.01
Tikhonov regularisation strength applied to the reference covariance
before solving the GEP. Prevents ill-conditioning when the
covariance matrix is rank-deficient.
References
----------
Ros, T., Férat, V., Huang, Y., et al. (2025). Return of the GEDAI:
Unsupervised EEG Denoising based on Leadfield Filtering. *bioRxiv*.
https://doi.org/10.1101/2025.10.04.680449
Cohen, M. X. (2022). A tutorial on generalized eigendecomposition for
denoising, contrast enhancement, and dimension reduction in multichannel
electrophysiology. *NeuroImage*, 247, 118809.
"""
[docs]
def __init__(self, n_channels: int, shrinkage: float = 0.01) -> None:
self.n_channels = n_channels
self.shrinkage = shrinkage
self._W: np.ndarray | None = None # spatial filters, shape (n_ch, n_ch)
self._A: np.ndarray | None = None # activation patterns = pinv(W.T)
self._eigenvalues: np.ndarray | None = None
# ------------------------------------------------------------------
# Fitting
# ------------------------------------------------------------------
[docs]
def fit(
self,
data_broadband: np.ndarray,
data_band: np.ndarray,
) -> "GEDAIDenoiser":
"""Estimate spatial filters from baseline data.
Parameters
----------
data_broadband : ndarray, shape (n_channels, n_samples)
Broadband (or noise-reference) baseline data.
data_band : ndarray, shape (n_channels, n_samples)
Band-filtered baseline data (the signal of interest).
Returns
-------
self
"""
if data_broadband.shape[0] != self.n_channels:
raise ValueError(
f"Expected {self.n_channels} channels, got {data_broadband.shape[0]}"
)
n_samples = data_broadband.shape[1]
Xb = data_broadband - data_broadband.mean(axis=1, keepdims=True)
Xs = data_band - data_band.mean(axis=1, keepdims=True)
R_broad = (Xb @ Xb.T) / n_samples
R_band = (Xs @ Xs.T) / n_samples
# Regularise broadband covariance
reg = self.shrinkage * np.trace(R_broad) / self.n_channels
R_broad += reg * np.eye(self.n_channels)
eigenvalues, W = linalg.eigh(R_band, R_broad)
# Sort descending: largest λ = most band-specific
order = np.argsort(eigenvalues)[::-1]
self._eigenvalues = eigenvalues[order]
self._W = W[:, order]
self._A = np.linalg.pinv(self._W.T)
return self
[docs]
def fit_from_leadfield(
self,
data: np.ndarray,
leadfield: np.ndarray,
) -> "GEDAIDenoiser":
"""Fit using the forward/leadfield matrix as the reference covariance.
This is the **true GEDAI** mode described in Ros et al. (2025).
The reference covariance is constructed as
:math:`\\mathbf{R} = \\mathbf{L}\\mathbf{L}^\\top` where
:math:`\\mathbf{L}` is the leadfield gain matrix. Components
whose eigenvalues are large are well-aligned with the theoretical
brain source model and are treated as signal; components with small
eigenvalues are artifact candidates.
Parameters
----------
data : ndarray, shape (n_channels, n_samples)
Broadband baseline EEG/MEG recording.
leadfield : ndarray, shape (n_channels, n_sources)
Forward solution gain matrix :math:`\\mathbf{L}` — the
``fwd['sol']['data']`` array from an MNE forward solution.
Returns
-------
self
"""
if data.shape[0] != self.n_channels:
raise ValueError(
f"Expected {self.n_channels} channels, got {data.shape[0]}"
)
if leadfield.shape[0] != self.n_channels:
raise ValueError(
f"Leadfield row count ({leadfield.shape[0]}) must equal n_channels ({self.n_channels})"
)
n_samples = data.shape[1]
Xb = data - data.mean(axis=1, keepdims=True)
C = (Xb @ Xb.T) / n_samples
# Leadfield-based reference: R = L @ L.T (normalised)
R_lead = leadfield @ leadfield.T
R_lead /= np.trace(R_lead) / self.n_channels # scale to unit average power
# Tikhonov regularisation
reg = self.shrinkage * np.trace(R_lead) / self.n_channels
R_lead += reg * np.eye(self.n_channels)
eigenvalues, W = linalg.eigh(C, R_lead)
# Descending order: largest λ = most leadfield-aligned = brain signal
order = np.argsort(eigenvalues)[::-1]
self._eigenvalues = eigenvalues[order]
self._W = W[:, order]
self._A = np.linalg.pinv(self._W.T)
return self
[docs]
def fit_from_raw(
self,
data: np.ndarray,
sfreq: float,
band: tuple[float, float],
) -> "GEDAIDenoiser":
"""Convenience wrapper: bandpass-filter ``data`` internally then call ``fit``.
Parameters
----------
data : ndarray, shape (n_channels, n_samples)
Baseline recording (broadband).
sfreq : float
Sampling frequency in Hz.
band : (low, high)
Target frequency band in Hz.
Returns
-------
self
"""
sos = butter_bandpass(band[0], band[1], sfreq, order=5)
data_band = sosfiltfilt(sos, data)
return self.fit(data, data_band)
# ------------------------------------------------------------------
# Transform / reconstruct
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Artifact identification and removal
# ------------------------------------------------------------------
[docs]
def find_artifact_components(
self,
template_map: np.ndarray,
threshold: float = 0.7,
) -> tuple[list[int], np.ndarray]:
"""Identify artifact components by correlation with a spatial template.
Parameters
----------
template_map : ndarray, shape (n_channels,)
Known topography of the artifact (e.g., blink, ECG).
threshold : float
Absolute-correlation threshold.
Returns
-------
artifact_idx : list of int
corrs : ndarray, shape (n_channels,)
"""
self._check_fitted()
corrs = np.array(
[
np.corrcoef(self._A[:, i], template_map)[0, 1]
for i in range(self._A.shape[1])
]
)
artifact_idx = np.where(np.abs(corrs) > threshold)[0].tolist()
return artifact_idx, corrs
[docs]
def find_noise_components(self, n_noise: int = 1) -> list[int]:
"""Return indices of the ``n_noise`` components with smallest eigenvalues.
These components capture the least band-specific activity and are
candidates for broadband noise.
Parameters
----------
n_noise : int
Number of components to flag.
Returns
-------
list of int
"""
self._check_fitted()
return list(range(len(self._eigenvalues) - n_noise, len(self._eigenvalues)))
[docs]
def denoise(
self,
data: np.ndarray,
artifact_idx: list[int],
) -> np.ndarray:
"""Suppress artifact components and reconstruct sensor data.
Parameters
----------
data : ndarray, shape (n_channels, n_samples)
artifact_idx : list of int
Component indices to zero out.
Returns
-------
data_clean : ndarray, shape (n_channels, n_samples)
"""
components = self.transform(data)
if artifact_idx:
components[artifact_idx, :] = 0.0
return self.inverse_transform(components)
[docs]
def update_and_denoise(
self,
data: np.ndarray,
template_map: np.ndarray,
threshold: float = 0.7,
) -> np.ndarray:
"""Identify blink/artifact components then denoise in one call.
Parameters
----------
data : ndarray, shape (n_channels, n_samples)
template_map : ndarray, shape (n_channels,)
threshold : float
Returns
-------
data_clean : ndarray, shape (n_channels, n_samples)
"""
artifact_idx, _ = self.find_artifact_components(template_map, threshold)
return self.denoise(data, artifact_idx)
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def eigenvalues(self) -> np.ndarray:
"""Sorted eigenvalues (descending) from the GED."""
self._check_fitted()
return self._eigenvalues
@property
def spatial_filters(self) -> np.ndarray:
"""Spatial filters W, shape (n_channels, n_channels)."""
self._check_fitted()
return self._W
@property
def activation_patterns(self) -> np.ndarray:
"""Activation patterns A = pinv(W.T), shape (n_channels, n_channels)."""
self._check_fitted()
return self._A
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _check_fitted(self) -> None:
if self._W is None:
raise RuntimeError("Call fit() or fit_from_raw() before using this method.")