Source code for mne_rt.viz.topo_plot

"""Real-time scalp-layout ERP / evoked-potential display.

Live-updating equivalent of :func:`mne.viz.plot_evoked_topo`: channels
are placed at their exact 2-D scalp positions (from
:func:`mne.channels.find_layout`), using PyQtGraph's scene for absolute
positioning rather than a collapsible grid.  Redraws after every
:meth:`update` call as new epochs arrive from :class:`~mne_rt.RTEpochs`.

Classes
-------
TopoPlot
    Real-time scalp-layout ERP display with interactive sidebar.
"""
from __future__ import annotations

import math
from typing import Optional, Union

import numpy as np

try:
    from PyQt6.QtCore import Qt, QRectF, pyqtSignal
    from PyQt6.QtGui  import QFont, QColor
    from PyQt6.QtWidgets import (
        QApplication, QMainWindow, QWidget,
        QVBoxLayout, QHBoxLayout, QLabel,
        QCheckBox, QSlider, QPushButton,
        QFrame, QSizePolicy, QScrollArea,
        QFileDialog, QComboBox,
    )
    _qt_available = True
except ImportError:
    _qt_available = False

try:
    import pyqtgraph as pg
    _pg_available = True
except ImportError:
    _pg_available = False

try:
    import mne
    _mne_available = True
except ImportError:
    _mne_available = False

from mne_rt._logging import logger, set_log_level


# ---------------------------------------------------------------------------
# Palette
# ---------------------------------------------------------------------------
_BG       = "#0d1117"
_SURFACE  = "#161b22"
_BORDER   = "#30363d"
_TEXT     = "#e6edf3"
_DIM      = "#8b949e"
_ACCENT   = "#3b82f6"

_COND_COLORS = [
    "#3b82f6",   # blue
    "#ec4899",   # pink
    "#10b981",   # green
    "#f59e0b",   # amber
    "#8b5cf6",   # violet
    "#06b6d4",   # cyan
]

_BG_PRESETS = [
    ("Dark",  "#0d1117", _TEXT),
    ("Navy",  "#050d1a", _TEXT),
    ("Slate", "#1e2030", _TEXT),
    ("Dim",   "#2d333b", _TEXT),
    ("Light", "#f1f5f9", "#111827"),
]

_SIDEBAR_W = 230

_SW, _SH = 1000, 920
_PW, _PH = 76, 64

_MASTOID_NAMES = frozenset(
    ["M1", "M2", "TP9", "TP10", "A1", "A2", "Mastoid", "mastoid"]
)


# ---------------------------------------------------------------------------
# Sidebar helpers
# ---------------------------------------------------------------------------

def _sep(parent: QWidget) -> QFrame:
    f = QFrame(parent)
    f.setFrameShape(QFrame.Shape.HLine)
    f.setStyleSheet(f"color:{_BORDER};")
    return f


def _section(text: str, parent: QWidget) -> QLabel:
    lbl = QLabel(text, parent)
    lbl.setStyleSheet(
        f"color:{_DIM}; font-size:10px; font-weight:700; "
        "letter-spacing:1px; padding-top:6px;"
    )
    return lbl


def _row(parent: QWidget, spacing: int = 5) -> tuple[QWidget, QHBoxLayout]:
    w = QWidget(parent)
    lay = QHBoxLayout(w)
    lay.setContentsMargins(0, 1, 0, 1)
    lay.setSpacing(spacing)
    return w, lay


def _val_lbl(text: str, parent: QWidget, color: str = _ACCENT) -> QLabel:
    lbl = QLabel(text, parent)
    lbl.setStyleSheet(
        f"color:{color}; font-size:11px; font-weight:600;"
    )
    lbl.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
    return lbl


def _key_lbl(text: str, parent: QWidget) -> QLabel:
    lbl = QLabel(text, parent)
    lbl.setStyleSheet(f"color:{_TEXT}; font-size:11px;")
    return lbl


def _slider(parent, lo: int, hi: int, val: int) -> QSlider:
    sl = QSlider(Qt.Orientation.Horizontal, parent)
    sl.setRange(lo, hi)
    sl.setValue(val)
    return sl


# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------

[docs] class TopoPlot(QMainWindow): """Real-time scalp-layout ERP display. One mini :class:`pyqtgraph.PlotItem` per electrode, placed at the channel's true 2-D scalp position from :func:`mne.channels.find_layout`. Condition averages (with optional ±1 SEM shading) are redrawn after every :meth:`update` call. Parameters ---------- ch_names : list of str Electrode names in data order. sfreq : float Sampling frequency in Hz. tmin : float Epoch start (s). tmax : float Epoch end (s). event_id : dict[str, int] Condition label → marker integer. info : mne.Info or None When provided, :func:`mne.channels.find_layout` is called on this object for exact scalp positioning and channel-type detection (EEG → µV, MEG mag → fT, MEG grad → fT/cm). Pass ``epochs_stream.info`` from :class:`~mne_rt.RTEpochs`. montage : str, default "standard_1020" Fallback montage when ``info`` is not given or has no dig points. baseline : tuple or None, default (None, 0) Baseline interval — drawn as a shaded region. window_size : tuple of int, default (1440, 900) Initial window size in pixels. verbose : bool or str or None .. versionadded:: 1.0.0 See Also -------- mne_rt.RTEpochs : Drives this plot via :meth:`update`. mne_rt.viz.ButterflyPlot : All-channel overlay alternative. mne_rt.viz.CompareEvoked : Large per-channel view with SEM ribbons. """ _redraw_sig = pyqtSignal(int)
[docs] def __init__( self, ch_names: list[str], sfreq: float, tmin: float, tmax: float, event_id: dict[str, int], info=None, montage: str = "standard_1020", baseline: Optional[tuple] = (None, 0), window_size: tuple[int, int] = (1440, 900), verbose: Union[bool, str, None] = None, ) -> None: if not _qt_available or not _pg_available: raise ImportError( "PyQt6 and pyqtgraph are required for TopoPlot.\n" "Install with: pip install 'mne-rt[full]'" ) _app = QApplication.instance() or QApplication([]) # noqa: F841 super().__init__() self._redraw_sig.connect(self._redraw) set_log_level(verbose) self.ch_names = list(ch_names) self.sfreq = sfreq self.tmin = tmin self.tmax = tmax self.event_id = event_id self.montage = montage self.baseline = baseline self._info = info self._conditions = list(event_id.keys()) self._cmap = { c: _COND_COLORS[i % len(_COND_COLORS)] for i, c in enumerate(self._conditions) } self._n_ch = len(ch_names) self._n_t = int(round((tmax - tmin) * sfreq)) + 1 self._times = np.linspace(tmin, tmax, self._n_t) self._epoch_buf: dict[str, list[np.ndarray]] = { c: [] for c in self._conditions } self._n_per: dict[str, int] = {c: 0 for c in self._conditions} # Display state self._yscale = 1.0 self._linewidth = 1.6 self._smooth_ms = 0.0 self._show_sem = False self._plot_bg = _BG self._x_start = tmin self._x_end = tmax # Unit / re-reference self._unit, self._unit_scale = self._detect_unit(info) self._reref_mode = "none" self._mastoid_idx = self._find_mastoids() self._norm_pos = self._compute_positions(info, montage) self.setWindowTitle("MNE-RT — Topo ERP") self.resize(*window_size) self._apply_styles() self._build_ui() logger.info( "TopoPlot(ERP): %d ch, %.0f%.0f ms, unit=%s, layout=%s", self._n_ch, tmin * 1000, tmax * 1000, self._unit, "from info" if info is not None else "montage/fallback", )
# ----------------------------------------------------------------------- # Unit / mastoid helpers # ----------------------------------------------------------------------- def _detect_unit(self, info) -> tuple[str, float]: if info is None or not _mne_available: return "µV", 1e6 try: ct = mne.channel_type(info, 0) if ct == "mag": return "fT", 1e15 elif ct == "grad": return "fT/cm", 1e13 except Exception: pass return "µV", 1e6 def _find_mastoids(self) -> list[int]: return [ i for i, ch in enumerate(self.ch_names) if ch in _MASTOID_NAMES ] # ----------------------------------------------------------------------- # Scalp layout # ----------------------------------------------------------------------- def _compute_positions( self, info, montage_name: str ) -> list[tuple[float, float]]: if _mne_available: if info is not None: pos = self._from_layout(mne.channels.find_layout(info)) if pos is not None: return pos try: tmp = mne.create_info( self.ch_names, sfreq=1.0, ch_types="eeg", verbose=False ) mont = mne.channels.make_standard_montage(montage_name) tmp.set_montage(mont, on_missing="ignore", verbose=False) pos = self._from_layout(mne.channels.find_layout(tmp)) if pos is not None: return pos except Exception as exc: logger.debug("TopoPlot: montage layout failed: %s", exc) logger.warning("TopoPlot: falling back to circular layout.") return self._circular_fallback() def _from_layout(self, layout) -> Optional[list[tuple[float, float]]]: if layout is None: return None name_xy: dict[str, tuple[float, float]] = {} for name, pos in zip(layout.names, layout.pos): xc = float(pos[0] + pos[2] / 2.0) yc = float(pos[1] + pos[3] / 2.0) name_xy[name] = (xc, 1.0 - yc) n_matched = sum(1 for c in self.ch_names if c in name_xy) if n_matched < self._n_ch // 2: return None positions: list[tuple[float, float]] = [] fb = 0 for ch in self.ch_names: if ch in name_xy: positions.append(name_xy[ch]) else: positions.append((0.02, fb * 0.05)); fb += 1 return positions def _circular_fallback(self) -> list[tuple[float, float]]: n = self._n_ch return [ (0.5 + 0.42 * math.cos(2 * math.pi * i / n - math.pi / 2), 0.5 + 0.42 * math.sin(2 * math.pi * i / n - math.pi / 2)) for i in range(n) ] # ----------------------------------------------------------------------- # UI build # ----------------------------------------------------------------------- def _apply_styles(self) -> None: self.setStyleSheet(f""" QMainWindow, QWidget {{ background:{_BG}; color:{_TEXT}; }} QLabel {{ color:{_TEXT}; font-size:12px; }} QCheckBox {{ color:{_TEXT}; font-size:12px; spacing:6px; }} QCheckBox::indicator {{ width:14px; height:14px; border-radius:3px; border:1px solid {_BORDER}; background:{_SURFACE}; }} QCheckBox::indicator:checked {{ background:{_ACCENT}; border-color:{_ACCENT}; }} QPushButton {{ background:{_SURFACE}; color:{_TEXT}; border:1px solid {_BORDER}; border-radius:5px; padding:4px 10px; font-size:11px; }} QPushButton:hover {{ background:{_BORDER}; }} QSlider::groove:horizontal {{ height:4px; background:{_BORDER}; border-radius:2px; }} QSlider::handle:horizontal {{ width:14px; height:14px; margin:-5px 0; border-radius:7px; background:{_ACCENT}; }} QComboBox {{ background:{_SURFACE}; color:{_TEXT}; border:1px solid {_BORDER}; border-radius:4px; padding:2px 6px; font-size:11px; }} QComboBox::drop-down {{ border:none; width:20px; }} QScrollArea {{ border: none; }} QScrollBar:vertical {{ background:{_BG}; width:6px; }} QScrollBar::handle:vertical {{ background:{_BORDER}; border-radius:3px; }} """) def _build_ui(self) -> None: central = QWidget() self.setCentralWidget(central) root = QHBoxLayout(central) root.setSpacing(0) root.setContentsMargins(0, 0, 0, 0) pg.setConfigOptions(antialias=True, background=_BG, foreground=_DIM) self._gview = pg.GraphicsView() self._gview.setBackground(_BG) self._gview.setSizePolicy( QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding ) root.addWidget(self._gview, stretch=1) self._scene = self._gview.sceneObj root.addWidget(self._build_sidebar()) self._plots: list[pg.PlotItem] = [] self._ch_labels: list[pg.TextItem] = [] self._curves: dict[str, list[pg.PlotCurveItem]] = { c: [] for c in self._conditions } self._sem_upper: dict[str, list[pg.PlotCurveItem]] = { c: [] for c in self._conditions } self._sem_lower: dict[str, list[pg.PlotCurveItem]] = { c: [] for c in self._conditions } self._sem_fills: dict[str, list] = {c: [] for c in self._conditions} self._zl_items: list[pg.InfiniteLine] = [] for ch_idx, ch in enumerate(self.ch_names): xn, yn = self._norm_pos[ch_idx] margin = 0.05 cx = (margin + xn * (1.0 - 2 * margin)) * _SW cy = (margin + yn * (1.0 - 2 * margin)) * _SH plot = pg.PlotItem() plot.setGeometry(QRectF(cx - _PW / 2, cy - _PH / 2, _PW, _PH)) self._scene.addItem(plot) lbl = self._style_plot(plot, ch) self._ch_labels.append(lbl) zl = pg.InfiniteLine( pos=0, angle=90, pen=pg.mkPen(_BORDER, width=1, style=Qt.PenStyle.DashLine), ) plot.addItem(zl) self._zl_items.append(zl) for cond in self._conditions: col = self._cmap[cond] curve = plot.plot( self._times, np.zeros(self._n_t), pen=pg.mkPen(col, width=self._linewidth), ) self._curves[cond].append(curve) upper = plot.plot(self._times, np.zeros(self._n_t), pen=None) lower = plot.plot(self._times, np.zeros(self._n_t), pen=None) qcol = QColor(col) qcol.setAlpha(55) fill = pg.FillBetweenItem(upper, lower, brush=pg.mkBrush(qcol)) fill.setVisible(False) plot.addItem(fill) self._sem_upper[cond].append(upper) self._sem_lower[cond].append(lower) self._sem_fills[cond].append(fill) self._plots.append(plot) def _style_plot(self, plot: pg.PlotItem, ch: str) -> pg.TextItem: plot.setMenuEnabled(False) plot.hideButtons() plot.setMouseEnabled(x=False, y=False) for ax in ("bottom", "left", "top", "right"): plot.showAxis(ax, False) plot.setContentsMargins(0, 0, 0, 0) lbl = pg.TextItem(ch, color=_DIM, anchor=(0, 1)) lbl.setFont(QFont("Helvetica", 6)) plot.addItem(lbl) lbl.setPos(self.tmin, 0) return lbl def _fit_view(self) -> None: if not self._plots: return rects = [p.geometry() for p in self._plots] x0 = min(r.x() for r in rects) y0 = min(r.y() for r in rects) x1 = max(r.x() + r.width() for r in rects) y1 = max(r.y() + r.height() for r in rects) pad_x = (x1 - x0) * 0.08 pad_y = (y1 - y0) * 0.10 self._gview.fitInView( QRectF(x0 - pad_x, y0 - pad_y, x1 - x0 + 2 * pad_x, y1 - y0 + 2 * pad_y), Qt.AspectRatioMode.KeepAspectRatio, )
[docs] def showEvent(self, event) -> None: # noqa: N802 super().showEvent(event) self._fit_view()
[docs] def resizeEvent(self, event) -> None: # noqa: N802 super().resizeEvent(event) self._fit_view()
# ----------------------------------------------------------------------- # Sidebar build # ----------------------------------------------------------------------- def _build_sidebar(self) -> QScrollArea: scroll = QScrollArea() scroll.setFixedWidth(_SIDEBAR_W + 14) scroll.setWidgetResizable(True) scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) scroll.setStyleSheet( f"QScrollArea {{ background:{_SURFACE}; " f"border-left:1px solid {_BORDER}; }}" ) sb = QWidget() sb.setStyleSheet(f"background:{_SURFACE};") ly = QVBoxLayout(sb) ly.setSpacing(4) ly.setContentsMargins(10, 12, 10, 12) # ── Header ─────────────────────────────────────────────────────── hdr = QLabel("ERP CONTROLS") hdr.setStyleSheet( f"color:{_TEXT}; font-size:11px; font-weight:700; letter-spacing:1.5px;" ) ly.addWidget(hdr) # Unit badge next to header self._unit_lbl = QLabel(f"[{self._unit}]") self._unit_lbl.setStyleSheet( f"color:{_ACCENT}; font-size:10px; font-weight:600; " f"background:{_SURFACE}; border:1px solid {_BORDER}; " "border-radius:3px; padding:1px 5px;" ) row_hdr, lhdr = _row(sb) lhdr.addWidget(hdr, stretch=1) lhdr.addWidget(self._unit_lbl) ly.addWidget(row_hdr) ly.addWidget(_sep(sb)) # ── CONDITIONS ─────────────────────────────────────────────────── ly.addWidget(_section("CONDITIONS", sb)) self._cond_checks: dict[str, QCheckBox] = {} self._cond_n_lbl: dict[str, QLabel] = {} for cond in self._conditions: col = self._cmap[cond] rw, rl = _row(sb) cb = QCheckBox() cb.setChecked(True) cb.setStyleSheet( f"QCheckBox::indicator:checked{{" f"background:{col};border-color:{col};}}" ) cb.toggled.connect(lambda chk, c=cond: self._toggle_cond(c, chk)) self._cond_checks[cond] = cb dot = QLabel(f"● {cond}") dot.setStyleSheet(f"color:{col};font-size:11px;font-weight:600;") dot.setWordWrap(True) n_lbl = _val_lbl("n = 0", sb, color=_DIM) self._cond_n_lbl[cond] = n_lbl rl.addWidget(cb); rl.addWidget(dot, stretch=1); rl.addWidget(n_lbl) ly.addWidget(rw) ly.addWidget(_sep(sb)) # ── AMPLITUDE ──────────────────────────────────────────────────── ly.addWidget(_section("AMPLITUDE", sb)) r1, l1 = _row(sb) l1.addWidget(_key_lbl("Y scale", sb), stretch=1) self._sv_lbl = _val_lbl("×1.0", sb) l1.addWidget(self._sv_lbl) ly.addWidget(r1) self._scale_sl = _slider(sb, 1, 50, 5) self._scale_sl.valueChanged.connect(self._on_scale) ly.addWidget(self._scale_sl) self._yscale = 1.0 r2, l2 = _row(sb) l2.addWidget(_key_lbl("Line width", sb), stretch=1) self._lw_lbl = _val_lbl("1.5", sb) l2.addWidget(self._lw_lbl) ly.addWidget(r2) self._lw_sl = _slider(sb, 1, 8, 3) self._lw_sl.valueChanged.connect(self._on_linewidth) ly.addWidget(self._lw_sl) rbt, lbt = _row(sb, 6) auto_btn = QPushButton("Auto scale") auto_btn.clicked.connect(self._auto_scale) lbt.addWidget(auto_btn) ly.addWidget(rbt) ly.addWidget(_sep(sb)) # ── SMOOTHING ──────────────────────────────────────────────────── ly.addWidget(_section("SMOOTHING", sb)) r3, l3 = _row(sb) l3.addWidget(_key_lbl("Window", sb), stretch=1) self._sm_lbl = _val_lbl("Off", sb) l3.addWidget(self._sm_lbl) ly.addWidget(r3) self._smooth_sl = _slider(sb, 0, 50, 0) self._smooth_sl.valueChanged.connect(self._on_smooth) ly.addWidget(self._smooth_sl) ly.addWidget(_sep(sb)) # ── RE-REFERENCE ───────────────────────────────────────────────── ly.addWidget(_section("RE-REFERENCE", sb)) self._reref_cb = QComboBox(sb) self._reref_cb.addItem("None (raw)") self._reref_cb.addItem("Average reference") mastoid_names = [self.ch_names[i] for i in self._mastoid_idx] if mastoid_names: self._reref_cb.addItem(f"Mastoids ({', '.join(mastoid_names)})") else: self._reref_cb.addItem("Mastoids (not found)") self._reref_cb.model().item(2).setEnabled(False) self._reref_cb.currentIndexChanged.connect(self._on_reref) ly.addWidget(self._reref_cb) ly.addWidget(_sep(sb)) # ── TIME WINDOW ────────────────────────────────────────────────── ly.addWidget(_section("TIME WINDOW", sb)) tmin_ms = int(self.tmin * 1000) tmax_ms = int(self.tmax * 1000) r4, l4 = _row(sb) l4.addWidget(_key_lbl("Start", sb), stretch=1) self._xstart_lbl = _val_lbl(f"{tmin_ms} ms", sb) l4.addWidget(self._xstart_lbl) ly.addWidget(r4) self._xstart_sl = _slider(sb, tmin_ms, 0, tmin_ms) self._xstart_sl.valueChanged.connect(self._on_xrange) ly.addWidget(self._xstart_sl) r5, l5 = _row(sb) l5.addWidget(_key_lbl("End", sb), stretch=1) self._xend_lbl = _val_lbl(f"{tmax_ms} ms", sb) l5.addWidget(self._xend_lbl) ly.addWidget(r5) self._xend_sl = _slider(sb, 0, tmax_ms, tmax_ms) self._xend_sl.valueChanged.connect(self._on_xrange) ly.addWidget(self._xend_sl) ly.addWidget(_sep(sb)) # ── APPEARANCE ─────────────────────────────────────────────────── ly.addWidget(_section("APPEARANCE", sb)) ly.addWidget(_key_lbl("Background", sb)) sw_row, sw_l = _row(sb, 5) self._bg_swatches: list[QPushButton] = [] for label, hex_col, _ in _BG_PRESETS: btn = QPushButton() btn.setFixedSize(28, 22) btn.setToolTip(label) active = hex_col == _BG border = f"2px solid {_ACCENT}" if active else f"1px solid {_BORDER}" btn.setStyleSheet( f"background:{hex_col}; border:{border}; border-radius:4px;" ) btn.clicked.connect(lambda _, c=hex_col: self._set_bg(c)) sw_l.addWidget(btn) self._bg_swatches.append(btn) sw_l.addStretch() ly.addWidget(sw_row) ly.addSpacing(4) self._sem_chk = QCheckBox("±1 SEM shading") self._sem_chk.setChecked(False) self._sem_chk.toggled.connect(self._toggle_sem) ly.addWidget(self._sem_chk) self._labels_chk = QCheckBox("Channel labels") self._labels_chk.setChecked(True) self._labels_chk.toggled.connect(self._toggle_labels) ly.addWidget(self._labels_chk) self._zl_chk = QCheckBox("Stimulus line") self._zl_chk.setChecked(True) self._zl_chk.toggled.connect(self._toggle_zl) ly.addWidget(self._zl_chk) ly.addWidget(_sep(sb)) # ── DATA ───────────────────────────────────────────────────────── ly.addWidget(_section("DATA", sb)) self._total_lbl = QLabel("Total: 0 trials") self._total_lbl.setStyleSheet(f"color:{_TEXT};font-size:11px;") ly.addWidget(self._total_lbl) ly.addSpacing(4) reset_btn = QPushButton("Reset epochs") reset_btn.setToolTip("Clear all accumulated epochs") reset_btn.clicked.connect(self._reset_epochs) ly.addWidget(reset_btn) export_btn = QPushButton("Export PNG …") export_btn.setToolTip("Save current plot as PNG") export_btn.clicked.connect(self._export_png) ly.addWidget(export_btn) ly.addStretch() scroll.setWidget(sb) return scroll # ----------------------------------------------------------------------- # Sidebar callbacks # ----------------------------------------------------------------------- def _toggle_cond(self, cond: str, visible: bool) -> None: for c in self._curves[cond]: c.setVisible(visible) for f in self._sem_fills[cond]: f.setVisible(visible and self._show_sem) def _on_scale(self, value: int) -> None: self._yscale = value / 5.0 self._sv_lbl.setText(f{self._yscale:.1f}") self._apply_y_range() def _apply_y_range(self) -> None: avgs = [] for cond in self._conditions: if self._cond_checks.get(cond, QCheckBox()).isChecked(): buf = self._epoch_buf[cond] if buf: avg = np.mean(np.stack(buf, 0), 0) * self._unit_scale avgs.append(self._apply_reref(avg)) if not avgs: return amp = float(np.percentile(np.abs(np.stack(avgs, 0)), 99)) or 1e-12 half = amp * self._yscale for plot in self._plots: plot.setYRange(-half, half, padding=0.05) def _auto_scale(self) -> None: self._scale_sl.setValue(5) # triggers _on_scale → _apply_y_range def _on_linewidth(self, value: int) -> None: self._linewidth = value * 0.5 self._lw_lbl.setText(f"{self._linewidth:.1f}") for cond in self._conditions: col = self._cmap[cond] for curve in self._curves[cond]: curve.setPen(pg.mkPen(col, width=self._linewidth)) def _on_smooth(self, value: int) -> None: self._smooth_ms = float(value) self._sm_lbl.setText("Off" if value == 0 else f"{value} ms") total = sum(self._n_per.values()) if total > 0: self._redraw(total) def _smooth(self, y: np.ndarray) -> np.ndarray: if self._smooth_ms <= 0: return y n = max(1, int(self._smooth_ms * 1e-3 * self.sfreq)) if n < 2: return y return np.convolve(y, np.ones(n) / n, mode="same") def _on_reref(self, index: int) -> None: modes = ["none", "average", "mastoids"] self._reref_mode = modes[index] if index < len(modes) else "none" total = sum(self._n_per.values()) if total > 0: self._redraw(total) def _apply_reref(self, avg: np.ndarray) -> np.ndarray: """Apply re-referencing to avg (n_ch, n_times). Returns copy.""" if self._reref_mode == "average": return avg - avg.mean(0, keepdims=True) if self._reref_mode == "mastoids" and self._mastoid_idx: ref = avg[self._mastoid_idx].mean(0) return avg - ref return avg def _on_xrange(self) -> None: x1 = self._xstart_sl.value() / 1000.0 x2 = self._xend_sl.value() / 1000.0 if x1 >= x2: return self._x_start = x1 self._x_end = x2 self._xstart_lbl.setText(f"{int(x1 * 1000)} ms") self._xend_lbl.setText(f"{int(x2 * 1000)} ms") for plot in self._plots: plot.setXRange(x1, x2, padding=0) def _set_bg(self, color: str) -> None: self._plot_bg = color self._gview.setBackground(color) for plot in self._plots: plot.getViewBox().setBackgroundColor(color) for btn, (_, hex_col, _) in zip(self._bg_swatches, _BG_PRESETS): active = hex_col == color border = f"2px solid {_ACCENT}" if active else f"1px solid {_BORDER}" btn.setStyleSheet( f"background:{hex_col}; border:{border}; border-radius:4px;" ) def _toggle_sem(self, visible: bool) -> None: self._show_sem = visible for cond in self._conditions: checked = self._cond_checks.get(cond, QCheckBox()).isChecked() for f in self._sem_fills[cond]: f.setVisible(visible and checked) if visible: total = sum(self._n_per.values()) if total > 0: self._redraw(total) def _toggle_labels(self, visible: bool) -> None: for lbl in self._ch_labels: lbl.setVisible(visible) def _toggle_zl(self, v: bool) -> None: for item in self._zl_items: item.setVisible(v) def _reset_epochs(self) -> None: for cond in self._conditions: self._epoch_buf[cond] = [] self._n_per[cond] = 0 self._cond_n_lbl[cond].setText("n = 0") for curve in self._curves[cond]: curve.setData(self._times, np.zeros(self._n_t)) for upper, lower, fill in zip( self._sem_upper[cond], self._sem_lower[cond], self._sem_fills[cond] ): upper.setData(self._times, np.zeros(self._n_t)) lower.setData(self._times, np.zeros(self._n_t)) fill.setVisible(False) self._total_lbl.setText("Total: 0 trials") def _export_png(self) -> None: path, _ = QFileDialog.getSaveFileName( self, "Export Topo ERP Plot", "topo_plot.png", "PNG Image (*.png);;JPEG Image (*.jpg)", ) if path: self.grab().save(path) # ----------------------------------------------------------------------- # Public API # -----------------------------------------------------------------------
[docs] def update( self, data: np.ndarray, conditions: list[str], ) -> None: """Redraw all channel plots with updated condition averages. Thread-safe — may be called from the acquisition thread. Parameters ---------- data : ndarray, shape (n_epochs, n_channels, n_times) All accepted epochs so far. conditions : list of str Condition label for each epoch; ``len(conditions) == data.shape[0]``. """ n_total = len(conditions) for cond in self._conditions: mask = np.array([c == cond for c in conditions]) self._epoch_buf[cond] = list(data[mask]) if mask.any() else [] self._n_per[cond] = int(mask.sum()) self._redraw_sig.emit(n_total)
def _redraw(self, n_total: int) -> None: self._total_lbl.setText(f"Total: {n_total} trials") for cond in self._conditions: self._cond_n_lbl[cond].setText(f"n = {self._n_per[cond]}") buf = self._epoch_buf[cond] n = len(buf) if buf: stack = np.stack(buf, 0) avg = np.mean(stack, 0) # (n_ch, n_t) sem = (np.std(stack, 0, ddof=1) / np.sqrt(n) if n >= 2 else np.zeros_like(avg)) else: avg = np.zeros((self._n_ch, self._n_t)) sem = np.zeros_like(avg) # Apply re-reference then unit scaling avg = self._apply_reref(avg) * self._unit_scale sem = sem * self._unit_scale checked = self._cond_checks.get(cond, QCheckBox()).isChecked() for ch_i, curve in enumerate(self._curves[cond]): y = avg[ch_i] if len(y) != self._n_t: y = np.interp( self._times, np.linspace(self.tmin, self.tmax, len(y)), y ) curve.setData(self._times, self._smooth(y)) if self._show_sem and n >= 2: s = sem[ch_i] self._sem_upper[cond][ch_i].setData(self._times, y + s) self._sem_lower[cond][ch_i].setData(self._times, y - s) self._sem_fills[cond][ch_i].setVisible(checked) self._apply_y_range()