Source code for mne_rt.rt_epochs

"""Event-triggered real-time epoch accumulation.

Thin orchestration layer on top of :class:`mne_lsl.stream.EpochsStream`,
which handles all buffering, baseline correction, and rejection internally.

Typical workflow
----------------
::

    rt = RTEpochs(
        event_id={"target": 1, "standard": 2},
        event_channels="STI 014",
        tmin=-0.2, tmax=0.8,
    )
    rt.connect_to_lsl()
    rt.run(n_trials=80, show_erp=True)

Classes
-------
RTEpochs
    Event-triggered epoch accumulator backed by mne_lsl.EpochsStream.
"""
from __future__ import annotations

import time
import threading
from collections import defaultdict
from typing import Callable, Optional, Union

import numpy as np

try:
    from mne_lsl.stream import StreamLSL, EpochsStream
    from mne_lsl.player import PlayerLSL
    _mne_lsl_available = True
except ImportError:
    _mne_lsl_available = False

from mne_rt._logging import logger, set_log_level


[docs] class RTEpochs: """Event-triggered epoch accumulator backed by :class:`mne_lsl.stream.EpochsStream`. Connects a :class:`~mne_lsl.stream.StreamLSL` to an :class:`~mne_lsl.stream.EpochsStream`, polls for new epochs, and optionally drives an :class:`~mne_rt.viz.TopoPlot` that redraws after every new trial. Parameters ---------- event_id : dict[str, int] Condition label → marker integer, e.g. ``{"target": 1, "standard": 2}``. event_channels : str or list of str Channel(s) in the LSL stream that carry the event codes (e.g. ``"STI 014"`` for a STIM channel, or ``"stim"``). tmin : float, default -0.2 Epoch start in seconds relative to the event. tmax : float, default 0.8 Epoch end in seconds relative to the event. baseline : tuple or None, default (None, 0) Baseline interval passed to :class:`~mne_lsl.stream.EpochsStream`. ``None`` disables correction. picks : str or list or None, default None Channel selection forwarded to :class:`~mne_lsl.stream.EpochsStream`. reject : dict or None, default None Peak-to-peak rejection thresholds, e.g. ``{"eeg": 150e-6}``. bufsize : int, default 200 Number of epochs to keep in the :class:`~mne_lsl.stream.EpochsStream` internal ring buffer. on_trial : callable or None, default None Optional callback fired after every accepted epoch:: def on_trial(n_accepted, data, event_code, condition): ... ``new_data`` is ``(n_new, n_channels, n_times)``; ``all_events`` is the current :attr:`~mne_lsl.stream.EpochsStream.events` array. verbose : bool or str or None, default None Attributes ---------- epochs_stream_ : mne_lsl.stream.EpochsStream or None The underlying :class:`~mne_lsl.stream.EpochsStream` after :meth:`connect_to_lsl` has been called. n_accepted_ : int Running count of accepted epochs since :meth:`run` started. See Also -------- mne_rt.viz.TopoPlot : Live scalp-layout ERP display driven by this class. mne_rt.viz.EpochPlot : Scrolling raw viewer with trigger/epoch overlays. mne_rt.RTStream : Continuous sliding-window stream processor. Examples -------- >>> rt = RTEpochs( ... event_id={"auditory": 1, "visual": 2}, ... event_channels="STI 014", ... tmin=-0.2, tmax=0.5, ... ) >>> rt.connect_to_lsl(mock_lsl=True, fname="sample_raw.fif") >>> rt.run(n_trials=20, show_erp=True) .. versionadded:: 1.0.0 """
[docs] def __init__( self, event_id: dict[str, int], event_channels: Union[str, list[str]], tmin: float = -0.2, tmax: float = 0.8, baseline: Optional[tuple] = (None, 0), picks: Optional[Union[str, list]] = None, reject: Optional[dict] = None, bufsize: int = 200, on_trial: Optional[Callable] = None, verbose: Union[bool, str, None] = None, ) -> None: set_log_level(verbose) if not _mne_lsl_available: raise ImportError("mne-lsl is required. Install with: pip install mne-lsl") self.event_id = event_id self.event_channels = event_channels self.tmin = tmin self.tmax = tmax self.baseline = baseline self.picks = picks self.reject = reject self.bufsize = bufsize self.on_trial = on_trial self._stream: Optional[StreamLSL] = None self._player: Optional[PlayerLSL] = None self.epochs_stream_: Optional[EpochsStream] = None self.n_accepted_: int = 0 self._stop_event = threading.Event() self._connected = False # Populated by run() — persists for get_epochs/get_evoked/save self._buf_: Optional[np.ndarray] = None # (n_trials, n_ch, n_t) self._cond_list_: list[str] = [] self._code_list_: list[int] = []
# ------------------------------------------------------------------ # Connection # ------------------------------------------------------------------
[docs] def connect_to_lsl( self, stream_name: Optional[str] = None, mock_lsl: bool = False, fname: Optional[str] = None, timeout: float = 10.0, verbose: Union[bool, str, None] = None, ) -> "RTEpochs": """Connect to an LSL stream and set up the EpochsStream. Parameters ---------- stream_name : str or None LSL stream name. ``None`` picks the first available stream. mock_lsl : bool Replay ``fname`` via :class:`~mne_lsl.player.PlayerLSL`. fname : str or None Path to a ``.fif`` file (required when ``mock_lsl=True``). timeout : float LSL connection timeout in seconds. verbose : bool or str or None Returns ------- self : RTEpochs """ if verbose is not None: set_log_level(verbose) if mock_lsl: if fname is None: raise ValueError("fname is required when mock_lsl=True.") logger.info("RTEpochs: starting mock PlayerLSL from %s", fname) self._player = PlayerLSL(fname, name="mne_rt_mock", chunk_size=16).start() time.sleep(1.5) stream_name = "mne_rt_mock" logger.info("RTEpochs: connecting StreamLSL …") self._stream = StreamLSL(bufsize=4.0, name=stream_name) self._stream.connect(acquisition_delay=0.005, timeout=timeout) logger.info( "RTEpochs: stream connected — %d ch @ %.0f Hz", self._stream.info["nchan"], self._stream.info["sfreq"], ) logger.info("RTEpochs: setting up EpochsStream …") self.epochs_stream_ = EpochsStream( stream=self._stream, bufsize=self.bufsize, event_id=self.event_id, event_channels=self.event_channels, tmin=self.tmin, tmax=self.tmax, baseline=self.baseline, picks=self.picks, reject=self.reject, ).connect(acquisition_delay=0.005) self._connected = True logger.info("RTEpochs: EpochsStream connected.") return self
# ------------------------------------------------------------------ # Main loop # ------------------------------------------------------------------
[docs] def run( self, n_trials: int = 100, show_erp: bool = False, erp_update_every: int = 1, poll_interval: float = 0.05, verbose: Union[bool, str, None] = None, ) -> "RTEpochs": """Run the epoch accumulation loop. Polls :attr:`~mne_lsl.stream.EpochsStream.n_new_epochs` and retrieves data in batches. Blocks until ``n_trials`` accepted epochs have been collected or :meth:`stop` is called. Parameters ---------- n_trials : int, default 100 Stop after this many accepted epochs. show_erp : bool, default False Open an :class:`~mne_rt.viz.TopoPlot` that redraws every ``erp_update_every`` accepted epochs. erp_update_every : int, default 1 ERP redraw cadence in number of accepted epochs. poll_interval : float, default 0.05 Seconds to sleep between polling :attr:`n_new_epochs`. verbose : bool or str or None Returns ------- self : RTEpochs """ if verbose is not None: set_log_level(verbose) if not self._connected: raise RuntimeError("Call connect_to_lsl() before run().") es = self.epochs_stream_ erp_plot = None if show_erp: from mne_rt.viz.topo_plot import TopoPlot erp_plot = TopoPlot( ch_names=list(es.info["ch_names"]), sfreq=es.info["sfreq"], tmin=self.tmin, tmax=self.tmax, event_id=self.event_id, info=es.info, # pass real Info for accurate layout baseline=self.baseline, ) erp_plot.show() inv_event = {v: k for k, v in self.event_id.items()} # Pre-allocate epoch buffer — avoids O(N²) np.stack per trial n_ch = es.info["nchan"] n_times = int(round((self.tmax - self.tmin) * es.info["sfreq"])) + 1 self._buf_ = np.zeros((n_trials, n_ch, n_times), dtype=np.float32) self._cond_list_ = [] self._code_list_ = [] self._stop_event.clear() self.n_accepted_ = 0 logger.info("RTEpochs: running — target %d trials …", n_trials) while self.n_accepted_ < n_trials and not self._stop_event.is_set(): n_new = self.epochs_stream_.n_new_epochs if n_new == 0: time.sleep(poll_interval) continue # Retrieve all new epochs at once — shape (n_new, n_ch, n_times) data = self.epochs_stream_.get_data(n_epochs=n_new) events = self.epochs_stream_.events[-n_new:] for i in range(data.shape[0]): if self.n_accepted_ >= n_trials: break code = int(events[i]) if events.ndim == 1 else int(events[i, 2]) condition = inv_event.get(code, str(code)) # Write into pre-allocated buffer (O(1) copy) ep = data[i] t = min(ep.shape[-1], n_times) self._buf_[self.n_accepted_, :, :t] = ep[:, :t] self._cond_list_.append(condition) self._code_list_.append(code) self.n_accepted_ += 1 # on_trial now receives event_code + condition directly if self.on_trial is not None: self.on_trial( self.n_accepted_, self._buf_[self.n_accepted_ - 1], # view — no copy code, condition, ) if erp_plot is not None and self.n_accepted_ % erp_update_every == 0: # Pass a view of the filled portion — no copy erp_plot.update(self._buf_[:self.n_accepted_], list(self._cond_list_)) logger.debug("RTEpochs: accepted %d (%s)", self.n_accepted_, condition) logger.info("RTEpochs: finished — %d epochs accepted.", self.n_accepted_) return self
[docs] def stop(self) -> None: """Signal the run loop to stop after the current poll.""" self._stop_event.set()
[docs] def disconnect(self) -> None: """Disconnect EpochsStream, StreamLSL, and stop any mock player.""" if self.epochs_stream_ is not None: try: self.epochs_stream_.disconnect() except Exception: pass if self._stream is not None: try: self._stream.disconnect() except Exception: pass if self._player is not None: try: self._player.stop() except Exception: pass self._connected = False logger.info("RTEpochs: disconnected.")
# ------------------------------------------------------------------ # Offline analysis helpers # ------------------------------------------------------------------
[docs] def get_epochs(self) -> "mne.EpochsArray": """Return accumulated epochs as :class:`mne.EpochsArray`. Can be called mid-run or after :meth:`run` completes. The returned object contains all epochs accepted so far and uses the real :class:`mne.Info` from the underlying stream (including channel positions and digitisation points). Returns ------- epochs : mne.EpochsArray Shape ``(n_accepted, n_channels, n_times)``. Raises ------ RuntimeError If called before :meth:`connect_to_lsl`. Examples -------- >>> rt.run(n_trials=50, show_erp=True) >>> epochs = rt.get_epochs() >>> epochs.plot_image() """ import mne if self.epochs_stream_ is None or self._buf_ is None: raise RuntimeError( "No data yet — call connect_to_lsl() then run() first." ) n = self.n_accepted_ events = np.column_stack([ np.arange(n, dtype=int), np.zeros(n, dtype=int), np.array(self._code_list_[:n], dtype=int), ]) return mne.EpochsArray( self._buf_[:n].astype(np.float64), info=self.epochs_stream_.info, events=events, event_id=self.event_id, tmin=self.tmin, verbose=False, )
[docs] def get_evoked(self) -> "dict[str, mne.EvokedArray]": """Return per-condition grand-average as :class:`mne.EvokedArray` objects. Useful for immediate offline analysis, plotting with :func:`mne.viz.plot_evoked`, or source localisation via :meth:`get_source`. Returns ------- evoked : dict[str, mne.EvokedArray] Mapping ``condition_label → EvokedArray``. Conditions with zero accepted epochs are omitted. Examples -------- >>> evoked = rt.get_evoked() >>> mne.viz.plot_evoked(evoked["auditory/left"]) """ epochs = self.get_epochs() result = {} for cond in self.event_id: try: result[cond] = epochs[cond].average() except KeyError: pass return result
[docs] def save(self, path: str, overwrite: bool = False) -> None: """Save accumulated epochs to a ``-epo.fif`` file mid-run. The file can be reloaded offline with ``mne.read_epochs(path)`` and the full MNE analysis pipeline applied. Parameters ---------- path : str Destination path. Should end with ``-epo.fif`` or ``-epo.fif.gz`` to follow MNE naming conventions. overwrite : bool, default False Overwrite an existing file. Examples -------- >>> rt.run(n_trials=30) >>> rt.save("session01-epo.fif", overwrite=True) """ self.get_epochs().save(path, overwrite=overwrite, verbose=False) logger.info("RTEpochs: saved %d epochs to %s", self.n_accepted_, path)
[docs] def get_source( self, inverse_operator, lambda2: float = 1.0 / 9.0, method: str = "dSPM", ) -> "dict[str, mne.SourceEstimate]": """Apply a pre-computed inverse operator to the current grand averages. Wraps :func:`mne.minimum_norm.apply_inverse` — load an existing inverse operator with ``mne.minimum_norm.read_inverse_operator(fname)``. Parameters ---------- inverse_operator : mne.minimum_norm.InverseOperator Pre-computed inverse operator matching the stream's Info (same channels, same channel order). lambda2 : float, default 1/9 Regularisation parameter (``1 / SNR²``). Use ``1/9`` for SNR ≈ 3 (typical ERP), ``1.0`` for noisy single-trial data. method : str, default "dSPM" Inverse method: ``"MNE"``, ``"dSPM"``, ``"sLORETA"``, or ``"eLORETA"``. Returns ------- stc_dict : dict[str, mne.SourceEstimate] Condition label → source estimate (vertex × time). Examples -------- >>> inv_op = mne.minimum_norm.read_inverse_operator("sample-inv.fif") >>> stc = rt.get_source(inv_op) >>> brain = mne_rt.BrainPlot(subject="sample", subjects_dir=sd) >>> brain.update(stc["auditory/left"].data.mean(-1)) """ import mne.minimum_norm evoked = self.get_evoked() return { cond: mne.minimum_norm.apply_inverse( ev, inverse_operator, lambda2=lambda2, method=method, verbose=False, ) for cond, ev in evoked.items() }