Source code for mne_rt.viz.tfr_plot

"""Real-time time-frequency representation (TFR) display.

Computes Morlet wavelet power for selected channels and displays
as colour-coded heatmaps (time × frequency) after each new batch
of epochs arrives via :meth:`TFRPlot.update`.

Classes
-------
TFRPlot
    Real-time TFR display with interactive sidebar.
"""
from __future__ import annotations

import math
import threading
from typing import Optional, Union

import numpy as np

try:
    from PyQt6.QtCore import Qt, QRectF, pyqtSignal
    from PyQt6.QtGui  import QFont, QColor, QTransform
    from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget,
        QVBoxLayout, QHBoxLayout, QLabel, QCheckBox, QSlider, QPushButton,
        QFrame, QSizePolicy, QScrollArea, QFileDialog, QComboBox,
        QDoubleSpinBox)
    _qt_available = True
except ImportError:
    _qt_available = False

try:
    import pyqtgraph as pg
    _pg_available = True
except ImportError:
    _pg_available = False

try:
    import mne
    import mne.time_frequency
    _mne_available = True
except ImportError:
    _mne_available = False

from mne_rt._logging import logger, set_log_level


# ---------------------------------------------------------------------------
# Palette  (shared with ERPPlot)
# ---------------------------------------------------------------------------
_BG      = "#0d1117"
_SURFACE = "#161b22"
_BORDER  = "#30363d"
_TEXT    = "#e6edf3"
_DIM     = "#8b949e"
_ACCENT  = "#3b82f6"

_COND_COLORS = [
    "#3b82f6",   # blue
    "#ec4899",   # pink
    "#10b981",   # green
    "#f59e0b",   # amber
    "#8b5cf6",   # violet
    "#06b6d4",   # cyan
]

_SIDEBAR_W = 210   # px

# Fallback thermal colormap stops when matplotlib is unavailable
_THERMAL_POS   = [0.0, 0.25, 0.5, 0.75, 1.0]
_THERMAL_COLOR = ["#000000", "#1a237e", "#e53935", "#ffeb3b", "#ffffff"]


# ---------------------------------------------------------------------------
# Sidebar helpers  (same API as erp_plot.py)
# ---------------------------------------------------------------------------

def _sep(parent: QWidget) -> QFrame:
    f = QFrame(parent)
    f.setFrameShape(QFrame.Shape.HLine)
    f.setStyleSheet(f"color:{_BORDER};")
    return f


def _section(text: str, parent: QWidget) -> QLabel:
    lbl = QLabel(text, parent)
    lbl.setStyleSheet(
        f"color:{_DIM}; font-size:10px; font-weight:700; "
        "letter-spacing:1px; padding-top:6px;"
    )
    return lbl


def _row(parent: QWidget, spacing: int = 5) -> tuple[QWidget, QHBoxLayout]:
    w = QWidget(parent)
    lay = QHBoxLayout(w)
    lay.setContentsMargins(0, 1, 0, 1)
    lay.setSpacing(spacing)
    return w, lay


def _val_lbl(text: str, parent: QWidget, color: str = _ACCENT) -> QLabel:
    lbl = QLabel(text, parent)
    lbl.setStyleSheet(f"color:{color}; font-size:11px; font-weight:600;")
    lbl.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter)
    return lbl


def _key_lbl(text: str, parent: QWidget) -> QLabel:
    lbl = QLabel(text, parent)
    lbl.setStyleSheet(f"color:{_TEXT}; font-size:11px;")
    return lbl


def _slider(parent: QWidget, lo: int, hi: int, val: int) -> QSlider:
    sl = QSlider(Qt.Orientation.Horizontal, parent)
    sl.setRange(lo, hi)
    sl.setValue(val)
    return sl


# ---------------------------------------------------------------------------
# Main class
# ---------------------------------------------------------------------------

[docs] class TFRPlot(QMainWindow): """Real-time time-frequency representation (TFR). Computes Morlet wavelet power for selected channels and displays as colour-coded heatmaps (time × frequency) after each new batch of epochs arrives via :meth:`update`. Two modes: - **induced**: average of per-epoch TFR → total power (including non-phase-locked oscillations). - **evoked**: TFR of the trial average → only phase-locked power. Baseline correction uses dB change: ``10 * log10(power / baseline_mean)``. Parameters ---------- ch_names : list of str Electrode names in data order. sfreq : float Sampling frequency in Hz. tmin : float Epoch start (s). tmax : float Epoch end (s). event_id : dict[str, int] Condition label → marker integer. freqs : ndarray or None Frequencies of interest (Hz). Default ``np.arange(4, 50, 2)``. n_cycles : ndarray, float, or None Number of cycles per frequency for Morlet wavelets. Default ``freqs / 2`` (half-cycle per frequency). channels : list of str or None Channels to display. When ``None``, up to 4 channels are auto-selected (see :meth:`_auto_channels`). mode : {'induced', 'evoked'} Computation mode. ``'induced'`` averages per-epoch TFRs; ``'evoked'`` computes the TFR of the trial average. baseline : tuple or None, default (None, 0) Baseline interval ``(tmin, tmax)`` in seconds used for dB normalisation. ``None`` on either side means epoch edge. decim : int, default 4 Decimation factor applied along the time axis before display. info : mne.Info or None Unused at present; reserved for future layout features. window_size : tuple of int, default (1440, 900) Initial window size in pixels. verbose : bool, str, or None .. versionadded:: 1.0.0 See Also -------- mne_rt.RTEpochs : Drives this plot via :meth:`update`. """ # Emitted from the worker thread; Qt delivers it to _redraw on the main # thread, keeping all widget mutations on the GUI thread. _redraw_sig = pyqtSignal(int)
[docs] def __init__( self, ch_names: list[str], sfreq: float, tmin: float, tmax: float, event_id: dict[str, int], freqs: Optional[np.ndarray] = None, n_cycles: Union[np.ndarray, float, None] = None, channels: Optional[list[str]] = None, mode: str = "induced", baseline: Optional[tuple] = (None, 0), decim: int = 4, info=None, montage: str = "standard_1020", window_size: tuple[int, int] = (1440, 900), verbose: Union[bool, str, None] = None, ) -> None: if not _qt_available or not _pg_available: raise ImportError( "PyQt6 and pyqtgraph are required for TFRPlot.\n" "Install with: pip install 'mne-rt[full]'" ) _app = QApplication.instance() or QApplication([]) # noqa: F841 super().__init__() self._redraw_sig.connect(self._redraw) set_log_level(verbose) # ── Public attributes ──────────────────────────────────────────── self.ch_names = list(ch_names) self.sfreq = float(sfreq) self.tmin = float(tmin) self.tmax = float(tmax) self.event_id = event_id self.mode = mode.lower() self.baseline = baseline self.decim = max(1, int(decim)) # ── Conditions ─────────────────────────────────────────────────── self._conditions = list(event_id.keys()) self._cmap = { c: _COND_COLORS[i % len(_COND_COLORS)] for i, c in enumerate(self._conditions) } # ── Time axis ──────────────────────────────────────────────────── self._n_t = int(round((tmax - tmin) * sfreq)) + 1 self._times = np.linspace(tmin, tmax, self._n_t) # Decimated time axis — mirrors what tfr_array_morlet returns with # decim applied along the last axis. self._times_dec = self._times[::self.decim] self._n_t_dec = len(self._times_dec) # ── Frequencies ────────────────────────────────────────────────── if freqs is None: freqs = np.arange(4, 50, 2, dtype=float) self._freqs = np.asarray(freqs, dtype=float) if n_cycles is None: self._n_cycles = self._freqs / 2.0 else: self._n_cycles = n_cycles self._clip_freqs() self.montage = montage # ── Display channels ───────────────────────────────────────────── self._display_chs = self._auto_channels(channels) self._display_idx = [self.ch_names.index(c) for c in self._display_chs] # Scalp positions for topomap channel selector (normalised, yn=0=frontal) self._norm_pos = self._compute_positions(info, montage) self._topo_scatter: Optional[pg.ScatterPlotItem] = None # ── Normalisation mode ─────────────────────────────────────────── # 'db' — dB change from baseline (default) # 'raw' — raw power in µV²/Hz (no normalisation) self._norm_mode = "db" # ── Thread-safety state ────────────────────────────────────────── self._computing = False self._latest_data: Optional[np.ndarray] = None self._latest_conds: list[str] = [] self._tfr_result: dict[str, np.ndarray] = {} # ── Colormap & color limits ────────────────────────────────────── self._cmap_name = "hot" self._cmap_lut = self._build_colormap() self._vmin: Optional[float] = None # None = auto self._vmax: Optional[float] = None # ── Widget dicts (populated in _build_ui) ──────────────────────── self._image_items: list[list[pg.ImageItem]] = [] # [cond_i][ch_i] self._plot_items: list[list[pg.PlotItem]] = [] # [cond_i][ch_i] self._ch_row_checks: dict[str, QCheckBox] = {} self._cond_n_lbl: dict[str, QLabel] = {} # ── Build window ───────────────────────────────────────────────── self.setWindowTitle("MNE-RT — TFR Plot") self.resize(*window_size) self._apply_styles() self._build_ui() logger.info( "TFRPlot: %d display channels, %d conditions, " "freqs %.0f%.0f Hz, mode=%s", len(self._display_chs), len(self._conditions), self._freqs[0], self._freqs[-1], self.mode, )
# ----------------------------------------------------------------------- # Frequency clipping # ----------------------------------------------------------------------- def _clip_freqs(self) -> None: """Ensure no Morlet wavelet exceeds the epoch length. MNE uses ``n_sigmas = 5`` in :func:`mne.time_frequency.morlet`, so the wavelet fits when ``n_cycles < (n_t - 1) * π * f / (5 * sfreq)``. Strategy: first remove frequencies where even 2 cycles would not fit (too few for a meaningful TFR); then clip ``n_cycles`` on the remaining frequencies so the wavelet stays within the epoch. """ nc = (np.asarray(self._n_cycles, float) if not np.isscalar(self._n_cycles) else np.full(len(self._freqs), float(self._n_cycles))) # Maximum n_cycles that fits: nc < (n_t-1)*pi*f / (5*sfreq) nc_max = (self._n_t - 1) * np.pi * self._freqs / (5.0 * self.sfreq) # Drop frequencies where even 2 cycles cannot fit mask = nc_max >= 2.0 if not mask.all(): logger.info( "TFRPlot: removed %d freq(s) below %.1f Hz " "(epoch too short for ≥2 cycles).", int((~mask).sum()), float(self._freqs[mask][0]) if mask.any() else 0.0, ) self._freqs = self._freqs[mask] nc = nc[mask] nc_max = nc_max[mask] if len(self._freqs) == 0: # Epoch is very short — use the highest plausible frequency range f_min = max(10.0 * self.sfreq / ((self._n_t - 1) * np.pi), 8.0) self._freqs = np.arange(f_min, min(f_min + 30.0, 80.0), 2.0) nc_max = (self._n_t - 1) * np.pi * self._freqs / (5.0 * self.sfreq) self._n_cycles = np.maximum(nc_max * 0.90, 1.0) logger.warning( "TFRPlot: epoch too short for standard TFR — " "using %.0f%.0f Hz with reduced n_cycles.", float(self._freqs[0]), float(self._freqs[-1]), ) return # Clip n_cycles to fit within epoch (10 % safety margin) self._n_cycles = np.minimum(nc, np.maximum(nc_max * 0.90, 1.0)) # ----------------------------------------------------------------------- # Scalp layout (for topomap channel selector) # ----------------------------------------------------------------------- def _compute_positions( self, info, montage_name: str ) -> list[tuple[float, float]]: import math as _math if _mne_available: if info is not None: pos = self._from_layout(mne.channels.find_layout(info)) if pos is not None: return pos try: tmp = mne.create_info( self.ch_names, sfreq=1.0, ch_types="eeg", verbose=False ) mont = mne.channels.make_standard_montage(montage_name) tmp.set_montage(mont, on_missing="ignore", verbose=False) pos = self._from_layout(mne.channels.find_layout(tmp)) if pos is not None: return pos except Exception as exc: logger.debug("TFRPlot: montage layout failed: %s", exc) n = len(self.ch_names) return [ (0.5 + 0.42 * _math.cos(2 * _math.pi * i / n - _math.pi / 2), 0.5 + 0.42 * _math.sin(2 * _math.pi * i / n - _math.pi / 2)) for i in range(n) ] def _from_layout(self, layout) -> Optional[list[tuple[float, float]]]: if layout is None: return None name_xy: dict[str, tuple[float, float]] = {} for name, pos in zip(layout.names, layout.pos): xc = float(pos[0] + pos[2] / 2.0) yc = float(pos[1] + pos[3] / 2.0) name_xy[name] = (xc, 1.0 - yc) n_matched = sum(1 for c in self.ch_names if c in name_xy) if n_matched < len(self.ch_names) // 2: return None positions: list[tuple[float, float]] = [] fb = 0 for ch in self.ch_names: if ch in name_xy: positions.append(name_xy[ch]) else: positions.append((0.02, fb * 0.05)) fb += 1 return positions # ----------------------------------------------------------------------- # Grid rebuild (called when channel selection changes) # ----------------------------------------------------------------------- def _rebuild_grid(self) -> None: """Clear and recreate TFR grid for the current _display_chs.""" self._gl_widget.clear() self._image_items = [[] for _ in range(len(self._conditions))] self._plot_items = [[] for _ in range(len(self._conditions))] self._display_idx = [self.ch_names.index(c) for c in self._display_chs] self._build_tfr_grid() if self._latest_data is not None and not self._computing: threading.Thread(target=self._compute_and_emit, daemon=True).start() # ----------------------------------------------------------------------- # Channel auto-selection # ----------------------------------------------------------------------- def _auto_channels(self, channels: Optional[list[str]]) -> list[str]: """Return up to 4 channels to display. If *channels* is given, return those that exist in ``ch_names`` (warn about missing ones). Otherwise prefer the canonical set ``['Cz','Pz','Oz','Fz','C3','C4']``, padding with the first available channels when fewer than 4 are found. """ if channels is not None: valid = [c for c in channels if c in self.ch_names] missing = [c for c in channels if c not in self.ch_names] if missing: logger.warning( "TFRPlot: channels not found in ch_names and will be " "ignored: %s", missing ) return valid or self.ch_names[:4] preferred = ["Cz", "Pz", "Oz", "Fz", "C3", "C4"] found: list[str] = [] # Case-sensitive first pass for c in preferred: if c in self.ch_names and c not in found: found.append(c) # Case-insensitive second pass ch_upper = [c.upper() for c in self.ch_names] for c in preferred: if c not in found: try: idx = ch_upper.index(c.upper()) found.append(self.ch_names[idx]) except ValueError: pass # Pad with the first N channels if fewer than 4 found for c in self.ch_names: if len(found) >= 4: break if c not in found: found.append(c) return found[:4] # ----------------------------------------------------------------------- # Colormap # ----------------------------------------------------------------------- def _build_colormap(self) -> np.ndarray: """Return a 256×3 uint8 LUT for the current ``_cmap_name``.""" name = getattr(self, "_cmap_name", "hot") try: cmap = pg.colormap.get(name, source="matplotlib") return cmap.getLookupTable(0.0, 1.0, 256) except Exception: pass # Inline fallbacks for the most important maps if name == "RdBu_r": stops = [0.0, 0.25, 0.5, 0.75, 1.0] colors = ["#053061", "#4393c3", "#f7f7f7", "#d6604d", "#67001f"] elif name == "viridis": stops = [0.0, 0.33, 0.67, 1.0] colors = ["#440154", "#31688e", "#35b779", "#fde725"] elif name == "plasma": stops = [0.0, 0.33, 0.67, 1.0] colors = ["#0d0887", "#cc4778", "#f89441", "#f0f921"] else: stops = _THERMAL_POS colors = _THERMAL_COLOR cmap = pg.ColorMap(stops, [QColor(c).getRgb()[:3] for c in colors]) return cmap.getLookupTable(0.0, 1.0, 256) # ----------------------------------------------------------------------- # Style sheet # ----------------------------------------------------------------------- def _apply_styles(self) -> None: self.setStyleSheet(f""" QMainWindow, QWidget {{ background:{_BG}; color:{_TEXT}; }} QLabel {{ color:{_TEXT}; font-size:12px; }} QCheckBox {{ color:{_TEXT}; font-size:12px; spacing:6px; }} QCheckBox::indicator {{ width:14px; height:14px; border-radius:3px; border:1px solid {_BORDER}; background:{_SURFACE}; }} QCheckBox::indicator:checked {{ background:{_ACCENT}; border-color:{_ACCENT}; }} QPushButton {{ background:{_SURFACE}; color:{_TEXT}; border:1px solid {_BORDER}; border-radius:5px; padding:4px 10px; font-size:11px; }} QPushButton:hover {{ background:{_BORDER}; }} QComboBox {{ background:{_SURFACE}; color:{_TEXT}; border:1px solid {_BORDER}; border-radius:4px; padding:2px 6px; font-size:11px; }} QComboBox::drop-down {{ border:none; }} QSlider::groove:horizontal {{ height:4px; background:{_BORDER}; border-radius:2px; }} QSlider::handle:horizontal {{ width:14px; height:14px; margin:-5px 0; border-radius:7px; background:{_ACCENT}; }} QScrollArea {{ border: none; }} QScrollBar:vertical {{ background:{_BG}; width:6px; }} QScrollBar::handle:vertical {{ background:{_BORDER}; border-radius:3px; }} """) # ----------------------------------------------------------------------- # UI build # ----------------------------------------------------------------------- def _build_ui(self) -> None: central = QWidget() self.setCentralWidget(central) root = QHBoxLayout(central) root.setSpacing(0) root.setContentsMargins(0, 0, 0, 0) pg.setConfigOptions(antialias=False, background=_BG, foreground=_DIM) # ── Canvas ────────────────────────────────────────────────────── self._gl_widget = pg.GraphicsLayoutWidget() self._gl_widget.setBackground(_BG) self._gl_widget.setSizePolicy( QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding ) root.addWidget(self._gl_widget, stretch=1) # ── Sidebar ───────────────────────────────────────────────────── root.addWidget(self._build_sidebar()) # ── TFR grid ───────────────────────────────────────────────────── self._build_tfr_grid() def _build_tfr_grid(self) -> None: """Populate the GraphicsLayoutWidget with PlotItem / ImageItem cells.""" layout = self._gl_widget.ci # the central GraphicsLayout n_conds = len(self._conditions) n_chs = len(self._display_chs) # Row 0: condition title labels for cond_i, cond in enumerate(self._conditions): col = _COND_COLORS[cond_i % len(_COND_COLORS)] title_lbl = layout.addLabel( cond, row=0, col=cond_i, color=col ) title_lbl.setText( f'<span style="color:{col};font-size:10pt;font-weight:700;">' f'{cond}</span>' ) # Rows 1…n_chs: one row per display channel, one column per condition self._image_items = [[] for _ in range(n_conds)] self._plot_items = [[] for _ in range(n_conds)] is_last_row = lambda ch_i: ch_i == n_chs - 1 # noqa: E731 is_first_col = lambda cond_i: cond_i == 0 # noqa: E731 for ch_i, ch_name in enumerate(self._display_chs): for cond_i in range(n_conds): show_x = is_last_row(ch_i) show_y = is_first_col(cond_i) ch_title = ch_name # show channel label in every column plot = layout.addPlot(row=ch_i + 1, col=cond_i) self._style_plot(plot, show_x=show_x, show_y=show_y, title=ch_title) img = pg.ImageItem() img.setLookupTable(self._cmap_lut) plot.addItem(img) # Set axis transform: x = time in ms, y = frequency in Hz self._set_image_transform(img) self._image_items[cond_i].append(img) self._plot_items[cond_i].append(plot) # ── Align all axes explicitly ───────────────────────────────────── f_min = float(self._freqs[0]) if len(self._freqs) else 0.0 f_max = float(self._freqs[-1]) if len(self._freqs) else 100.0 t0_ms = float(self._times_dec[0] * 1000.0) t1_ms = float(self._times_dec[-1] * 1000.0) for ch_i in range(n_chs): for cond_i in range(n_conds): self._plot_items[cond_i][ch_i].setXRange(t0_ms, t1_ms, padding=0) self._plot_items[cond_i][ch_i].setYRange(f_min, f_max, padding=0) def _set_image_transform(self, img: "pg.ImageItem") -> None: """Apply QTransform so image axes show ms / Hz values.""" t0_ms = self._times_dec[0] * 1000.0 t1_ms = self._times_dec[-1] * 1000.0 f0_hz = float(self._freqs[0]) f1_hz = float(self._freqs[-1]) n_t = max(1, len(self._times_dec)) n_f = max(1, len(self._freqs)) tr = QTransform() tr.translate(t0_ms, f0_hz) tr.scale((t1_ms - t0_ms) / n_t, (f1_hz - f0_hz) / n_f) img.setTransform(tr) def _style_plot( self, plot: "pg.PlotItem", show_x: bool = False, show_y: bool = False, title: str = "", ) -> None: plot.setMenuEnabled(False) plot.hideButtons() plot.setMouseEnabled(x=False, y=False) plot.getViewBox().disableAutoRange() # prevent auto-range from breaking alignment plot.showAxis("top", False) plot.showAxis("right", False) plot.showAxis("bottom", show_x) plot.showAxis("left", show_y) if show_x: plot.getAxis("bottom").setLabel("Time", units="ms") plot.getAxis("bottom").setStyle(tickFont=QFont("Helvetica", 7)) if show_y: plot.getAxis("left").setLabel("Freq", units="Hz") plot.getAxis("left").setStyle(tickFont=QFont("Helvetica", 7)) if title: plot.setTitle(title, color=_DIM, size="9pt") plot.setContentsMargins(0, 0, 0, 0) # ----------------------------------------------------------------------- # Sidebar # ----------------------------------------------------------------------- def _build_sidebar(self) -> QScrollArea: scroll = QScrollArea() scroll.setFixedWidth(_SIDEBAR_W + 14) scroll.setWidgetResizable(True) scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) scroll.setStyleSheet( f"QScrollArea {{ background:{_SURFACE}; " f"border-left:1px solid {_BORDER}; }}" ) sb = QWidget() sb.setStyleSheet(f"background:{_SURFACE};") ly = QVBoxLayout(sb) ly.setSpacing(4) ly.setContentsMargins(10, 12, 10, 12) # ── Header ─────────────────────────────────────────────────────── hdr = QLabel("TFR CONTROLS") hdr.setStyleSheet( f"color:{_TEXT}; font-size:11px; font-weight:700; letter-spacing:1.5px;" ) ly.addWidget(hdr) ly.addWidget(_sep(sb)) # ── CHANNELS (topomap selector) ─────────────────────────────────── ly.addWidget(_section("CHANNELS", sb)) self._sel_lbl = QLabel(", ".join(self._display_chs)) self._sel_lbl.setStyleSheet( f"color:{_ACCENT}; font-size:10px; font-weight:600;" ) self._sel_lbl.setWordWrap(True) ly.addWidget(self._sel_lbl) cap_lbl = QLabel("click to toggle (≤8 for performance)") cap_lbl.setStyleSheet(f"color:{_DIM}; font-size:9px;") ly.addWidget(cap_lbl) ly.addSpacing(4) topo_w = self._build_topo_widget(sb) topo_row = QWidget(sb) topo_lay = QHBoxLayout(topo_row) topo_lay.setContentsMargins(0, 0, 0, 0) topo_lay.addStretch() topo_lay.addWidget(topo_w) topo_lay.addStretch() ly.addWidget(topo_row) ly.addWidget(_sep(sb)) # ── CONDITIONS ─────────────────────────────────────────────────── ly.addWidget(_section("CONDITIONS", sb)) self._cond_n_lbl = {} for cond in self._conditions: col = self._cmap[cond] row_w, row_l = _row(sb) dot = QLabel(f"● {cond}") dot.setStyleSheet(f"color:{col}; font-size:11px; font-weight:600;") dot.setWordWrap(True) n_lbl = _val_lbl("n = 0", sb, color=_DIM) self._cond_n_lbl[cond] = n_lbl row_l.addWidget(dot, stretch=1) row_l.addWidget(n_lbl) ly.addWidget(row_w) ly.addWidget(_sep(sb)) # ── MODE ───────────────────────────────────────────────────────── ly.addWidget(_section("MODE", sb)) self._mode_combo = QComboBox(sb) self._mode_combo.addItem("Induced (total power)") self._mode_combo.addItem("Evoked (phase-locked)") self._mode_combo.setCurrentIndex(0 if self.mode == "induced" else 1) self._mode_combo.currentIndexChanged.connect(self._on_mode_change) ly.addWidget(self._mode_combo) ly.addWidget(_sep(sb)) # ── FREQ RANGE ─────────────────────────────────────────────────── ly.addWidget(_section("FREQ RANGE", sb)) r_fstart, l_fstart = _row(sb) l_fstart.addWidget(_key_lbl("Start", sb), stretch=1) self._fstart_lbl = _val_lbl(f"{int(self._freqs[0])} Hz", sb) l_fstart.addWidget(self._fstart_lbl) ly.addWidget(r_fstart) self._fstart_sl = _slider(sb, 2, 50, int(self._freqs[0])) self._fstart_sl.valueChanged.connect(self._on_freq_range) ly.addWidget(self._fstart_sl) r_fend, l_fend = _row(sb) l_fend.addWidget(_key_lbl("End", sb), stretch=1) self._fend_lbl = _val_lbl(f"{int(self._freqs[-1])} Hz", sb) l_fend.addWidget(self._fend_lbl) ly.addWidget(r_fend) self._fend_sl = _slider(sb, 4, 100, int(self._freqs[-1])) self._fend_sl.valueChanged.connect(self._on_freq_range) ly.addWidget(self._fend_sl) ly.addWidget(_sep(sb)) # ── NORMALISATION ──────────────────────────────────────────────── ly.addWidget(_section("NORMALIZATION", sb)) self._norm_combo = QComboBox(sb) self._norm_combo.addItem("dB change (baseline)") self._norm_combo.addItem("Raw power (µV²/Hz)") self._norm_combo.setCurrentIndex(0) self._norm_combo.currentIndexChanged.connect(self._on_norm_change) ly.addWidget(self._norm_combo) ly.addWidget(_sep(sb)) # ── COLORMAP ───────────────────────────────────────────────────── ly.addWidget(_section("COLORMAP", sb)) self._cmap_combo = QComboBox(sb) for label in ["Hot", "RdBu (div.)", "Viridis", "Plasma", "Turbo", "Greys"]: self._cmap_combo.addItem(label) self._cmap_combo.currentIndexChanged.connect(self._on_cmap_change) ly.addWidget(self._cmap_combo) ly.addWidget(_sep(sb)) # ── COLOR LIMITS ───────────────────────────────────────────────── ly.addWidget(_section("COLOR LIMITS", sb)) self._auto_lvl_chk = QCheckBox("Auto") self._auto_lvl_chk.setChecked(True) self._auto_lvl_chk.toggled.connect(self._on_auto_levels) ly.addWidget(self._auto_lvl_chk) r_vmin, l_vmin = _row(sb) l_vmin.addWidget(_key_lbl("vmin", sb), stretch=1) self._vmin_spin = QDoubleSpinBox(sb) self._vmin_spin.setRange(-200.0, 0.0) self._vmin_spin.setValue(-3.0) self._vmin_spin.setSingleStep(0.5) self._vmin_spin.setDecimals(1) self._vmin_spin.setEnabled(False) self._vmin_spin.setStyleSheet( f"QDoubleSpinBox{{background:{_SURFACE};color:{_TEXT};" f"border:1px solid {_BORDER};border-radius:3px;padding:1px 4px;}}" ) self._vmin_spin.valueChanged.connect(self._on_vmin_change) l_vmin.addWidget(self._vmin_spin) ly.addWidget(r_vmin) r_vmax, l_vmax = _row(sb) l_vmax.addWidget(_key_lbl("vmax", sb), stretch=1) self._vmax_spin = QDoubleSpinBox(sb) self._vmax_spin.setRange(0.0, 200.0) self._vmax_spin.setValue(3.0) self._vmax_spin.setSingleStep(0.5) self._vmax_spin.setDecimals(1) self._vmax_spin.setEnabled(False) self._vmax_spin.setStyleSheet( f"QDoubleSpinBox{{background:{_SURFACE};color:{_TEXT};" f"border:1px solid {_BORDER};border-radius:3px;padding:1px 4px;}}" ) self._vmax_spin.valueChanged.connect(self._on_vmax_change) l_vmax.addWidget(self._vmax_spin) ly.addWidget(r_vmax) ly.addWidget(_sep(sb)) # ── DATA ───────────────────────────────────────────────────────── ly.addWidget(_section("DATA", sb)) self._total_lbl = QLabel("Total: 0 trials") self._total_lbl.setStyleSheet(f"color:{_TEXT}; font-size:11px;") ly.addWidget(self._total_lbl) ly.addSpacing(4) export_btn = QPushButton("Export PNG …") export_btn.setToolTip("Save the current TFR plot as a PNG image") export_btn.clicked.connect(self._export_png) ly.addWidget(export_btn) ly.addStretch() scroll.setWidget(sb) return scroll # ----------------------------------------------------------------------- # Sidebar callbacks # ----------------------------------------------------------------------- # ----------------------------------------------------------------------- # Topomap channel selector # ----------------------------------------------------------------------- def _build_topo_widget(self, parent: QWidget) -> pg.PlotWidget: pw = pg.PlotWidget(parent=parent) pw.setFixedSize(184, 184) pw.setBackground(_SURFACE) pw.hideAxis("bottom") pw.hideAxis("left") pw.getViewBox().setMouseEnabled(x=False, y=False) pw.getViewBox().setAspectLocked(True) pw.getViewBox().setRange( xRange=(-0.06, 1.06), yRange=(-0.06, 1.14), padding=0, ) theta = np.linspace(0, 2 * np.pi, 160) pw.plot(0.5 + 0.48 * np.cos(theta), 0.5 + 0.48 * np.sin(theta), pen=pg.mkPen(_BORDER, width=1.5)) pw.plot([0.47, 0.5, 0.53, 0.47], [0.97, 1.06, 0.97, 0.97], pen=pg.mkPen(_BORDER, width=1.2)) spots = [] for i, ch in enumerate(self.ch_names): xn, yn = self._norm_pos[i] tx = 0.5 + (xn - 0.5) * 0.9 ty = 1.0 - (0.5 + (yn - 0.5) * 0.9) selected = ch in self._display_chs spots.append({ "pos": (tx, ty), "data": ch, "brush": pg.mkBrush(_ACCENT if selected else _BORDER), "pen": pg.mkPen(None), "size": 10 if selected else 6, }) self._topo_scatter = pg.ScatterPlotItem( spots=spots, hoverable=True, tip=lambda x, y, data: str(data) if data else "", ) self._topo_scatter.sigClicked.connect(self._on_topo_click) pw.addItem(self._topo_scatter) hint = pg.TextItem("click to select", color=_DIM, anchor=(0.5, 1)) hint.setFont(QFont("Helvetica", 7)) hint.setPos(0.5, -0.02) pw.addItem(hint) return pw def _update_topo_colors(self) -> None: if self._topo_scatter is None: return spots = [] for i, ch in enumerate(self.ch_names): xn, yn = self._norm_pos[i] tx = 0.5 + (xn - 0.5) * 0.9 ty = 1.0 - (0.5 + (yn - 0.5) * 0.9) selected = ch in self._display_chs spots.append({ "pos": (tx, ty), "data": ch, "brush": pg.mkBrush(_ACCENT if selected else _BORDER), "pen": pg.mkPen(None), "size": 10 if selected else 6, }) self._topo_scatter.setData(spots=spots) def _on_topo_click(self, *args) -> None: points = args[-2] if len(args) >= 2 else args[0] for pt in points: ch = pt.data() if ch is None or ch not in self.ch_names: continue if ch in self._display_chs: if len(self._display_chs) > 1: self._display_chs.remove(ch) else: if len(self._display_chs) < 8: self._display_chs.append(ch) self._update_topo_colors() self._sel_lbl.setText(", ".join(self._display_chs)) self._rebuild_grid() def _on_mode_change(self, idx: int) -> None: self.mode = "induced" if idx == 0 else "evoked" self._maybe_recompute() def _on_freq_range(self) -> None: fmin = self._fstart_sl.value() fmax = self._fend_sl.value() if fmin >= fmax: return self._fstart_lbl.setText(f"{fmin} Hz") self._fend_lbl.setText(f"{fmax} Hz") self._freqs = np.arange(fmin, fmax, 2, dtype=float) if len(self._freqs) == 0: self._freqs = np.array([float(fmin)]) self._n_cycles = self._freqs / 2.0 self._clip_freqs() # Refresh transforms and realign all axes f_min = float(self._freqs[0]) if len(self._freqs) else 0.0 f_max = float(self._freqs[-1]) if len(self._freqs) else 100.0 for cond_i in range(len(self._conditions)): for ch_i in range(len(self._display_chs)): self._set_image_transform(self._image_items[cond_i][ch_i]) self._plot_items[cond_i][ch_i].setXRange( float(self._times_dec[0]*1000), float(self._times_dec[-1]*1000), padding=0 ) self._plot_items[cond_i][ch_i].setYRange(f_min, f_max, padding=0) self._maybe_recompute() def _on_norm_change(self, idx: int) -> None: self._norm_mode = "db" if idx == 0 else "raw" self._maybe_recompute() def _on_cmap_change(self, idx: int) -> None: _map = ["hot", "RdBu_r", "viridis", "plasma", "turbo", "greys"] self._cmap_name = _map[idx % len(_map)] self._cmap_lut = self._build_colormap() for ci in range(len(self._conditions)): for chi in range(len(self._display_chs)): self._image_items[ci][chi].setLookupTable(self._cmap_lut) def _on_auto_levels(self, checked: bool) -> None: self._vmin_spin.setEnabled(not checked) self._vmax_spin.setEnabled(not checked) if checked: self._vmin = None self._vmax = None else: self._vmin = self._vmin_spin.value() self._vmax = self._vmax_spin.value() self._maybe_recompute() def _on_vmin_change(self, val: float) -> None: if not self._auto_lvl_chk.isChecked(): self._vmin = val self._maybe_recompute() def _on_vmax_change(self, val: float) -> None: if not self._auto_lvl_chk.isChecked(): self._vmax = val self._maybe_recompute() def _maybe_recompute(self) -> None: """Trigger recompute if we already have data.""" if self._latest_data is not None and not self._computing: threading.Thread( target=self._compute_and_emit, daemon=True ).start() def _export_png(self) -> None: path, _ = QFileDialog.getSaveFileName( self, "Export TFR Plot", "tfr_plot.png", "PNG Image (*.png);;JPEG Image (*.jpg)", ) if path: self.grab().save(path) # ----------------------------------------------------------------------- # TFR computation # ----------------------------------------------------------------------- def _compute_tfr( self, data: np.ndarray, conditions: list[str] ) -> "dict[str, np.ndarray]": """Compute Morlet TFR for each condition. Parameters ---------- data : ndarray, shape (n_epochs, n_channels, n_times) All accumulated epochs. conditions : list of str Condition label for each epoch. Returns ------- dict Mapping condition → power array of shape ``(n_display_ch, n_freqs, n_times_dec)``. """ result: dict[str, np.ndarray] = {} for cond in self._conditions: mask = np.array([c == cond for c in conditions]) if not mask.any(): result[cond] = np.zeros( (len(self._display_idx), len(self._freqs), self._n_t_dec) ) continue ep = data[mask] # (n_ep, n_ch, n_t) if self.mode == "evoked": ep = ep.mean(0, keepdims=True) # TFR of trial average power = mne.time_frequency.tfr_array_morlet( ep.astype(np.float64), sfreq=self.sfreq, freqs=self._freqs, n_cycles=self._n_cycles, output="power", decim=self.decim, zero_mean=True, ) # (n_ep, n_ch, n_freqs, n_times_dec) avg_power = power.mean(0) # (n_ch, n_freqs, n_times_dec) if self._norm_mode == "db": # Baseline correction: dB change from mean baseline power bl_mask = self._times_dec <= 0 if bl_mask.any(): bl_power = ( avg_power[:, :, bl_mask].mean(-1, keepdims=True) + 1e-30 ) avg_power = 10.0 * np.log10(avg_power / bl_power) # If no pre-stimulus baseline, still convert to log scale else: avg_power = 10.0 * np.log10(avg_power + 1e-30) result[cond] = avg_power[self._display_idx] # select display chs return result # ----------------------------------------------------------------------- # Threading # -----------------------------------------------------------------------
[docs] def update(self, data: np.ndarray, conditions: list[str]) -> None: """Receive new epoch data and schedule a TFR recompute. Thread-safe. If a computation is already running the new data is stored and will be used at the next opportunity; the current run is not interrupted. Parameters ---------- data : ndarray, shape (n_epochs, n_channels, n_times) All accepted epochs accumulated so far. conditions : list of str Condition label for each epoch; ``len(conditions) == data.shape[0]``. """ self._latest_data = data.copy() self._latest_conds = list(conditions) if not self._computing: threading.Thread( target=self._compute_and_emit, daemon=True ).start()
def _compute_and_emit(self) -> None: self._computing = True try: tfr = self._compute_tfr(self._latest_data, self._latest_conds) self._tfr_result = tfr self._redraw_sig.emit(len(self._latest_conds)) except Exception as exc: logger.warning("TFRPlot compute error: %s", exc) finally: self._computing = False # ----------------------------------------------------------------------- # Redraw (main thread) # ----------------------------------------------------------------------- def _redraw(self, n_total: int) -> None: """Update all image items from the latest TFR result. Always called on the main (GUI) thread via the Qt signal mechanism. Uses a **shared** colour scale across all conditions so the two heatmaps are directly comparable. """ self._total_lbl.setText(f"Total: {n_total} trials") for cond_i, cond in enumerate(self._conditions): n = sum(1 for c in self._latest_conds if c == cond) self._cond_n_lbl[cond].setText(f"n = {n}") # ── Shared colour limits across ALL conditions and channels ─────── if self._vmin is not None and self._vmax is not None: g_vmin, g_vmax = float(self._vmin), float(self._vmax) else: samples = [] for cond in self._conditions: pwr = self._tfr_result.get(cond) if pwr is not None: for ch_i in range(len(self._display_chs)): flat = pwr[ch_i].ravel() if flat.any(): samples.append(flat) if samples: all_vals = np.concatenate(samples) g_vmin = float(np.percentile(all_vals, 5)) g_vmax = float(np.percentile(all_vals, 95)) if g_vmin == g_vmax: g_vmin, g_vmax = g_vmin - 1.0, g_vmax + 1.0 else: g_vmin, g_vmax = -3.0, 3.0 # ── Update images ───────────────────────────────────────────────── for cond_i, cond in enumerate(self._conditions): pwr = self._tfr_result.get(cond) if pwr is None: continue for ch_i in range(len(self._display_chs)): # ImageItem expects (n_times_dec, n_freqs) self._image_items[cond_i][ch_i].setImage( pwr[ch_i].T, levels=(g_vmin, g_vmax) ) # ── Re-lock axis ranges on every cell ───────────────────────────── if self._plot_items and len(self._freqs): f_min = float(self._freqs[0]) f_max = float(self._freqs[-1]) t0_ms = float(self._times_dec[0] * 1000.0) t1_ms = float(self._times_dec[-1] * 1000.0) for ci in range(len(self._conditions)): for chi in range(len(self._display_chs)): if ci < len(self._plot_items) and chi < len(self._plot_items[ci]): self._plot_items[ci][chi].setXRange(t0_ms, t1_ms, padding=0) self._plot_items[ci][chi].setYRange(f_min, f_max, padding=0)