Source code for mne_rt.viz.raw_plot

"""Real-time raw M/EEG channel viewer.

Dark-themed scrolling raw signal display built on PyQt6 + pyqtgraph.

Classes
-------
RawPlot
    Scrolling multi-channel raw M/EEG signal viewer.
"""
from __future__ import annotations

import datetime
from collections import deque
from pathlib import Path

import numpy as np
import pyqtgraph as pg
import pyqtgraph.exporters
from PyQt6.QtCore import QEvent, QObject, Qt, QTimer
from PyQt6.QtWidgets import (
    QCheckBox,
    QComboBox,
    QDoubleSpinBox,
    QGroupBox,
    QHBoxLayout,
    QLabel,
    QMainWindow,
    QPushButton,
    QScrollArea,
    QScrollBar,
    QSpinBox,
    QVBoxLayout,
    QWidget,
)


# ---------------------------------------------------------------------------
# Per-channel trace colours — 20 visually distinct hues for a dark background
# ---------------------------------------------------------------------------

_TRACE_COLORS = [
    "#4fc3f7",  # sky blue
    "#ef9a9a",  # salmon
    "#a5d6a7",  # mint green
    "#fff176",  # yellow
    "#ffab91",  # light orange
    "#ce93d8",  # lavender
    "#80cbc4",  # teal
    "#ffcc80",  # peach
    "#80deea",  # cyan
    "#b39ddb",  # purple
    "#f48fb1",  # pink
    "#c5e1a5",  # lime green
    "#ffd54f",  # amber
    "#81d4fa",  # light blue
    "#dce775",  # yellow-green
    "#ff8a65",  # deep orange
    "#90caf9",  # blue
    "#e6ee9c",  # lime
    "#bcaaa4",  # warm grey
    "#ffe082",  # light amber
]

_TIME_WINDOW_OPTIONS = [5, 10, 20, 30, 60]

_QSS = """
QMainWindow, QWidget {
    background-color: #1a1a2e;
    color: #e0e0e0;
    font-family: "Segoe UI", sans-serif;
}
QPushButton {
    background-color: #16213e;
    color: #d0d0e8;
    border: 1px solid #0f3460;
    border-radius: 5px;
    padding: 5px 10px;
    font-size: 12px;
}
QPushButton:hover  { background-color: #0f3460; }
QPushButton:pressed { background-color: #533483; }
QPushButton:checked {
    background-color: #533483;
    border-color: #a882dd;
    color: #ffffff;
}
QComboBox {
    background-color: #16213e;
    color: #d0d0e8;
    border: 1px solid #0f3460;
    border-radius: 4px;
    padding: 3px 6px;
}
QComboBox QAbstractItemView {
    background-color: #16213e;
    color: #d0d0e8;
    selection-background-color: #0f3460;
}
QDoubleSpinBox, QSpinBox {
    background-color: #16213e;
    color: #d0d0e8;
    border: 1px solid #0f3460;
    border-radius: 4px;
    padding: 2px 4px;
}
QGroupBox {
    border: 1px solid #2a2a4a;
    border-radius: 6px;
    margin-top: 10px;
    padding-top: 6px;
    font-weight: bold;
    font-size: 11px;
    color: #8888aa;
}
QGroupBox::title {
    subcontrol-origin: margin;
    left: 8px;
    padding: 0 4px;
}
QLabel  { color: #b0b0c8; font-size: 11px; }
QCheckBox { color: #b0b0c8; font-size: 11px; }
QScrollArea { border: none; }
QStatusBar { background-color: #0d0d1a; color: #606080; font-size: 10px; }
QScrollBar:vertical {
    background-color: #0d0d1a;
    width: 14px;
    border: none;
    margin: 0px;
}
QScrollBar::handle:vertical {
    background-color: #2a2a4a;
    border-radius: 4px;
    min-height: 24px;
}
QScrollBar::handle:vertical:hover { background-color: #404060; }
QScrollBar::add-line:vertical, QScrollBar::sub-line:vertical { height: 0px; }
QScrollBar::add-page:vertical, QScrollBar::sub-page:vertical {
    background-color: #0d0d1a;
}
"""


# ---------------------------------------------------------------------------
# Artifact-corrector adapter classes
# ---------------------------------------------------------------------------

class _GEDAIWrapper:
    """Makes GEDAIDenoiser present the .transform(data) interface."""

    def __init__(self, gedai, n_noise: int) -> None:
        self._g = gedai
        self._n = n_noise

    def transform(self, data: np.ndarray) -> np.ndarray:
        idx = self._g.find_noise_components(self._n)
        return self._g.denoise(data, idx)


class _ORICAWrapper:
    """Online ORICA: adapts W on every chunk, suppresses highest-power ICs."""

    def __init__(self, orica, n_remove: int) -> None:
        self._o = orica
        self._n = max(1, int(n_remove))

    def transform(self, data: np.ndarray) -> np.ndarray:
        self._o.partial_fit(data)
        S = self._o.transform(data)
        rms = np.sqrt(np.mean(S ** 2, axis=1))
        noise_idx = np.argsort(rms)[-self._n:].tolist()
        S_clean = S.copy()
        S_clean[noise_idx] = 0.0
        return self._o.inverse_transform(S_clean)


# ---------------------------------------------------------------------------
# Event filter: intercepts wheel events on the plot viewport
# ---------------------------------------------------------------------------

class _WheelFilter(QObject):
    def __init__(self, callback, parent=None):
        super().__init__(parent)
        self._cb = callback

    def eventFilter(self, obj, event):
        if event.type() == QEvent.Type.Wheel:
            self._cb(1 if event.angleDelta().y() > 0 else -1)
            return True
        return False


# ---------------------------------------------------------------------------
# RawPlot
# ---------------------------------------------------------------------------

[docs] class RawPlot(QMainWindow): """Scrolling raw M/EEG channel viewer. Displays all channels stacked vertically in a dark-themed :class:`pyqtgraph.PlotWidget`, colour-coded by channel index. Channels are scrollable via the mouse wheel on the plot or the scrollbar on the right edge of the signal area. Right-clicking any channel name on the Y-axis immediately opens an MNE sensor-position plot for that channel. Parameters ---------- ch_names : list of str Channel names. One row is shown per channel. sfreq : float Sampling frequency in Hz. time_window : float, default 10.0 Visible time range in seconds at startup. n_shown : int, default 20 Number of channels visible simultaneously. scale_uv : float, default 100.0 Amplitude scale in µV. A signal of this peak amplitude occupies half the per-channel row height. For MEG (Tesla) pass ``scale_uv=1e-6`` (i.e., 1 pT per half-row). info : mne.Info | None, default None If provided: used to resolve channel types, apply SSP projectors, and show sensor positions on right-click. verbose : bool | str | None, default None Verbosity level. See :func:`~mne_rt._logging.set_log_level`. See Also -------- mne_rt.viz.NFPlot : Scrolling NF feature monitor. mne_rt.RTStream.record_main : Drives the raw display from the NF loop. Notes ----- The control panel (right sidebar) provides: * **Playback** — pause/resume, clear buffer, screenshot. * **Amplitude** — ÷2 / ×2 scale buttons with a live readout. * **Display** — time-window selector, grid toggle, DC-removal toggle. * **Filter** — online causal bandpass/highpass/lowpass/notch (scipy). Applied to new data from the moment "Apply filter" is clicked; data already in the buffer is not retroactively filtered. * **Artifact Correction** — LMS adaptive filter or ASR, applied from the moment "Apply from now" is clicked. LMS requires a reference channel; ASR calibrates on the current buffer content. * **SSP** — shown when ``info`` contains projectors; applied to new data from the moment the checkbox is ticked. Use the mouse wheel on the signal area or the vertical scrollbar to the right of the traces to page through channels. Right-click any channel label on the Y-axis to open its sensor-position diagram. .. versionadded:: 1.0.0 """
[docs] def __init__( self, ch_names: list[str], sfreq: float, time_window: float = 10.0, n_shown: int = 20, scale_uv: float = 100.0, info=None, verbose=None, ) -> None: from mne_rt._logging import set_log_level set_log_level(verbose) super().__init__() self._ch_names = list(ch_names) self._n_ch = len(ch_names) self._sfreq = float(sfreq) self._time_window = float(time_window) self._n_shown = min(int(n_shown), self._n_ch) self._scale = float(scale_uv) * 1e-6 self._info = info self._page_start = 0 self._paused = False self._dc_remove = False # Online causal filter — SOS coefficients + per-channel state vector # Both reset to None when the filter is changed or the buffer is cleared. self._filter_sos = None # ndarray (n_sections, 6) or None self._filter_zi = None # ndarray (n_ch, n_sections, 2) or None # SSP projector matrix (n_ch × n_ch), applied to new incoming chunks self._ssp_proj = None # Artifact corrector — object with a .transform(data) → data method self._corrector = None # Re-referencing — applied in push() after the corrector self._reref_type: str = "none" # "none", "average", "mastoid", "channel" self._reref_idx: int = 0 # index of the single reference channel self._reref_idxs: list[int] = [] # indices for multi-channel references # Bad channels — toggled by left-clicking the channel label self._bad_ch_idxs: set[int] = set() # Bad segments — marked by double-clicking on the signal canvas self._total_pushed: int = 0 # cumulative samples pushed self._bad_segs: list[tuple[float, float]] = [] # (abs_start_s, abs_end_s) self._bad_seg_overlays: list = [] # pg.LinearRegionItem objects on plot self._bad_seg_click1: float | None = None # absolute session time of first click self._bad_seg_start_line = None # pg.InfiniteLine shown while waiting for end self._bad_seg_start_line_on_plot: bool = False # Per-channel colours and types self._ch_types: list[str] = [] self._ch_colors: list[str] = [] self._resolve_colors() n_pts = max(int(sfreq * time_window), 30) self._time_axis = np.linspace(0.0, time_window, n_pts) self._buf = np.zeros((self._n_ch, n_pts)) # Thread-safe data queue: push() (background thread) queues processed # chunks here; _flush_data_queue() (main thread, 30 Hz) drains it. self._data_queue: deque = deque() # Riemannian Potato auto-bad-segment detection self._rp_detector = None # RiemannianPotatoDetector | None self._rp_active: bool = False self._rp_seg_samples: int = max(2, int(sfreq * 1.0)) # updated from spinbox self._rp_last_tested: int = 0 # abs sample idx of last tested window end pg.setConfigOptions(antialias=True, foreground="#c0c0d8", background="#0d0d1a") self._build_ui() self.setWindowTitle("MNE-RT — Raw") self.resize(1500, 720) # 30 Hz render timer — all Qt widget updates happen in the main thread. self._flush_timer = QTimer(self) self._flush_timer.setInterval(33) self._flush_timer.timeout.connect(self._flush_data_queue) self._flush_timer.start()
# ------------------------------------------------------------------ # Colour resolution # ------------------------------------------------------------------ def _resolve_colors(self) -> None: if self._info is not None: try: import mne self._ch_types = [ mne.channel_type(self._info, i) for i in range(self._n_ch) ] except Exception: self._ch_types = ["misc"] * self._n_ch else: self._ch_types = ["misc"] * self._n_ch # Distinct colour per channel by cycling the palette; ensures adjacent # channels are always distinguishable even when all share one type. self._ch_colors = [ _TRACE_COLORS[i % len(_TRACE_COLORS)] for i in range(self._n_ch) ] # ------------------------------------------------------------------ # UI construction # ------------------------------------------------------------------ def _build_ui(self) -> None: self.setStyleSheet(_QSS) central = QWidget() self.setCentralWidget(central) root = QHBoxLayout(central) root.setContentsMargins(8, 8, 4, 8) root.setSpacing(0) root.addWidget(self._build_plot_widget(), stretch=5) # Channel scrollbar — between the plot canvas and the control panel self._ch_scroll = QScrollBar(Qt.Orientation.Vertical) self._ch_scroll.setRange(0, max(0, self._n_ch - self._n_shown)) self._ch_scroll.setPageStep(self._n_shown) self._ch_scroll.setSingleStep(1) self._ch_scroll.setFixedWidth(14) self._ch_scroll.valueChanged.connect( lambda v: self._set_page_start(v, source="scrollbar") ) root.addWidget(self._ch_scroll) root.addSpacing(4) root.addWidget(self._build_control_panel(), stretch=0) self._status = self.statusBar() self._status.showMessage("Waiting for data …") def _build_plot_widget(self) -> pg.GraphicsLayoutWidget: glw = pg.GraphicsLayoutWidget() glw.setBackground("#0d0d1a") self._pi = glw.addPlot(row=0, col=0) self._pi.setMouseEnabled(x=False, y=False) vb = self._pi.getViewBox() vb.setMouseEnabled(x=False, y=False) vb.setMenuEnabled(False) self._pi.showGrid(x=True, y=False, alpha=0.25) for ax_name in ("left", "bottom"): ax = self._pi.getAxis(ax_name) ax.setPen(pg.mkPen("#303050")) ax.setTextPen(pg.mkPen("#9090aa")) self._pi.getAxis("left").setWidth(80) self._pi.setLabel("bottom", "Time", units="s", color="#9090aa") self._pi.setXRange(0.0, self._time_window, padding=0.01) self._pi.setYRange(-0.5, self._n_shown - 0.5, padding=0) self._curves: list[pg.PlotCurveItem] = [ pg.PlotCurveItem(pen=pg.mkPen(color=_TRACE_COLORS[0], width=1)) for _ in range(self._n_shown) ] for c in self._curves: self._pi.addItem(c) sep_pen = pg.mkPen(color=(70, 70, 110, 55), width=1, style=Qt.PenStyle.DotLine) for i in range(self._n_shown): self._pi.addItem(pg.InfiniteLine(pos=i, angle=0, pen=sep_pen)) self._update_tick_labels() self._wheel_filter = _WheelFilter(self._on_plot_wheel, self) glw.viewport().installEventFilter(self._wheel_filter) # Right-click on the Y-axis tick label → open sensor position plot glw.scene().sigMouseClicked.connect(self._on_scene_clicked) self._glw = glw return glw def _build_control_panel(self) -> QScrollArea: scroll = QScrollArea() scroll.setWidgetResizable(True) scroll.setFixedWidth(230) scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) panel = QWidget() layout = QVBoxLayout(panel) layout.setSpacing(8) layout.setContentsMargins(6, 6, 6, 6) layout.addWidget(self._grp_playback()) layout.addWidget(self._grp_amplitude()) layout.addWidget(self._grp_display()) layout.addWidget(self._grp_filter()) layout.addWidget(self._grp_reref()) layout.addWidget(self._grp_correction()) layout.addWidget(self._grp_bad_segs()) layout.addWidget(self._grp_potato()) layout.addStretch() scroll.setWidget(panel) return scroll # --- sidebar groups --- def _grp_playback(self) -> QGroupBox: grp = QGroupBox("Playback") lay = QVBoxLayout(grp) self._btn_pause = QPushButton("⏸ Pause") self._btn_pause.setCheckable(True) self._btn_pause.clicked.connect(self._toggle_pause) btn_clear = QPushButton("⟳ Clear") btn_clear.clicked.connect(self._clear) btn_shot = QPushButton("📷 Screenshot") btn_shot.clicked.connect(self._screenshot) for w in (self._btn_pause, btn_clear, btn_shot): lay.addWidget(w) return grp def _grp_amplitude(self) -> QGroupBox: grp = QGroupBox("Amplitude") lay = QVBoxLayout(grp) self._scale_lbl = QLabel(self._fmt_scale()) self._scale_lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) self._scale_lbl.setStyleSheet("color:#7ec8e3; font-size:12px; font-weight:bold;") row = QHBoxLayout() btn_dn = QPushButton("÷2") btn_up = QPushButton("×2") for btn in (btn_dn, btn_up): btn.setFixedSize(42, 26) btn_dn.clicked.connect(self._scale_down) btn_up.clicked.connect(self._scale_up) row.addStretch() row.addWidget(btn_dn) row.addWidget(btn_up) row.addStretch() lay.addWidget(self._scale_lbl) lay.addLayout(row) return grp def _grp_display(self) -> QGroupBox: grp = QGroupBox("Display") lay = QVBoxLayout(grp) row = QHBoxLayout() row.addWidget(QLabel("Time window:")) self._cmb_tw = QComboBox() for secs in _TIME_WINDOW_OPTIONS: self._cmb_tw.addItem(f"{secs} s", secs) best = min(_TIME_WINDOW_OPTIONS, key=lambda s: abs(s - self._time_window)) self._cmb_tw.setCurrentIndex(_TIME_WINDOW_OPTIONS.index(best)) self._cmb_tw.currentIndexChanged.connect(self._change_time_window) row.addWidget(self._cmb_tw) lay.addLayout(row) chk_grid = QCheckBox("Show grid") chk_grid.setChecked(True) chk_grid.toggled.connect( lambda on: self._pi.showGrid(x=on, y=False, alpha=0.25 if on else 0.0) ) lay.addWidget(chk_grid) chk_dc = QCheckBox("Remove DC") chk_dc.setChecked(False) chk_dc.toggled.connect(self._set_dc_remove) lay.addWidget(chk_dc) return grp def _grp_filter(self) -> QGroupBox: grp = QGroupBox("Filter") lay = QVBoxLayout(grp) lay.setSpacing(5) type_row = QHBoxLayout() type_row.addWidget(QLabel("Type:")) self._cmb_filter = QComboBox() self._cmb_filter.addItems(["None", "High-pass", "Low-pass", "Band-pass", "Notch"]) self._cmb_filter.currentIndexChanged.connect(self._on_filter_type_changed) type_row.addWidget(self._cmb_filter) lay.addLayout(type_row) self._flo_row = QWidget() flo_lay = QHBoxLayout(self._flo_row) flo_lay.setContentsMargins(0, 0, 0, 0) flo_lbl = QLabel("Lo cut:") flo_lbl.setFixedWidth(44) self._flo_spin = QDoubleSpinBox() self._flo_spin.setRange(0.1, 500.0) self._flo_spin.setValue(1.0) self._flo_spin.setSuffix(" Hz") self._flo_spin.setDecimals(1) flo_lay.addWidget(flo_lbl) flo_lay.addWidget(self._flo_spin) self._flo_row.setVisible(False) lay.addWidget(self._flo_row) self._fhi_row = QWidget() fhi_lay = QHBoxLayout(self._fhi_row) fhi_lay.setContentsMargins(0, 0, 0, 0) fhi_lbl = QLabel("Hi cut:") fhi_lbl.setFixedWidth(44) self._fhi_spin = QDoubleSpinBox() self._fhi_spin.setRange(0.1, 500.0) self._fhi_spin.setValue(40.0) self._fhi_spin.setSuffix(" Hz") self._fhi_spin.setDecimals(1) fhi_lay.addWidget(fhi_lbl) fhi_lay.addWidget(self._fhi_spin) self._fhi_row.setVisible(False) lay.addWidget(self._fhi_row) self._fnotch_row = QWidget() fnotch_lay = QHBoxLayout(self._fnotch_row) fnotch_lay.setContentsMargins(0, 0, 0, 0) fnotch_lbl = QLabel("Freq:") fnotch_lbl.setFixedWidth(44) self._fnotch_spin = QDoubleSpinBox() self._fnotch_spin.setRange(1.0, 500.0) self._fnotch_spin.setValue(50.0) self._fnotch_spin.setSuffix(" Hz") self._fnotch_spin.setDecimals(1) fnotch_lay.addWidget(fnotch_lbl) fnotch_lay.addWidget(self._fnotch_spin) self._fnotch_row.setVisible(False) lay.addWidget(self._fnotch_row) btn_apply = QPushButton("Apply filter") btn_apply.setStyleSheet( "background:#132744; color:#80d8ff; border:1px solid #2a6090;" "border-radius:4px; padding:4px; font-size:11px;" ) btn_apply.clicked.connect(self._apply_filter_settings) lay.addWidget(btn_apply) self._filter_status = QLabel("○ No filter") self._filter_status.setStyleSheet("color:#505070; font-size:10px;") self._filter_status.setAlignment(Qt.AlignmentFlag.AlignCenter) lay.addWidget(self._filter_status) # SSP — shown only when info has projectors if self._info is not None: try: projs = self._info.get("projs", []) if projs: n = len(projs) self._ssp_chk = QCheckBox(f"Apply SSP ({n} proj.)") self._ssp_chk.setChecked(False) self._ssp_chk.toggled.connect(self._toggle_ssp) lay.addWidget(self._ssp_chk) except Exception: pass return grp def _grp_reref(self) -> QGroupBox: grp = QGroupBox("Re-reference") lay = QVBoxLayout(grp) lay.setSpacing(5) type_row = QHBoxLayout() type_row.addWidget(QLabel("Ref:")) self._cmb_reref = QComboBox() self._cmb_reref.addItems( ["None", "Average", "Mastoid (TP9/TP10)", "Linked Mastoid", "Channel"] ) self._cmb_reref.currentIndexChanged.connect(self._on_reref_type_changed) type_row.addWidget(self._cmb_reref) lay.addLayout(type_row) # Single-channel reference selector (shown only for "Channel") self._reref_ch_row = QWidget() reref_ch_lay = QHBoxLayout(self._reref_ch_row) reref_ch_lay.setContentsMargins(0, 0, 0, 0) reref_ch_lbl = QLabel("Channel:") reref_ch_lbl.setFixedWidth(54) self._reref_ch_cmb = QComboBox() self._reref_ch_cmb.addItems(self._ch_names) self._reref_ch_cmb.setMaxVisibleItems(12) reref_ch_lay.addWidget(reref_ch_lbl) reref_ch_lay.addWidget(self._reref_ch_cmb) self._reref_ch_row.setVisible(False) lay.addWidget(self._reref_ch_row) btn_reref = QPushButton("Apply from now") btn_reref.setStyleSheet( "background:#132744; color:#80d8ff; border:1px solid #2a6090;" "border-radius:4px; padding:4px; font-size:11px;" ) btn_reref.clicked.connect(self._apply_reref_settings) lay.addWidget(btn_reref) self._reref_status = QLabel("○ No re-reference") self._reref_status.setStyleSheet("color:#505070; font-size:10px;") self._reref_status.setAlignment(Qt.AlignmentFlag.AlignCenter) lay.addWidget(self._reref_status) return grp def _grp_correction(self) -> QGroupBox: grp = QGroupBox("Artifact Correction") lay = QVBoxLayout(grp) lay.setSpacing(5) type_row = QHBoxLayout() type_row.addWidget(QLabel("Method:")) self._cmb_corr = QComboBox() self._cmb_corr.addItems(["None", "LMS", "ASR", "GEDAI", "ORICA", "Maxwell"]) self._cmb_corr.currentIndexChanged.connect(self._on_corr_type_changed) type_row.addWidget(self._cmb_corr) lay.addLayout(type_row) # LMS — reference channel by name self._lms_ref_row = QWidget() lms_lay = QHBoxLayout(self._lms_ref_row) lms_lay.setContentsMargins(0, 0, 0, 0) lms_lbl = QLabel("Ref ch:") lms_lbl.setFixedWidth(44) self._lms_ref_cmb = QComboBox() self._lms_ref_cmb.addItems(self._ch_names) self._lms_ref_cmb.setMaxVisibleItems(12) self._lms_ref_cmb.setToolTip( "Reference (EOG/ECG) channel used by LMS. " "Pick the channel that best captures the artifact." ) lms_lay.addWidget(lms_lbl) lms_lay.addWidget(self._lms_ref_cmb) self._lms_ref_row.setVisible(False) lay.addWidget(self._lms_ref_row) # ASR — cutoff threshold self._asr_cut_row = QWidget() asr_lay = QHBoxLayout(self._asr_cut_row) asr_lay.setContentsMargins(0, 0, 0, 0) asr_lbl = QLabel("Cutoff σ:") asr_lbl.setFixedWidth(52) self._asr_cut_spin = QDoubleSpinBox() self._asr_cut_spin.setRange(2.0, 20.0) self._asr_cut_spin.setValue(5.0) self._asr_cut_spin.setDecimals(1) self._asr_cut_spin.setToolTip( "ASR rejection threshold in multiples of clean-data RMS.\n" "Lower = more aggressive (3–4); higher = more conservative (8–10)." ) asr_lay.addWidget(asr_lbl) asr_lay.addWidget(self._asr_cut_spin) self._asr_cut_row.setVisible(False) lay.addWidget(self._asr_cut_row) # GEDAI — target band + number of noise components to remove self._gedai_row = QWidget() gedai_v = QVBoxLayout(self._gedai_row) gedai_v.setContentsMargins(0, 0, 0, 0) gedai_v.setSpacing(3) band_row = QHBoxLayout() band_row.addWidget(QLabel("Band:")) self._gedai_lo_spin = QDoubleSpinBox() self._gedai_lo_spin.setRange(0.1, 200.0) self._gedai_lo_spin.setValue(1.0) self._gedai_lo_spin.setSuffix(" Hz") self._gedai_lo_spin.setDecimals(1) self._gedai_lo_spin.setFixedWidth(70) band_row.addWidget(self._gedai_lo_spin) band_row.addWidget(QLabel("–")) self._gedai_hi_spin = QDoubleSpinBox() self._gedai_hi_spin.setRange(1.0, 500.0) self._gedai_hi_spin.setValue(40.0) self._gedai_hi_spin.setSuffix(" Hz") self._gedai_hi_spin.setDecimals(1) self._gedai_hi_spin.setFixedWidth(70) band_row.addWidget(self._gedai_hi_spin) gedai_v.addLayout(band_row) noise_row = QHBoxLayout() noise_row.addWidget(QLabel("Remove:")) self._gedai_noise_spin = QSpinBox() self._gedai_noise_spin.setRange(1, max(1, self._n_ch - 1)) self._gedai_noise_spin.setValue(1) self._gedai_noise_spin.setSuffix(" comps") self._gedai_noise_spin.setToolTip( "Number of lowest-eigenvalue GEDAI components to suppress.\n" "These capture the least band-specific (artifact) activity." ) noise_row.addWidget(self._gedai_noise_spin) gedai_v.addLayout(noise_row) self._gedai_row.setVisible(False) lay.addWidget(self._gedai_row) # ORICA — number of ICs to suppress self._orica_row = QWidget() orica_lay = QHBoxLayout(self._orica_row) orica_lay.setContentsMargins(0, 0, 0, 0) orica_lbl = QLabel("Remove:") orica_lbl.setFixedWidth(50) self._orica_n_spin = QSpinBox() self._orica_n_spin.setRange(1, max(1, self._n_ch - 1)) self._orica_n_spin.setValue(1) self._orica_n_spin.setSuffix(" ICs") self._orica_n_spin.setToolTip( "Number of highest-power ICs to suppress per chunk.\n" "Artifacts are typically the highest-power independent components." ) orica_lay.addWidget(orica_lbl) orica_lay.addWidget(self._orica_n_spin) self._orica_row.setVisible(False) lay.addWidget(self._orica_row) btn_apply_corr = QPushButton("Apply from now") btn_apply_corr.setStyleSheet( "background:#132744; color:#80d8ff; border:1px solid #2a6090;" "border-radius:4px; padding:4px; font-size:11px;" ) btn_apply_corr.clicked.connect(self._apply_correction) lay.addWidget(btn_apply_corr) self._corr_status = QLabel("○ No correction") self._corr_status.setStyleSheet("color:#505070; font-size:10px;") self._corr_status.setAlignment(Qt.AlignmentFlag.AlignCenter) lay.addWidget(self._corr_status) return grp def _grp_bad_segs(self) -> QGroupBox: grp = QGroupBox("Bad Segments") lay = QVBoxLayout(grp) lay.setSpacing(5) hint = QLabel("Double-click signal to set\nstart, then end of bad segment") hint.setWordWrap(True) hint.setStyleSheet("color:#7a7a9a; font-size:9px;") lay.addWidget(hint) self._bad_seg_status_lbl = QLabel("Ready") self._bad_seg_status_lbl.setStyleSheet("color:#505070; font-size:10px;") self._bad_seg_status_lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) lay.addWidget(self._bad_seg_status_lbl) self._bad_seg_count_lbl = QLabel("No bad segments") self._bad_seg_count_lbl.setStyleSheet("color:#505070; font-size:10px;") self._bad_seg_count_lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) lay.addWidget(self._bad_seg_count_lbl) btn_clear_bad = QPushButton("Clear all bad segments") btn_clear_bad.clicked.connect(self._clear_bad_segs) lay.addWidget(btn_clear_bad) return grp def _grp_potato(self) -> QGroupBox: grp = QGroupBox("Auto Bad Seg (Riemann)") lay = QVBoxLayout(grp) lay.setSpacing(5) # Segment length seg_row = QHBoxLayout() seg_row.addWidget(QLabel("Seg:")) self._rp_seg_spin = QDoubleSpinBox() self._rp_seg_spin.setRange(0.2, 5.0) self._rp_seg_spin.setValue(1.0) self._rp_seg_spin.setSuffix(" s") self._rp_seg_spin.setDecimals(1) self._rp_seg_spin.setSingleStep(0.1) self._rp_seg_spin.setToolTip( "Window length (seconds) for covariance estimation.\n" "Longer = more reliable covariance; shorter = finer time resolution." ) seg_row.addWidget(self._rp_seg_spin) lay.addLayout(seg_row) # Z-threshold thr_row = QHBoxLayout() thr_row.addWidget(QLabel("Z-thr:")) self._rp_thr_spin = QDoubleSpinBox() self._rp_thr_spin.setRange(1.0, 6.0) self._rp_thr_spin.setValue(3.0) self._rp_thr_spin.setDecimals(1) self._rp_thr_spin.setSingleStep(0.1) self._rp_thr_spin.setToolTip( "Z-score threshold for declaring a segment as bad.\n" "Lower = more aggressive (more rejections)." ) thr_row.addWidget(self._rp_thr_spin) lay.addLayout(thr_row) # Calibrate button btn_cal = QPushButton("Calibrate on buffer") btn_cal.setStyleSheet( "background:#132744; color:#80d8ff; border:1px solid #2a6090;" "border-radius:4px; padding:4px; font-size:11px;" ) btn_cal.setToolTip( "Segment the current buffer into clean windows and fit\n" "the Riemannian Potato. Run with artifact-free data." ) btn_cal.clicked.connect(self._calibrate_potato) lay.addWidget(btn_cal) # Active checkbox self._rp_chk = QCheckBox("Active (auto-detect)") self._rp_chk.setChecked(False) self._rp_chk.setEnabled(False) # enabled after calibration self._rp_chk.toggled.connect(self._toggle_potato_active) lay.addWidget(self._rp_chk) # Status label self._rp_status_lbl = QLabel("Not calibrated") self._rp_status_lbl.setWordWrap(True) self._rp_status_lbl.setStyleSheet("color:#505070; font-size:10px;") self._rp_status_lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) lay.addWidget(self._rp_status_lbl) return grp # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _fmt_scale(self) -> str: uv = self._scale * 1e6 if uv >= 1000: return f"{uv / 1000:.4g} mV / row" if uv >= 1: return f"{uv:.4g} µV / row" return f"{uv * 1000:.4g} nV / row" def _update_tick_labels(self) -> None: end = min(self._page_start + self._n_shown, self._n_ch) n_actual = end - self._page_start ticks = [] for i in range(n_actual): ch_idx = self._page_start + i name = self._ch_names[ch_idx] label = f"✕ {name}" if ch_idx in self._bad_ch_idxs else name ticks.append((n_actual - 1 - i, label)) self._pi.getAxis("left").setTicks([ticks, []]) def _set_page_start(self, new_start: int, source: str = "other") -> None: new_start = max(0, min(new_start, max(0, self._n_ch - self._n_shown))) if new_start == self._page_start: return self._page_start = new_start self._update_tick_labels() if source != "scrollbar": self._ch_scroll.blockSignals(True) self._ch_scroll.setValue(new_start) self._ch_scroll.blockSignals(False) self._redraw() def _reset_online_state(self) -> None: """Reset filter and corrector state when the buffer is cleared.""" self._filter_zi = None # ------------------------------------------------------------------ # Callbacks — scroll / wheel # ------------------------------------------------------------------ def _on_plot_wheel(self, direction: int) -> None: step = max(1, self._n_shown // 4) self._set_page_start(self._page_start - direction * step, source="wheel") # ------------------------------------------------------------------ # Callbacks — Y-axis right-click → sensor location # ------------------------------------------------------------------ def _on_scene_clicked(self, event) -> None: pos = event.scenePos() btn = event.button() axis = self._pi.getAxis("left") vb = self._pi.getViewBox() # ── Y-axis clicks ────────────────────────────────────────────── if axis.sceneBoundingRect().contains(pos): y_val = vb.mapSceneToView(pos).y() end = min(self._page_start + self._n_shown, self._n_ch) n_actual = end - self._page_start vis_idx = int(round(n_actual - 1 - y_val)) if 0 <= vis_idx < n_actual: ch_idx = self._page_start + vis_idx if btn == Qt.MouseButton.RightButton: self._show_channel_location(self._ch_names[ch_idx]) elif btn == Qt.MouseButton.LeftButton and not event.double(): # Single-click only; defer to keep Qt widget calls in main thread QTimer.singleShot(0, lambda idx=ch_idx: self._toggle_bad_channel(idx)) return # ── Signal-area double-click: bad-segment marking ────────────── if btn == Qt.MouseButton.LeftButton and event.double(): if vb.sceneBoundingRect().contains(pos): x_val = vb.mapSceneToView(pos).x() if 0.0 <= x_val <= self._time_window: self._on_bad_seg_click(x_val) def _show_channel_location(self, ch_name: str) -> None: if self._info is None: self._status.showMessage(f"No Info available — cannot show position for {ch_name}") return try: import mne import matplotlib.pyplot as plt fig = mne.viz.plot_sensors( self._info, show_names=True, title=f"Sensor position — {ch_name}", show=False, ) # Highlight the right-clicked channel in red for ax in fig.axes: for txt in ax.texts: if txt.get_text() == ch_name: txt.set_color("#cc0000") txt.set_fontsize(10) txt.set_fontweight("bold") plt.show(block=False) except Exception as exc: self._status.showMessage(f"Could not show sensor position: {exc}") # ------------------------------------------------------------------ # Callbacks — playback / display # ------------------------------------------------------------------ def _toggle_pause(self, checked: bool) -> None: self._paused = checked self._btn_pause.setText("▶ Resume" if checked else "⏸ Pause") def _clear(self) -> None: self._buf[:] = 0.0 self._reset_online_state() self._redraw() def _screenshot(self) -> None: from PyQt6.QtWidgets import QFileDialog ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") default = str(Path.home() / f"raw_plot_{ts}.png") path, _ = QFileDialog.getSaveFileName( self, "Save Screenshot", default, "PNG Image (*.png)" ) if not path: return exp = pg.exporters.ImageExporter(self._glw.scene()) exp.parameters()["width"] = 1920 exp.export(path) def _scale_up(self) -> None: self._scale *= 2.0 self._scale_lbl.setText(self._fmt_scale()) self._redraw() def _scale_down(self) -> None: self._scale /= 2.0 self._scale_lbl.setText(self._fmt_scale()) self._redraw() def _set_dc_remove(self, checked: bool) -> None: self._dc_remove = checked self._redraw() def _change_time_window(self, idx: int) -> None: secs = float(self._cmb_tw.itemData(idx)) self._time_window = secs n_pts = max(int(self._sfreq * secs), 30) self._time_axis = np.linspace(0.0, secs, n_pts) self._buf = np.zeros((self._n_ch, n_pts)) self._reset_online_state() self._pi.setXRange(0.0, secs, padding=0.01) # ------------------------------------------------------------------ # Callbacks — filter # ------------------------------------------------------------------ def _on_filter_type_changed(self, idx: int) -> None: ftypes = ["none", "highpass", "lowpass", "bandpass", "notch"] ft = ftypes[idx] self._flo_row.setVisible(ft in ("highpass", "bandpass")) self._fhi_row.setVisible(ft in ("lowpass", "bandpass")) self._fnotch_row.setVisible(ft == "notch") def _apply_filter_settings(self) -> None: ftypes = ["none", "highpass", "lowpass", "bandpass", "notch"] ft = ftypes[self._cmb_filter.currentIndex()] if ft == "none": self._filter_sos = None self._filter_zi = None self._filter_status.setText("○ No filter") self._filter_status.setStyleSheet("color:#505070; font-size:10px;") return try: from scipy.signal import butter, iirnotch, sosfilt, sosfilt_zi, tf2sos nyq = self._sfreq / 2.0 flo = self._flo_spin.value() fhi = self._fhi_spin.value() fn = self._fnotch_spin.value() if ft == "highpass": sos = butter(4, flo / nyq, btype="high", output="sos") label = f"● HP {flo:.1f} Hz" elif ft == "lowpass": sos = butter(4, fhi / nyq, btype="low", output="sos") label = f"● LP {fhi:.1f} Hz" elif ft == "bandpass": sos = butter(4, [flo / nyq, fhi / nyq], btype="band", output="sos") label = f"● BP {flo:.1f}{fhi:.1f} Hz" else: # notch b, a = iirnotch(fn, Q=30, fs=self._sfreq) sos = tf2sos(b, a) label = f"● Notch {fn:.1f} Hz" # Prime the per-channel state from the current buffer tail # (1 s worth of samples) to avoid a transient at the start. zi_base = sosfilt_zi(sos) # (n_sections, 2) self._filter_zi = np.zeros( (self._n_ch, zi_base.shape[0], zi_base.shape[1]) ) n_prime = min(int(self._sfreq), self._buf.shape[1]) if n_prime > 0: for ch in range(self._n_ch): _, self._filter_zi[ch] = sosfilt( sos, self._buf[ch, -n_prime:], zi=zi_base.copy() ) self._filter_sos = sos self._filter_status.setText(label) self._filter_status.setStyleSheet( "color:#80d8ff; font-size:10px; font-weight:bold;" ) except Exception as exc: self._filter_sos = None self._filter_zi = None self._filter_status.setText(f"Error: {exc}") self._filter_status.setStyleSheet("color:#ff8080; font-size:10px;") def _toggle_ssp(self, checked: bool) -> None: if not checked or self._info is None: self._ssp_proj = None return try: import mne active = list(self._info.get("projs", [])) proj, _, _ = mne.make_projector(active, self._info["ch_names"]) self._ssp_proj = proj except Exception: self._ssp_proj = None # ------------------------------------------------------------------ # Callbacks — re-referencing # ------------------------------------------------------------------ def _on_reref_type_changed(self, idx: int) -> None: methods = ["none", "average", "mastoid", "linked_mastoid", "channel"] self._reref_ch_row.setVisible(methods[idx] == "channel") def _apply_reref_settings(self) -> None: methods = ["none", "average", "mastoid", "linked_mastoid", "channel"] method = methods[self._cmb_reref.currentIndex()] if method == "none": self._reref_type = "none" self._reref_status.setText("○ No re-reference") self._reref_status.setStyleSheet("color:#505070; font-size:10px;") return try: if method == "average": self._reref_type = "average" label = "● Average ref" elif method == "mastoid": idxs = [ self._ch_names.index(n) for n in ("TP9", "TP10") if n in self._ch_names ] if not idxs: raise RuntimeError("No TP9 or TP10 channels found.") self._reref_type = "mastoid" self._reref_idxs = idxs found = [self._ch_names[i] for i in idxs] label = f"● Mastoid ({'+'.join(found)})" elif method == "linked_mastoid": candidates = ("TP9", "TP10", "M1", "M2", "A1", "A2") idxs = [ self._ch_names.index(n) for n in candidates if n in self._ch_names ] if not idxs: raise RuntimeError( "No mastoid channels found (tried TP9, TP10, M1, M2, A1, A2)." ) self._reref_type = "mastoid" self._reref_idxs = idxs found = [self._ch_names[i] for i in idxs] label = f"● Linked mastoid ({'+'.join(found)})" else: # channel ref_name = self._reref_ch_cmb.currentText() self._reref_type = "channel" self._reref_idx = self._ch_names.index(ref_name) label = f"● Ref: {ref_name}" self._reref_status.setText(label) self._reref_status.setStyleSheet( "color:#80d8ff; font-size:10px; font-weight:bold;" ) except Exception as exc: self._reref_type = "none" self._reref_status.setText(f"Error: {exc}") self._reref_status.setStyleSheet("color:#ff8080; font-size:10px;") # ------------------------------------------------------------------ # Callbacks — artifact correction # ------------------------------------------------------------------ def _on_corr_type_changed(self, idx: int) -> None: methods = ["none", "lms", "asr", "gedai", "orica", "maxwell"] method = methods[idx] self._lms_ref_row.setVisible(method == "lms") self._asr_cut_row.setVisible(method == "asr") self._gedai_row.setVisible(method == "gedai") self._orica_row.setVisible(method == "orica") def _apply_correction(self) -> None: methods = ["none", "lms", "asr", "gedai", "orica", "maxwell"] method = methods[self._cmb_corr.currentIndex()] self._corrector = None if method == "none": self._corr_status.setText("○ No correction") self._corr_status.setStyleSheet("color:#505070; font-size:10px;") return try: if method == "lms": from mne_rt.tools.lms import AdaptiveLMSFilter ref_name = self._lms_ref_cmb.currentText() ref_idx = self._ch_names.index(ref_name) self._corrector = AdaptiveLMSFilter(ref_ch_idx=ref_idx) label = f"● LMS (ref: {ref_name})" elif method == "asr": from mne_rt.tools.asr import ASRDenoiser cutoff = self._asr_cut_spin.value() nonzero_cols = np.any(self._buf != 0.0, axis=0) n_valid = int(nonzero_cols.sum()) min_needed = max(int(self._sfreq), 30) if n_valid < min_needed: raise RuntimeError( f"Not enough data in buffer ({n_valid} samples < {min_needed}). " "Wait for more data and try again." ) asr = ASRDenoiser(cutoff=cutoff) asr.fit(self._buf[:, nonzero_cols], self._sfreq) self._corrector = asr label = f"● ASR (cutoff={cutoff:.1f}σ)" elif method == "gedai": from mne_rt.tools.gedai import GEDAIDenoiser nonzero_cols = np.any(self._buf != 0.0, axis=0) n_valid = int(nonzero_cols.sum()) min_needed = max(int(self._sfreq), 30) if n_valid < min_needed: raise RuntimeError( f"Not enough data in buffer ({n_valid} samples < {min_needed}). " "Wait for more data and try again." ) lo = self._gedai_lo_spin.value() hi = self._gedai_hi_spin.value() n_noise = self._gedai_noise_spin.value() gedai = GEDAIDenoiser(n_channels=self._n_ch) gedai.fit_from_raw( self._buf[:, nonzero_cols], self._sfreq, band=(lo, hi) ) self._corrector = _GEDAIWrapper(gedai, n_noise) label = f"● GEDAI ({lo:.0f}{hi:.0f} Hz, rm {n_noise})" elif method == "orica": from mne_rt.tools.orica import ORICA n_remove = self._orica_n_spin.value() self._corrector = _ORICAWrapper(ORICA(n_channels=self._n_ch), n_remove) label = ( f"● ORICA (rm {n_remove} IC{'s' if n_remove > 1 else ''})" ) elif method == "maxwell": if self._info is None: raise RuntimeError( "Maxwell filter requires mne.Info. " "Pass info= when constructing RawPlot." ) from mne_rt.tools.maxwell import RTMaxwellFilter mf = RTMaxwellFilter() mf.fit(self._info) self._corrector = mf label = "● Maxwell SSS" self._corr_status.setText(label) self._corr_status.setStyleSheet( "color:#80d8ff; font-size:10px; font-weight:bold;" ) except Exception as exc: self._corrector = None self._corr_status.setText(f"Error: {exc}") self._corr_status.setStyleSheet("color:#ff8080; font-size:10px;") # ------------------------------------------------------------------ # Bad channels # ------------------------------------------------------------------ def _toggle_bad_channel(self, ch_idx: int) -> None: if ch_idx in self._bad_ch_idxs: self._bad_ch_idxs.discard(ch_idx) else: self._bad_ch_idxs.add(ch_idx) # Do NOT write to self._info["bads"] here: mne_lsl's get_data() uses # exclude="bads" by default, so syncing bads into the shared Info object # causes the stream to return one fewer channel on the next get_data() # call, which breaks the circular buffer write in _flush_data_queue. bads = sorted(self._bad_ch_idxs) if bads: self._status.showMessage( f"Bad channels: {', '.join(self._ch_names[i] for i in bads)}" ) else: self._status.showMessage("No bad channels marked") self._update_tick_labels() self._redraw() # ------------------------------------------------------------------ # Bad segments # ------------------------------------------------------------------ def _on_bad_seg_click(self, x_val: float) -> None: # Convert plot x-coordinate to absolute session time immediately, so # that a scrolling buffer between click 1 and click 2 doesn't shift it. buf_size = self._buf.shape[1] buf_start_s = max(0.0, (self._total_pushed - buf_size) / self._sfreq) abs_time = buf_start_s + x_val if self._bad_seg_click1 is None: # First double-click — store absolute start time self._bad_seg_click1 = abs_time self._bad_seg_status_lbl.setText(f"Start: {abs_time:.2f} s") self._bad_seg_status_lbl.setStyleSheet( "color:#ff8a65; font-size:10px; font-weight:bold;" ) self._status.showMessage( f"Bad segment start: {abs_time:.2f} s — double-click to set end" ) self._redraw() # redraws start indicator via _update_bad_seg_overlays else: # Second double-click — finalize segment abs_start = self._bad_seg_click1 abs_end = abs_time self._bad_seg_click1 = None if self._bad_seg_start_line is not None: if self._bad_seg_start_line_on_plot: try: self._pi.removeItem(self._bad_seg_start_line) except Exception: pass self._bad_seg_start_line = None self._bad_seg_start_line_on_plot = False if abs_end < abs_start: abs_start, abs_end = abs_end, abs_start self._bad_segs.append((abs_start, abs_end)) n = len(self._bad_segs) self._bad_seg_count_lbl.setText( f"{n} bad segment{'s' if n > 1 else ''}" ) self._bad_seg_count_lbl.setStyleSheet("color:#ff8a65; font-size:10px;") self._bad_seg_status_lbl.setText("Ready") self._bad_seg_status_lbl.setStyleSheet("color:#505070; font-size:10px;") self._status.showMessage( f"Bad segment added: {abs_start:.2f}{abs_end:.2f} s" ) self._redraw() def _update_bad_seg_overlays(self) -> None: for region in self._bad_seg_overlays: self._pi.removeItem(region) self._bad_seg_overlays.clear() buf_size = self._buf.shape[1] buf_start_s = max(0.0, (self._total_pushed - buf_size) / self._sfreq) # ── pending start indicator ──────────────────────────────────── def _remove_start_line() -> None: if self._bad_seg_start_line_on_plot and self._bad_seg_start_line is not None: try: self._pi.removeItem(self._bad_seg_start_line) except Exception: pass self._bad_seg_start_line_on_plot = False if self._bad_seg_click1 is not None: x_ind = self._bad_seg_click1 - buf_start_s if 0.0 <= x_ind <= self._time_window: if self._bad_seg_start_line is None: self._bad_seg_start_line = pg.InfiniteLine( pos=x_ind, angle=90, pen=pg.mkPen(color="#ff8a65", width=2, style=Qt.PenStyle.DashLine), ) else: self._bad_seg_start_line.setPos(x_ind) if not self._bad_seg_start_line_on_plot: self._pi.addItem(self._bad_seg_start_line) self._bad_seg_start_line_on_plot = True else: _remove_start_line() # scrolled off-screen; keep click1 state else: _remove_start_line() self._bad_seg_start_line = None # ── completed bad-segment regions ────────────────────────────── for (abs_start, abs_end) in self._bad_segs: x_s = abs_start - buf_start_s x_e = abs_end - buf_start_s if x_e < 0 or x_s > self._time_window: continue x_s = max(0.0, x_s) x_e = min(self._time_window, x_e) region = pg.LinearRegionItem( values=(x_s, x_e), brush=pg.mkBrush(200, 50, 50, 45), movable=False, pen=pg.mkPen(color="#cc3333", width=1), ) self._pi.addItem(region) self._bad_seg_overlays.append(region) def _clear_bad_segs(self) -> None: self._bad_segs.clear() for region in self._bad_seg_overlays: self._pi.removeItem(region) self._bad_seg_overlays.clear() self._bad_seg_click1 = None if self._bad_seg_start_line is not None: if self._bad_seg_start_line_on_plot: try: self._pi.removeItem(self._bad_seg_start_line) except Exception: pass self._bad_seg_start_line = None self._bad_seg_start_line_on_plot = False self._bad_seg_count_lbl.setText("No bad segments") self._bad_seg_count_lbl.setStyleSheet("color:#505070; font-size:10px;") self._bad_seg_status_lbl.setText("Ready") self._bad_seg_status_lbl.setStyleSheet("color:#505070; font-size:10px;") # ------------------------------------------------------------------ # Riemannian Potato auto-detection # ------------------------------------------------------------------ def _calibrate_potato(self) -> None: seg_s = self._rp_seg_spin.value() z_thr = self._rp_thr_spin.value() n_seg = max(2, int(self._sfreq * seg_s)) n_wins = self._buf.shape[1] // n_seg if n_wins < 3: self._rp_status_lbl.setText( f"Buffer too short: need ≥3 × {seg_s:.1f} s windows. " "Increase time window or wait for more data." ) self._rp_status_lbl.setStyleSheet("color:#ff8a65; font-size:10px;") return try: from mne_rt.tools import RiemannianPotatoDetector windows = np.stack([ self._buf[:, i * n_seg:(i + 1) * n_seg] for i in range(n_wins) ]) # (n_wins, n_ch, n_seg) det = RiemannianPotatoDetector(threshold=z_thr) det.fit(windows) self._rp_detector = det self._rp_seg_samples = n_seg self._rp_last_tested = self._total_pushed # start detecting from now self._rp_chk.setEnabled(True) self._rp_status_lbl.setText( f"✓ Calibrated\n{n_wins} windows · z>{z_thr:.1f}" ) self._rp_status_lbl.setStyleSheet("color:#69f0ae; font-size:10px;") self._status.showMessage( f"Potato calibrated on {n_wins} windows ({seg_s:.1f} s each)" ) except Exception as exc: self._rp_detector = None self._rp_chk.setEnabled(False) self._rp_status_lbl.setText(f"Error: {exc}") self._rp_status_lbl.setStyleSheet("color:#ff8a65; font-size:10px;") def _toggle_potato_active(self, checked: bool) -> None: self._rp_active = checked if checked: self._rp_last_tested = self._total_pushed self._rp_status_lbl.setStyleSheet( "color:#69f0ae; font-size:10px; font-weight:bold;" ) else: self._rp_status_lbl.setStyleSheet("color:#69f0ae; font-size:10px;") def _run_potato_detection(self) -> None: """Test any newly completed windows against the fitted potato.""" if not self._rp_active or self._rp_detector is None: return n_seg = self._rp_seg_samples buf_size = self._buf.shape[1] buf_start_abs = self._total_pushed - buf_size # absolute sample of buf[:,0] added = False while (self._total_pushed - self._rp_last_tested) >= n_seg: win_start_abs = self._rp_last_tested win_end_abs = win_start_abs + n_seg # Locate in buffer s = win_start_abs - buf_start_abs e = win_end_abs - buf_start_abs self._rp_last_tested = win_end_abs if s < 0 or e > buf_size: continue # window fell outside the buffer (edge case at startup) window = self._buf[:, s:e] try: is_clean, z_score = self._rp_detector.detect(window) except Exception: continue if not is_clean: abs_start = win_start_abs / self._sfreq abs_end = win_end_abs / self._sfreq self._bad_segs.append((abs_start, abs_end)) added = True if added: n = len(self._bad_segs) self._bad_seg_count_lbl.setText( f"{n} bad segment{'s' if n > 1 else ''}" ) self._bad_seg_count_lbl.setStyleSheet("color:#ff8a65; font-size:10px;") # ------------------------------------------------------------------ # Redraw — purely displays whatever is in _buf # ------------------------------------------------------------------ def _redraw(self) -> None: end = min(self._page_start + self._n_shown, self._n_ch) visible = list(range(self._page_start, end)) n_actual = len(visible) gain = 1.0 / (self._scale + 1e-300) for vis_idx, ch_idx in enumerate(visible): raw = self._buf[ch_idx].copy() if self._dc_remove: nz = raw[raw != 0] if nz.size > 0: raw -= nz.mean() is_bad = ch_idx in self._bad_ch_idxs color = "#505050" if is_bad else self._ch_colors[ch_idx] width = 1 self._curves[vis_idx].setPen(pg.mkPen(color=color, width=width)) offset = float(n_actual - 1 - vis_idx) self._curves[vis_idx].setData(self._time_axis, offset + raw * gain) for vis_idx in range(n_actual, self._n_shown): self._curves[vis_idx].setData([], []) self._update_bad_seg_overlays() # ------------------------------------------------------------------ # Public interface # ------------------------------------------------------------------
[docs] def push(self, data: np.ndarray) -> None: """Append a chunk of raw data and refresh the display. All active online processing (filter, SSP, artifact correction) is applied to the incoming *chunk* before it enters the circular buffer. Data already in the buffer is never retroactively modified — the display transitions from unprocessed to processed as new data arrives. Parameters ---------- data : ndarray, shape (n_channels, n_samples) New raw data chunk. Notes ----- The call is a no-op when the plot is paused. """ if self._paused: return if data.shape[0] != self._n_ch: return chunk = data.copy() # 1. Online causal filter (sosfilt with persistent state) if self._filter_sos is not None and self._filter_zi is not None: try: from scipy.signal import sosfilt for ch in range(self._n_ch): chunk[ch], self._filter_zi[ch] = sosfilt( self._filter_sos, chunk[ch], zi=self._filter_zi[ch] ) except Exception: pass # 2. SSP projection if self._ssp_proj is not None: try: chunk = self._ssp_proj @ chunk except Exception: pass # 3. Artifact corrector (.transform must preserve shape) if self._corrector is not None: try: chunk = self._corrector.transform(chunk) except Exception: pass # 4. Re-referencing (spatial; subtract reference signal from every channel) try: if self._reref_type == "average": chunk = chunk - chunk.mean(axis=0) elif self._reref_type == "mastoid" and self._reref_idxs: ref = chunk[self._reref_idxs].mean(axis=0) chunk = chunk - ref elif self._reref_type == "channel": ref = chunk[self._reref_idx].copy() chunk = chunk - ref except Exception: pass # Queue the processed chunk — buffer write + redraw happen in the # main thread via _flush_data_queue() to avoid Qt thread-safety issues. self._data_queue.append(chunk)
def _flush_data_queue(self) -> None: """Drain the data queue and redraw — called from the main thread at 30 Hz.""" if not self._data_queue: return while self._data_queue: chunk = self._data_queue.popleft() n = chunk.shape[1] self._buf = np.roll(self._buf, -n, axis=1) self._buf[:, -n:] = chunk self._total_pushed += n self._run_potato_detection() end = min(self._page_start + self._n_shown, self._n_ch) self._status.showMessage( f"Streaming — ch {self._page_start + 1}{end} of {self._n_ch}" ) self._redraw() @property def bad_channels(self) -> list[str]: """Channel names currently marked as bad.""" return [self._ch_names[i] for i in sorted(self._bad_ch_idxs)] @property def bad_segments(self) -> list[tuple[float, float]]: """Bad segments as list of (start_s, end_s) in absolute seconds.""" return list(self._bad_segs)
[docs] def to_annotations(self): """Return bad segments as :class:`mne.Annotations`. Returns ------- annotations : mne.Annotations or None ``None`` when no bad segments have been marked. """ if not self._bad_segs: return None try: import mne onsets = [s for s, _ in self._bad_segs] durations = [e - s for s, e in self._bad_segs] return mne.Annotations( onset=onsets, duration=durations, description=["BAD_segment"] * len(onsets), ) except Exception: return None
[docs] def closeEvent(self, event) -> None: self._flush_timer.stop() super().closeEvent(event)