"""Real-time multi-condition evoked comparison for selected channels.
Shows N user-selected channels as large individual plots with all
conditions overlaid, ±1 SEM shading, visible time/amplitude axes, and
auto-detected peak markers. Channels are chosen by clicking on a
mini scalp-topomap in the sidebar. Redraws after every :meth:`update`
call as new epochs arrive from :class:`~mne_rt.RTEpochs`.
Classes
-------
CompareEvoked
Real-time per-channel condition-overlay display with SEM shading,
peak detection, and interactive topomap channel selector.
"""
from __future__ import annotations
import math
from typing import Optional, Union
import numpy as np
try:
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtGui import QFont, QColor
from PyQt6.QtWidgets import (
QApplication, QMainWindow, QWidget,
QVBoxLayout, QHBoxLayout, QLabel,
QCheckBox, QSlider, QPushButton,
QFrame, QSizePolicy, QScrollArea,
QFileDialog,
)
_qt_available = True
except ImportError:
_qt_available = False
try:
import pyqtgraph as pg
_pg_available = True
except ImportError:
_pg_available = False
try:
import mne
_mne_available = True
except ImportError:
_mne_available = False
from mne_rt._logging import logger, set_log_level
# ---------------------------------------------------------------------------
# Palette
# ---------------------------------------------------------------------------
_BG = "#0d1117"
_SURFACE = "#161b22"
_BORDER = "#30363d"
_TEXT = "#e6edf3"
_DIM = "#8b949e"
_ACCENT = "#3b82f6"
_COND_COLORS = [
"#3b82f6", # blue
"#ec4899", # pink
"#10b981", # green
"#f59e0b", # amber
"#8b5cf6", # violet
"#06b6d4", # cyan
]
# Auto-channel preference list (case-insensitive matching)
_PREF_CHANNELS = ["cz", "pz", "oz", "fz", "fcz", "cpz"]
_SIDEBAR_W = 210 # px
# ---------------------------------------------------------------------------
# Sidebar helpers (mirrors erp_plot.py)
# ---------------------------------------------------------------------------
def _sep(parent: QWidget) -> QFrame:
f = QFrame(parent)
f.setFrameShape(QFrame.Shape.HLine)
f.setStyleSheet(f"color:{_BORDER};")
return f
def _section(text: str, parent: QWidget) -> QLabel:
lbl = QLabel(text, parent)
lbl.setStyleSheet(
f"color:{_DIM}; font-size:10px; font-weight:700; "
"letter-spacing:1px; padding-top:6px;"
)
return lbl
def _row(parent: QWidget, spacing: int = 5) -> tuple[QWidget, QHBoxLayout]:
w = QWidget(parent)
lay = QHBoxLayout(w)
lay.setContentsMargins(0, 1, 0, 1)
lay.setSpacing(spacing)
return w, lay
def _val_lbl(text: str, parent: QWidget, color: str = _ACCENT) -> QLabel:
lbl = QLabel(text, parent)
lbl.setStyleSheet(
f"color:{color}; font-size:11px; font-weight:600;"
)
lbl.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
return lbl
def _key_lbl(text: str, parent: QWidget) -> QLabel:
lbl = QLabel(text, parent)
lbl.setStyleSheet(f"color:{_TEXT}; font-size:11px;")
return lbl
def _slider(parent: QWidget, lo: int, hi: int, val: int) -> QSlider:
sl = QSlider(Qt.Orientation.Horizontal, parent)
sl.setRange(lo, hi)
sl.setValue(val)
return sl
# ---------------------------------------------------------------------------
# Unit/scale helpers
# ---------------------------------------------------------------------------
def _detect_unit(info, ch_names: list[str]) -> tuple[str, float]:
if info is None or not _mne_available:
return ("µV", 1e6)
try:
ch_type = mne.channel_type(info, 0)
if ch_type == "eeg":
return ("µV", 1e6)
elif ch_type == "mag":
return ("fT", 1e15)
elif ch_type == "grad":
return ("fT/cm", 1e13)
else:
return ("µV", 1e6)
except Exception:
return ("µV", 1e6)
def _auto_channels(ch_names: list[str]) -> list[str]:
lower_to_orig = {ch.lower(): ch for ch in ch_names}
selected: list[str] = []
for pref in _PREF_CHANNELS:
if pref in lower_to_orig:
selected.append(lower_to_orig[pref])
if len(selected) == 3:
break
if not selected:
selected = ch_names[:3]
return selected
# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------
[docs]
class CompareEvoked(QMainWindow):
"""Real-time per-channel condition overlay with SEM shading and peak markers.
Shows N user-selected channels as large individual :class:`pyqtgraph.PlotItem`
rows. Each plot overlays all conditions with solid curves, ±1 SEM shading,
and a scatter point marking the peak latency in the post-stimulus window.
Channels are chosen interactively via a clickable mini scalp-topomap
in the sidebar: click any electrode dot to add or remove it from the
display (up to :data:`_MAX_DISP_CH` channels simultaneously).
Parameters
----------
ch_names : list of str
Electrode names in data order.
sfreq : float
Sampling frequency in Hz.
tmin : float
Epoch start (s).
tmax : float
Epoch end (s).
event_id : dict[str, int]
Condition label → marker integer.
channels : list of str or None
Channels shown on startup. If ``None``, auto-selects from
``['Cz','Pz','Oz','Fz','FCz','CPz']`` or the first 3 channels.
info : mne.Info or None
Used for unit/scale detection and scalp layout.
montage : str, default ``"standard_1020"``
Fallback montage for electrode positions when *info* has no dig.
baseline : tuple or None, default ``(None, 0)``
Baseline correction interval (informational only).
window_size : tuple of int, default ``(1200, 800)``
Initial window size in pixels.
verbose : bool, str, or None
.. versionadded:: 1.0.0
See Also
--------
mne_rt.RTEpochs : Drives this plot via :meth:`update`.
"""
# Emitted from any thread; Qt delivers it to _redraw on the main thread.
_redraw_sig = pyqtSignal(int)
[docs]
def __init__(
self,
ch_names: list[str],
sfreq: float,
tmin: float,
tmax: float,
event_id: dict[str, int],
channels: Optional[list[str]] = None,
info=None,
montage: str = "standard_1020",
baseline: Optional[tuple] = (None, 0),
window_size: tuple[int, int] = (1200, 800),
verbose: Union[bool, str, None] = None,
) -> None:
if not _qt_available or not _pg_available:
raise ImportError(
"PyQt6 and pyqtgraph are required for CompareEvoked.\n"
"Install with: pip install 'mne-rt[full]'"
)
_app = QApplication.instance() or QApplication([]) # noqa: F841
super().__init__()
self._redraw_sig.connect(self._redraw)
set_log_level(verbose)
self.ch_names = list(ch_names)
self.sfreq = sfreq
self.tmin = tmin
self.tmax = tmax
self.event_id = event_id
self.montage = montage
self.baseline = baseline
self._info = info
self._conditions = list(event_id.keys())
self._cmap = {
c: _COND_COLORS[i % len(_COND_COLORS)]
for i, c in enumerate(self._conditions)
}
self._n_ch = len(ch_names)
self._n_t = int(round((tmax - tmin) * sfreq)) + 1
self._times = np.linspace(tmin, tmax, self._n_t)
# Display channels (mutable — changed by topomap clicks)
if channels is None:
self._disp_channels: list[str] = _auto_channels(self.ch_names)
else:
self._disp_channels = [ch for ch in channels if ch in self.ch_names]
if not self._disp_channels:
logger.warning(
"CompareEvoked: none of the requested channels found; "
"falling back to auto-selection."
)
self._disp_channels = _auto_channels(self.ch_names)
# (no hard cap — let the user decide via the topomap)
# Index in ch_names for each displayed channel
self._disp_idx: list[int] = [
self.ch_names.index(ch) for ch in self._disp_channels
]
# Unit / scale
self._unit_label, self._unit_scale = _detect_unit(info, self.ch_names)
# Epoch buffer (always over all conditions)
self._epoch_buf: dict[str, list[np.ndarray]] = {
c: [] for c in self._conditions
}
self._n_per: dict[str, int] = {c: 0 for c in self._conditions}
# Display state
self._yscale = 1.0
self._show_sem = True
self._show_peak = True
# Scalp positions for all channels (normalised, yn=0=frontal)
self._norm_pos = self._compute_positions(info, montage)
# Topomap scatter item (assigned in _build_topo_widget)
self._topo_scatter: Optional[pg.ScatterPlotItem] = None
# Canvas data structures (rebuilt by _rebuild_canvas)
self._ch_plots: dict[str, pg.PlotItem] = {}
self._curves: dict[str, dict[str, pg.PlotCurveItem]] = {}
self._sem_upper: dict[str, dict[str, pg.PlotCurveItem]] = {}
self._sem_lower: dict[str, dict[str, pg.PlotCurveItem]] = {}
self._sem_fills: dict[str, dict[str, pg.FillBetweenItem]] = {}
self._peaks: dict[str, dict[str, pg.ScatterPlotItem]] = {}
self._t0_lines: list[pg.InfiniteLine] = []
self.setWindowTitle("MNE-RT — Compare Evoked")
self.resize(*window_size)
self._apply_styles()
self._build_ui()
logger.info(
"CompareEvoked: %d channels displayed (%s), %d conditions, "
"unit=%s",
len(self._disp_channels),
", ".join(self._disp_channels),
len(self._conditions),
self._unit_label,
)
# -----------------------------------------------------------------------
# Scalp layout
# -----------------------------------------------------------------------
def _compute_positions(
self, info, montage_name: str
) -> list[tuple[float, float]]:
"""Return normalised (xn, yn) for each channel, yn=0=frontal."""
if _mne_available:
if info is not None:
pos = self._from_layout(mne.channels.find_layout(info))
if pos is not None:
return pos
try:
tmp = mne.create_info(
self.ch_names, sfreq=1.0, ch_types="eeg", verbose=False
)
mont = mne.channels.make_standard_montage(montage_name)
tmp.set_montage(mont, on_missing="ignore", verbose=False)
pos = self._from_layout(mne.channels.find_layout(tmp))
if pos is not None:
return pos
except Exception as exc:
logger.debug("CompareEvoked: montage layout failed: %s", exc)
logger.warning("CompareEvoked: falling back to circular layout.")
return self._circular_fallback()
def _from_layout(self, layout) -> Optional[list[tuple[float, float]]]:
if layout is None:
return None
name_xy: dict[str, tuple[float, float]] = {}
for name, pos in zip(layout.names, layout.pos):
xc = float(pos[0] + pos[2] / 2.0)
yc = float(pos[1] + pos[3] / 2.0)
name_xy[name] = (xc, 1.0 - yc) # yn=0 = frontal
n_matched = sum(1 for c in self.ch_names if c in name_xy)
if n_matched < self._n_ch // 2:
return None
positions: list[tuple[float, float]] = []
fb = 0
for ch in self.ch_names:
if ch in name_xy:
positions.append(name_xy[ch])
else:
positions.append((0.02, fb * 0.05))
fb += 1
return positions
def _circular_fallback(self) -> list[tuple[float, float]]:
n = self._n_ch
return [
(0.5 + 0.42 * math.cos(2 * math.pi * i / n - math.pi / 2),
0.5 + 0.42 * math.sin(2 * math.pi * i / n - math.pi / 2))
for i in range(n)
]
# -----------------------------------------------------------------------
# Styles
# -----------------------------------------------------------------------
def _apply_styles(self) -> None:
self.setStyleSheet(f"""
QMainWindow, QWidget {{ background:{_BG}; color:{_TEXT}; }}
QLabel {{ color:{_TEXT}; font-size:12px; }}
QCheckBox {{ color:{_TEXT}; font-size:12px; spacing:6px; }}
QCheckBox::indicator {{ width:14px; height:14px; border-radius:3px;
border:1px solid {_BORDER};
background:{_SURFACE}; }}
QCheckBox::indicator:checked {{ background:{_ACCENT};
border-color:{_ACCENT}; }}
QPushButton {{ background:{_SURFACE}; color:{_TEXT};
border:1px solid {_BORDER};
border-radius:5px; padding:4px 10px;
font-size:11px; }}
QPushButton:hover {{ background:{_BORDER}; }}
QSlider::groove:horizontal {{
height:4px; background:{_BORDER}; border-radius:2px; }}
QSlider::handle:horizontal {{
width:14px; height:14px; margin:-5px 0;
border-radius:7px; background:{_ACCENT}; }}
QScrollArea {{ border: none; }}
QScrollBar:vertical {{ background:{_BG}; width:6px; }}
QScrollBar::handle:vertical {{ background:{_BORDER}; border-radius:3px; }}
""")
# -----------------------------------------------------------------------
# UI build
# -----------------------------------------------------------------------
def _build_ui(self) -> None:
central = QWidget()
self.setCentralWidget(central)
root = QHBoxLayout(central)
root.setSpacing(0)
root.setContentsMargins(0, 0, 0, 0)
pg.setConfigOptions(antialias=True, background=_BG, foreground=_DIM)
self._glw = pg.GraphicsLayoutWidget()
self._glw.setBackground(_BG)
self._glw.setSizePolicy(
QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding
)
root.addWidget(self._glw, stretch=1)
root.addWidget(self._build_sidebar())
self._rebuild_canvas()
def _rebuild_canvas(self) -> None:
"""Create or recreate PlotItems for the current _disp_channels."""
self._glw.clear()
self._ch_plots = {}
self._curves = {}
self._sem_upper = {}
self._sem_lower = {}
self._sem_fills = {}
self._peaks = {}
self._t0_lines = []
self._disp_idx = [self.ch_names.index(ch) for ch in self._disp_channels]
n_rows = len(self._disp_channels)
times_ms = self._times * 1000.0
for row_idx, ch in enumerate(self._disp_channels):
is_last = row_idx == n_rows - 1
plot = self._glw.addPlot(row=row_idx, col=0)
self._style_plot(plot, ch, is_last)
self._ch_plots[ch] = plot
t0_line = pg.InfiniteLine(
pos=0.0, angle=90,
pen=pg.mkPen(_BORDER, width=1, style=Qt.PenStyle.DashLine),
)
plot.addItem(t0_line)
self._t0_lines.append(t0_line)
self._curves[ch] = {}
self._sem_upper[ch] = {}
self._sem_lower[ch] = {}
self._sem_fills[ch] = {}
self._peaks[ch] = {}
for cond in self._conditions:
col = self._cmap[cond]
qcol = QColor(col)
r, g, b = qcol.red(), qcol.green(), qcol.blue()
curve = pg.PlotCurveItem(
times_ms, np.zeros(self._n_t),
pen=pg.mkPen(col, width=1.8), antialias=True,
)
plot.addItem(curve)
self._curves[ch][cond] = curve
sem_up = pg.PlotCurveItem(
times_ms, np.zeros(self._n_t),
pen=pg.mkPen((r, g, b, 0), width=0.5),
)
sem_lo = pg.PlotCurveItem(
times_ms, np.zeros(self._n_t),
pen=pg.mkPen((r, g, b, 0), width=0.5),
)
plot.addItem(sem_up)
plot.addItem(sem_lo)
self._sem_upper[ch][cond] = sem_up
self._sem_lower[ch][cond] = sem_lo
fill = pg.FillBetweenItem(
sem_up, sem_lo, brush=pg.mkBrush(r, g, b, 40),
)
fill.setVisible(False)
plot.addItem(fill)
self._sem_fills[ch][cond] = fill
scatter = pg.ScatterPlotItem(
size=8, pen=pg.mkPen(col, width=1.5),
brush=pg.mkBrush(r, g, b, 200), symbol="o",
)
scatter.setVisible(False)
plot.addItem(scatter)
self._peaks[ch][cond] = scatter
for row_idx in range(n_rows):
self._glw.ci.layout.setRowStretchFactor(row_idx, 1)
self._update_x_ticks()
# If data already accumulated, refresh immediately
total = sum(self._n_per.values())
if total > 0:
self._redraw(total)
def _style_plot(
self, plot: pg.PlotItem, ch_name: str, is_last: bool
) -> None:
plot.setMenuEnabled(False)
plot.hideButtons()
plot.setMouseEnabled(x=False, y=False)
plot.getViewBox().setBackgroundColor(_BG)
plot.showAxis("left", True)
left_ax = plot.getAxis("left")
left_ax.setStyle(tickLength=4, showValues=False)
left_ax.setPen(pg.mkPen(_BORDER, width=1))
left_ax.setTextPen(pg.mkPen(_DIM))
plot.showAxis("top", False)
plot.showAxis("right", False)
plot.showAxis("bottom", is_last)
if is_last:
bottom_ax = plot.getAxis("bottom")
bottom_ax.setPen(pg.mkPen(_BORDER, width=1))
bottom_ax.setTextPen(pg.mkPen(_DIM))
bottom_ax.setLabel("Time (ms)", color=_DIM)
plot.setContentsMargins(0, 0, 0, 0)
plot.setXRange(self.tmin * 1000.0, self.tmax * 1000.0, padding=0)
title = pg.TextItem(ch_name, color=_TEXT, anchor=(0, 0))
title.setFont(QFont("Helvetica", 10, QFont.Weight.Bold))
plot.addItem(title)
title.setPos(self.tmin * 1000.0, 0)
# -----------------------------------------------------------------------
# Custom time axis ticks
# -----------------------------------------------------------------------
def _update_x_ticks(self) -> None:
if not self._disp_channels:
return
last_ch = self._disp_channels[-1]
plot = self._ch_plots[last_ch]
bottom_ax = plot.getAxis("bottom")
tmin_ms = self.tmin * 1000.0
tmax_ms = self.tmax * 1000.0
span_ms = tmax_ms - tmin_ms
for interval in [25, 50, 100, 200, 250, 500]:
if 4 <= span_ms / interval <= 10:
break
else:
interval = 100
start = math.ceil(tmin_ms / interval) * interval
ticks: list[tuple[float, str]] = []
t = start
while t <= tmax_ms + 1e-6:
ticks.append((t, str(int(t))))
t += interval
bottom_ax.setTicks([ticks])
# -----------------------------------------------------------------------
# Topomap channel selector
# -----------------------------------------------------------------------
def _build_topo_widget(self, parent: QWidget) -> pg.PlotWidget:
"""Return a 184×184 pyqtgraph PlotWidget with clickable electrode dots."""
pw = pg.PlotWidget(parent=parent)
pw.setFixedSize(184, 184)
pw.setBackground(_SURFACE)
pw.hideAxis("bottom")
pw.hideAxis("left")
pw.getViewBox().setMouseEnabled(x=False, y=False)
pw.getViewBox().setAspectLocked(True)
# View range with a small margin so the nose isn't clipped
pw.getViewBox().setRange(
xRange=(-0.06, 1.06), yRange=(-0.06, 1.14), padding=0,
)
# ── Head circle ──────────────────────────────────────────────────
theta = np.linspace(0, 2 * np.pi, 160)
cx = 0.5 + 0.48 * np.cos(theta)
cy = 0.5 + 0.48 * np.sin(theta)
pw.plot(cx, cy, pen=pg.mkPen(_BORDER, width=1.5))
# ── Nose (small triangle at top, y > 0.98) ───────────────────────
nose_x = [0.47, 0.5, 0.53, 0.47]
nose_y = [0.97, 1.06, 0.97, 0.97]
pw.plot(nose_x, nose_y, pen=pg.mkPen(_BORDER, width=1.2))
# ── Left/right ear bumps ─────────────────────────────────────────
for side in (-1, 1):
ear_x_vals = np.linspace(0.48 * side, 0.56 * side, 8)
ear_y_vals = 0.5 + 0.06 * np.sin(np.linspace(0, np.pi, 8))
pw.plot(
0.5 + ear_x_vals, ear_y_vals,
pen=pg.mkPen(_BORDER, width=1.2),
)
# ── Channel dots ─────────────────────────────────────────────────
# yn=0 = frontal = top of display (y large in pg's y-up coords)
spots = []
for i, ch in enumerate(self.ch_names):
xn, yn = self._norm_pos[i]
# Map into the circle: keep within 0.5±0.45
tx = 0.5 + (xn - 0.5) * 0.9
ty = 0.5 + (yn - 0.5) * 0.9 # yn=0→top, yn=1→bottom; y-axis: 0=bottom
# pg default: y increases upward, so we flip yn
ty = 1.0 - ty # now ty=1 → frontal top, ty=0 → occipital bottom
selected = ch in self._disp_channels
spots.append({
"pos": (tx, ty),
"data": ch,
"brush": pg.mkBrush(_ACCENT if selected else _BORDER),
"pen": pg.mkPen(None),
"size": 10 if selected else 6,
})
self._topo_scatter = pg.ScatterPlotItem(
spots=spots, hoverable=True,
tip=lambda x, y, data: str(data) if data else "",
)
self._topo_scatter.sigClicked.connect(self._on_topo_click)
pw.addItem(self._topo_scatter)
# ── Hint text ────────────────────────────────────────────────────
hint = pg.TextItem("click to select", color=_DIM, anchor=(0.5, 1))
hint.setFont(QFont("Helvetica", 7))
hint.setPos(0.5, -0.02)
pw.addItem(hint)
return pw
def _update_topo_colors(self) -> None:
"""Refresh dot appearance after selection changes."""
if self._topo_scatter is None:
return
spots = []
for i, ch in enumerate(self.ch_names):
xn, yn = self._norm_pos[i]
tx = 0.5 + (xn - 0.5) * 0.9
ty = 1.0 - (0.5 + (yn - 0.5) * 0.9)
selected = ch in self._disp_channels
spots.append({
"pos": (tx, ty),
"data": ch,
"brush": pg.mkBrush(_ACCENT if selected else _BORDER),
"pen": pg.mkPen(None),
"size": 10 if selected else 6,
})
self._topo_scatter.setData(spots=spots)
def _on_topo_click(self, *args) -> None:
"""Handle a click on the topomap scatter plot.
PyQtGraph passes ``(scatter, points, event)``; we accept ``*args``
for version compatibility.
"""
# points is the second-to-last argument across pg versions
points = args[-2] if len(args) >= 2 else args[0]
for pt in points:
ch = pt.data()
if ch is None or ch not in self.ch_names:
continue
if ch in self._disp_channels:
if len(self._disp_channels) > 1: # always keep at least 1
self._disp_channels.remove(ch)
else:
self._disp_channels.append(ch)
self._update_topo_colors()
# Update title
self._sel_lbl.setText(
f"{', '.join(self._disp_channels)}"
)
self._rebuild_canvas()
# -----------------------------------------------------------------------
# Sidebar build
# -----------------------------------------------------------------------
def _build_sidebar(self) -> QScrollArea:
scroll = QScrollArea()
scroll.setFixedWidth(_SIDEBAR_W + 14)
scroll.setWidgetResizable(True)
scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
scroll.setStyleSheet(
f"QScrollArea {{ background:{_SURFACE}; "
f"border-left:1px solid {_BORDER}; }}"
)
sb = QWidget()
sb.setStyleSheet(f"background:{_SURFACE};")
ly = QVBoxLayout(sb)
ly.setSpacing(4)
ly.setContentsMargins(10, 12, 10, 12)
# ── Header ───────────────────────────────────────────────────────
hdr = QLabel("COMPARE EVOKED")
hdr.setStyleSheet(
f"color:{_TEXT}; font-size:11px; font-weight:700; letter-spacing:1.5px;"
)
ly.addWidget(hdr)
ly.addWidget(_sep(sb))
# ── CHANNEL SELECTOR (topomap) ────────────────────────────────────
ly.addWidget(_section("CHANNELS", sb))
# Hint: how many selected / max
self._sel_lbl = QLabel(", ".join(self._disp_channels))
self._sel_lbl.setStyleSheet(
f"color:{_ACCENT}; font-size:10px; font-weight:600;"
)
self._sel_lbl.setWordWrap(True)
ly.addWidget(self._sel_lbl)
cap_lbl = QLabel("click to toggle")
cap_lbl.setStyleSheet(f"color:{_DIM}; font-size:9px;")
ly.addWidget(cap_lbl)
ly.addSpacing(4)
topo_w = self._build_topo_widget(sb)
# Centre the topomap in the sidebar
topo_row = QWidget(sb)
topo_lay = QHBoxLayout(topo_row)
topo_lay.setContentsMargins(0, 0, 0, 0)
topo_lay.addStretch()
topo_lay.addWidget(topo_w)
topo_lay.addStretch()
ly.addWidget(topo_row)
ly.addWidget(_sep(sb))
# ── CONDITIONS ───────────────────────────────────────────────────
ly.addWidget(_section("CONDITIONS", sb))
self._cond_checks: dict[str, QCheckBox] = {}
self._cond_n_lbl: dict[str, QLabel] = {}
for cond in self._conditions:
col = self._cmap[cond]
row_w, row_l = _row(sb)
cb = QCheckBox()
cb.setChecked(True)
cb.setStyleSheet(
f"QCheckBox::indicator:checked{{"
f"background:{col};border-color:{col};}}"
)
cb.toggled.connect(lambda chk, c=cond: self._toggle_cond(c, chk))
self._cond_checks[cond] = cb
dot = QLabel(f"● {cond}")
dot.setStyleSheet(f"color:{col};font-size:11px;font-weight:600;")
dot.setWordWrap(True)
n_lbl = _val_lbl("n = 0", sb, color=_DIM)
self._cond_n_lbl[cond] = n_lbl
row_l.addWidget(cb)
row_l.addWidget(dot, stretch=1)
row_l.addWidget(n_lbl)
ly.addWidget(row_w)
ly.addWidget(_sep(sb))
# ── DISPLAY ──────────────────────────────────────────────────────
ly.addWidget(_section("DISPLAY", sb))
self._sem_chk = QCheckBox("SEM shading")
self._sem_chk.setChecked(True)
self._sem_chk.toggled.connect(self._toggle_sem)
ly.addWidget(self._sem_chk)
self._peak_chk = QCheckBox("Peak markers")
self._peak_chk.setChecked(True)
self._peak_chk.toggled.connect(self._toggle_peaks)
ly.addWidget(self._peak_chk)
ly.addSpacing(4)
r1, l1 = _row(sb)
l1.addWidget(_key_lbl("Y scale", sb), stretch=1)
self._sv_lbl = _val_lbl("×1.0", sb)
l1.addWidget(self._sv_lbl)
ly.addWidget(r1)
self._scale_sl = _slider(sb, 1, 50, 5)
self._scale_sl.valueChanged.connect(self._on_scale)
ly.addWidget(self._scale_sl)
ra, la = _row(sb, spacing=6)
auto_btn = QPushButton("Auto scale")
auto_btn.clicked.connect(self._auto_scale)
la.addWidget(auto_btn)
ly.addWidget(ra)
ly.addWidget(_sep(sb))
# ── DATA ─────────────────────────────────────────────────────────
ly.addWidget(_section("DATA", sb))
self._total_lbl = QLabel("Total: 0 trials")
self._total_lbl.setStyleSheet(f"color:{_TEXT};font-size:11px;")
ly.addWidget(self._total_lbl)
ly.addSpacing(4)
export_btn = QPushButton("Export PNG …")
export_btn.setToolTip("Save the current compare-evoked plot as a PNG image")
export_btn.clicked.connect(self._export_png)
ly.addWidget(export_btn)
ly.addStretch()
scroll.setWidget(sb)
return scroll
# -----------------------------------------------------------------------
# Sidebar callbacks
# -----------------------------------------------------------------------
def _toggle_cond(self, cond: str, visible: bool) -> None:
for ch in self._disp_channels:
self._curves[ch][cond].setVisible(visible)
if visible and self._show_sem:
self._sem_fills[ch][cond].setVisible(True)
else:
self._sem_fills[ch][cond].setVisible(False)
if not visible:
self._peaks[ch][cond].setVisible(False)
def _toggle_sem(self, visible: bool) -> None:
self._show_sem = visible
for ch in self._disp_channels:
for cond in self._conditions:
cond_vis = self._cond_checks.get(cond, QCheckBox()).isChecked()
self._sem_fills[ch][cond].setVisible(visible and cond_vis)
def _toggle_peaks(self, visible: bool) -> None:
self._show_peak = visible
for ch in self._disp_channels:
for cond in self._conditions:
cond_vis = self._cond_checks.get(cond, QCheckBox()).isChecked()
self._peaks[ch][cond].setVisible(
visible and cond_vis and bool(self._epoch_buf[cond])
)
def _on_scale(self, value: int) -> None:
self._yscale = value * 0.2
self._sv_lbl.setText(f"×{self._yscale:.1f}")
self._apply_y_range()
def _apply_y_range(self) -> None:
avgs_scaled: list[np.ndarray] = []
for cond in self._conditions:
if self._cond_checks.get(cond, QCheckBox()).isChecked():
buf = self._epoch_buf[cond]
if buf:
avg = np.mean(np.stack(buf, 0), 0) * self._unit_scale
avgs_scaled.append(avg)
if not avgs_scaled:
return
amp = float(np.percentile(np.abs(np.stack(avgs_scaled, 0)), 99)) or 1e-12
half = amp * self._yscale
for plot in self._ch_plots.values():
if plot.isVisible():
plot.setYRange(-half, half, padding=0.05)
def _auto_scale(self) -> None:
self._scale_sl.setValue(5)
for plot in self._ch_plots.values():
if plot.isVisible():
plot.enableAutoRange(axis="y")
def _export_png(self) -> None:
path, _ = QFileDialog.getSaveFileName(
self, "Export Compare Evoked", "compare_evoked.png",
"PNG Image (*.png);;JPEG Image (*.jpg)",
)
if path:
self.grab().save(path)
# -----------------------------------------------------------------------
# Public API
# -----------------------------------------------------------------------
[docs]
def update(
self,
data: np.ndarray,
conditions: list[str],
) -> None:
"""Redraw all channel plots with updated condition averages.
Thread-safe — may be called from the acquisition thread.
Parameters
----------
data : ndarray, shape (n_epochs, n_channels, n_times)
All accepted epochs so far.
conditions : list of str
Condition label for each epoch; length == ``data.shape[0]``.
"""
n_total = len(conditions)
for cond in self._conditions:
mask = np.array([c == cond for c in conditions])
self._epoch_buf[cond] = list(data[mask]) if mask.any() else []
self._n_per[cond] = int(mask.sum())
self._redraw_sig.emit(n_total)
def _redraw(self, n_total: int) -> None:
"""Slot — always runs on the main/GUI thread."""
times_ms = self._times * 1000.0
t0_idx = int(np.searchsorted(times_ms, 0.0))
for cond in self._conditions:
buf = self._epoch_buf[cond]
n = len(buf)
cond_vis = self._cond_checks.get(cond, QCheckBox()).isChecked()
if buf:
stack = np.stack(buf, 0)
avg = np.mean(stack, 0)
sem = (np.std(stack, axis=0, ddof=1) / math.sqrt(n)
if n >= 2 else np.zeros_like(avg))
else:
avg = np.zeros((self._n_ch, self._n_t))
sem = np.zeros((self._n_ch, self._n_t))
for disp_pos, ch in enumerate(self._disp_channels):
ch_i = self._disp_idx[disp_pos]
avg_scaled = avg[ch_i] * self._unit_scale
sem_scaled = sem[ch_i] * self._unit_scale
self._curves[ch][cond].setData(times_ms, avg_scaled)
self._curves[ch][cond].setVisible(cond_vis)
if n >= 2 and self._show_sem and cond_vis:
self._sem_upper[ch][cond].setData(times_ms, avg_scaled + sem_scaled)
self._sem_lower[ch][cond].setData(times_ms, avg_scaled - sem_scaled)
self._sem_fills[ch][cond].setVisible(True)
else:
self._sem_fills[ch][cond].setVisible(False)
if buf and self._show_peak and cond_vis:
post = avg_scaled[t0_idx:]
if post.size > 0:
pk_local = int(np.argmax(np.abs(post)))
self._peaks[ch][cond].setData(
[times_ms[t0_idx + pk_local]],
[avg_scaled[t0_idx + pk_local]],
)
self._peaks[ch][cond].setVisible(True)
else:
self._peaks[ch][cond].setVisible(False)
self._cond_n_lbl[cond].setText(f"n = {self._n_per[cond]}")
self._total_lbl.setText(f"Total: {n_total} trials")
if self._scale_sl.value() != 5:
self._apply_y_range()