Source code for brainsets.datasets.Neuroprobe2025

from __future__ import annotations

from numbers import Integral
from pathlib import Path
import re
from typing import Callable, Literal, Optional, get_args

import numpy as np
from temporaldata import Data, Interval

from torch_brain.dataset import Dataset, MultiChannelDatasetMixin

from ._utils import get_processed_dir

SubsetTier = Literal["full", "lite", "nano"]
LabelMode = Literal["binary", "multiclass"]
Regime = Literal["SS-SM", "SS-DM", "DS-DM"]
Split = Literal["train", "val", "test"]

VALID_SUBSET_TIERS = get_args(SubsetTier)
VALID_LABEL_MODES = get_args(LabelMode)
VALID_REGIMES = get_args(Regime)
VALID_SPLITS = get_args(Split)

# Supported Neuroprobe task labels available in processed H5 splits.
VALID_TASKS = (
    "delta_volume",
    "face_num",
    "frame_brightness",
    "global_flow",
    "gpt2_surprisal",
    "local_flow",
    "onset",
    "pitch",
    "speech",
    "volume",
    "word_gap",
    "word_head_pos",
    "word_index",
    "word_length",
    "word_part_speech",
)

H5_REGIME_BY_REGIME: dict[Regime, str] = {
    "SS-SM": "within_session",
    "SS-DM": "cross_x",
    "DS-DM": "cross_x",
}

# Split interval and channel-mask keys share one selector key:
# <subset_tier>$<label_mode>$<eval_setting>$<task>$fold<k>$<split>

# Neuroprobe benchmark constants (mirrors neuroprobe.config)
# Fixed train subject id for benchmark-default DS-DM configuration.
DS_DM_TRAIN_SUBJECT_ID = 2
# Fixed train trial id for benchmark-default DS-DM configuration.
DS_DM_TRAIN_TRIAL_ID = 4

# Eligible (subject, trial) pairs for Neuroprobe Lite benchmark mode.
NEUROPROBE_LITE_SUBJECT_TRIALS = {
    (1, 1),
    (1, 2),
    (2, 0),
    (2, 4),
    (3, 0),
    (3, 1),
    (4, 0),
    (4, 1),
    (7, 0),
    (7, 1),
    (10, 0),
    (10, 1),
}

# Eligible (subject, trial) pairs for Neuroprobe Nano benchmark mode.
NEUROPROBE_NANO_SUBJECT_TRIALS = {
    (1, 1),
    (2, 4),
    (3, 1),
    (4, 0),
    (7, 1),
    (10, 1),
}

# Per-subject trials ranked by duration, used by full SS-DM train selection.
NEUROPROBE_LONGEST_TRIALS_FOR_SUBJECT: dict[int, list[int]] = {
    1: [0, 1],
    2: [4, 6],
    3: [2, 1],
    4: [2, 1],
    5: [0],
    6: [0, 2],
    7: [1, 0],
    8: [0],
    9: [0],
    10: [1, 0],
}

# Strict parser for canonical recording ids like "sub_1_trial004".
_RECORDING_ID_RE = re.compile(r"^sub_(\d+)_trial(\d{3})$")


def _to_recording_id(subject: Integral, session: Integral) -> str:
    # Normalize integer subject/session into the canonical H5 recording id.
    if (
        isinstance(subject, bool)
        or not isinstance(subject, Integral)
        or isinstance(session, bool)
        or not isinstance(session, Integral)
    ):
        raise ValueError(
            "_to_recording_id received invalid subject/session values: "
            f"subject={subject!r}, session={session!r}. Expected subject to be a "
            "non-negative integer and session to be an integer in 0..999."
        )

    subject_int = int(subject)
    session_int = int(session)
    if subject_int < 0 or not (0 <= session_int <= 999):
        raise ValueError(
            "_to_recording_id received invalid subject/session values: "
            f"subject={subject!r}, session={session!r}. Expected subject to be a "
            "non-negative integer and session to be an integer in 0..999."
        )
    return f"sub_{subject_int}_trial{session_int:03d}"


def _from_recording_id(recording_id: str) -> tuple[int, int]:
    # Parse canonical ids like "sub_1_trial004" back into integers.
    match = _RECORDING_ID_RE.match(recording_id)
    if match is None:
        raise ValueError(
            f"Invalid recording_id '{recording_id}'. Expected 'sub_<subject>_trial<session>' "
            "with a zero-padded 3-digit session."
        )
    return int(match.group(1)), int(match.group(2))


[docs] class Neuroprobe2025(MultiChannelDatasetMixin, Dataset): """Neuroprobe 2025 iEEG benchmark dataset. .. admonition:: Preprocessing To download and prepare this dataset, run ``brainsets prepare neuroprobe_2025``. Each instance operates in exactly one of two mutually-exclusive modes: - Neuroprobe benchmark mode (`recording_ids=None`): splits are resolved from Neuroprobe benchmark split generators. Cross-session and Cross-subject are condensed to 'cross-x' splits that will be selected for train and test. - Recording id mode (`recording_ids` provided): no splits are resolved, only recording_ids specified are preprocessed to be used as continuous data. **References** Zahorodnii, A., Wang, C., Stankovits, B., Moraitaki, C., Chau, G., Barbu, A., Katz, B., & Fiete, I. R. *Neuroprobe: Evaluating Intracranial Brain Responses to Naturalistic Stimuli.* `arXiv:2509.21671 <https://arxiv.org/abs/2509.21671>`_. Data sources: `BrainTreeBank <https://braintreebank.dev>`_ and `Neuroprobe Benchmark <https://neuroprobe.dev>`_ Args: root: Root directory containing processed Neuroprobe artifacts. Defaults to ``processed_dir`` from brainsets config. recording_ids: Optional explicit recording-id subset to expose from disk. If omitted, the dataset uses benchmark-required recording ids inferred from ``subset_tier/test_subject/test_session/split/label_mode/task/regime/fold``. transform: Optional sample transform. subset_tier: One of ``"full"``, ``"lite"``, ``"nano"``. Required in benchmark mode; must be omitted in explicit-recording mode. test_subject: Target test subject id (Neuroprobe semantics). Required in benchmark mode; must be omitted in explicit-recording mode. test_session: Target test trial/session id (Neuroprobe semantics). Required in benchmark mode; must be omitted in explicit-recording mode. split: One of ``"train"``, ``"val"``, ``"test"``. Required in benchmark mode; must be omitted in explicit-recording mode. label_mode: One of ``"binary"``, ``"multiclass"``. Defaults to ``"binary"`` in benchmark mode. task: Neuroprobe task name. Defaults to ``"speech"`` in benchmark mode. Supported values are: ``"delta_volume"``, ``"face_num"``, ``"frame_brightness"``, ``"global_flow"``, ``"gpt2_surprisal"``, ``"local_flow"``, ``"onset"``, ``"pitch"``, ``"speech"``, ``"volume"``, ``"word_gap"``, ``"word_head_pos"``, ``"word_index"``, ``"word_length"``, ``"word_part_speech"``. regime: One of ``"SS-SM"``, ``"SS-DM"``, ``"DS-DM"``. Defaults to ``"SS-SM"`` in benchmark mode. Neuroprobe regime semantics: - ``"SS-SM"``: single-subject, single-session (within-session split) - ``"SS-DM"``: single-subject, different-session (cross-x split) - ``"DS-DM"``: different-subject, different-session (cross-x split) fold: Fold index used only in benchmark mode. Defaults to ``0`` in benchmark mode and must be omitted in explicit-recording mode. Valid values depend on regime: - ``within_session``: valid {0, 1} - ``cross_x``: forced to 0 uniquify_channel_ids_with_subject: Whether to prefix channel IDs with ``subject.id`` via ``MultiChannelDatasetMixin``. Defaults to ``True``. uniquify_channel_ids_with_session: Whether to prefix channel IDs with ``session.id`` via ``MultiChannelDatasetMixin``. Defaults to ``False``. dirname: Subdirectory under ``root`` containing recording H5 files. """ _ALLOWED_FOLDS_BY_REGIME: dict[Regime, tuple[int, ...]] = { "SS-SM": (0, 1), "SS-DM": (0,), "DS-DM": (0,), } def __init__( self, root: Optional[str] = None, recording_ids: Optional[list[str]] = None, transform: Optional[Callable] = None, *, subset_tier: SubsetTier | None = None, test_subject: int | None = None, test_session: int | None = None, split: Split | None = None, label_mode: LabelMode | None = None, task: str | None = None, regime: Regime | None = None, fold: int | None = None, uniquify_channel_ids_with_subject: bool = True, uniquify_channel_ids_with_session: bool = False, dirname: str = "neuroprobe_2025", **kwargs, ): if root is None: root = get_processed_dir() # Resolve and validate constructor inputs before touching dataset records. self._dataset_dir = Path(root) / dirname # XOR recording-source behavior (exactly one source of active recording ids): # - no recording_ids => use neuroprobe benchmark split recordings # - recording_ids provided => use the explicit subset of recordings use_split_selection = recording_ids is None self._use_split_selection = use_split_selection if use_split_selection: label_mode = label_mode or "binary" task = task or "speech" regime = regime or "SS-SM" fold = fold or 0 self.subset_tier = subset_tier self.label_mode = label_mode self.task = task self.regime = regime self.fold = fold self.test_subject = test_subject self.test_session = test_session self.split = split self._validate_split_args() self.h5_regime = H5_REGIME_BY_REGIME[self.regime] active_recording_ids = self._split_recording_ids() else: unexpected_split_args = [ name for name, value in ( ("subset_tier", subset_tier), ("test_subject", test_subject), ("test_session", test_session), ("split", split), ("label_mode", label_mode), ("task", task), ("regime", regime), ("fold", fold), ) if value is not None ] if unexpected_split_args: raise ValueError( "When recording_ids is provided (explicit-recording mode), split-selection args " "must be omitted. Unexpected args: " f"{', '.join(unexpected_split_args)}." ) active_recording_ids = self._resolve_requested_recording_ids(recording_ids) if not active_recording_ids: raise ValueError( "No active recording_ids resolved for Neuroprobe2025 construction." ) super().__init__( dataset_dir=self._dataset_dir, recording_ids=active_recording_ids, transform=transform, namespace_attributes=["subject.id", "channels.id"], **kwargs, ) # Configure subject/session-based channel-id prefixing behavior. self.multichannel_dataset_mixin_uniquify_channel_ids_with_subject = ( uniquify_channel_ids_with_subject ) self.multichannel_dataset_mixin_uniquify_channel_ids_with_session = ( uniquify_channel_ids_with_session ) def get_sampling_intervals(self) -> dict[str, Interval]: """Return split-specific sampling intervals for this dataset instance.""" if not self._use_split_selection: raise RuntimeError( "get_sampling_intervals is only available in benchmark mode." ) intervals: dict[str, Interval] = {} for rid in self.recording_ids: rec = self.get_recording(rid) intervals[rid] = rec.splits return intervals def get_domain_intervals(self) -> dict[str, Interval]: """Return full-domain intervals for active recordings.""" return {rid: self.get_recording(rid).domain for rid in self.recording_ids} @property def sampling_rate(self) -> float: """Recording sampling rate in Hz.""" return 2048.0 def get_channel_metadata(self, recording_id: str) -> dict[str, np.ndarray | str]: """Return normalized channel metadata arrays for one recording.""" rec = self.get_recording(recording_id) channels = rec.channels ids = np.asarray(channels.id).astype(str) names = np.asarray(channels.name).astype(str) included_mask = np.asarray(channels.included, dtype=bool) if len(names) != len(ids): raise ValueError( f"Channel name length mismatch for recording '{recording_id}': " f"len(names)={len(names)} vs len(ids)={len(ids)}" ) if len(included_mask) != len(ids): raise ValueError( f"Channel mask length mismatch for recording '{recording_id}': " f"len(mask)={len(included_mask)} vs len(ids)={len(ids)}" ) try: lip = np.stack( ( np.asarray(channels.localization_L, dtype=float), np.asarray(channels.localization_I, dtype=float), np.asarray(channels.localization_P, dtype=float), ), axis=1, ) except AttributeError as exc: raise AttributeError( "Missing required channel localization fields for Neuroprobe2025 " f"recording '{recording_id}'. Expected channels.localization_L, " "channels.localization_I, and channels.localization_P." ) from exc if len(lip) != len(ids): raise ValueError( f"Channel localization length mismatch for recording '{recording_id}': " f"len(lip)={len(lip)} vs len(ids)={len(ids)}" ) return { "ids": ids, "names": names, "included_mask": included_mask, "coords": lip, "coords_type": "lip", "indices": np.arange(len(ids), dtype=int), } def get_recording_hook(self, data: Data): """Apply split-specific channel inclusion mask when available.""" # Explicit-recording mode does not apply benchmark split routing. if not self._use_split_selection: super().get_recording_hook(data) return recording_id = data.session.id channel_split_path = self._channel_split_attr_path() interval_path = self._interval_attr_path() # Split-selection mode requires both the channel mask and intervals. try: channel_mask = data.get_nested_attribute(channel_split_path) split_interval = data.get_nested_attribute(interval_path) except (AttributeError, KeyError) as exc: raise KeyError( "Missing required split-selection attributes for Neuroprobe2025 " f"recording '{recording_id}'. Expected channel mask at " f"'{channel_split_path}', " f"and split intervals at '{interval_path}'." ) from exc data.channels.included = channel_mask data.splits = split_interval super().get_recording_hook(data) def describe_selection(self) -> dict[str, object]: """Return a compact debug summary of the resolved benchmark selection.""" summary: dict[str, object] = { "uses_split_selection": self._use_split_selection, "active_recording_ids": list(self.recording_ids), } if not self._use_split_selection: return summary # Expose resolved split internals to make dataset/debug logs self-explanatory. summary.update( { "subset_tier": self.subset_tier, "label_mode": self.label_mode, "task": self.task, "regime": self.regime, "h5_regime": self.h5_regime, "fold": self.fold, "split": self.split, "test_subject": self.test_subject, "test_session": self.test_session, "test_recording_id": _to_recording_id( self.test_subject, self.test_session ), "split_key": self._split_key(), } ) return summary # Path/key builders. def _split_key(self) -> str: """Return the canonical key shared under `splits` and `channel_splits`.""" return ( f"{self.subset_tier}${self.label_mode}${self.h5_regime}${self.task}$" f"fold{self.fold}${self.split}" ) def _interval_attr_path(self) -> str: # Primary split interval path under data.splits. return f"splits.{self._split_key()}" def _channel_split_attr_path(self) -> str: # Return the primary path for split-specific channel masks. return f"channel_splits.{self._split_key()}" def _validate_split_args(self) -> None: # Keep constructor strict so invalid benchmark configs fail immediately. if self.subset_tier not in VALID_SUBSET_TIERS: raise ValueError( f"Invalid subset_tier '{self.subset_tier}'. Must be one of {VALID_SUBSET_TIERS}." ) if self.label_mode not in VALID_LABEL_MODES: raise ValueError( f"Invalid label_mode '{self.label_mode}'. Must be one of {VALID_LABEL_MODES}." ) if self.task not in VALID_TASKS: raise ValueError( f"Invalid task '{self.task}'. Must be one of {VALID_TASKS}." ) if self.regime not in VALID_REGIMES: raise ValueError( f"Invalid regime '{self.regime}'. Must be one of {VALID_REGIMES}." ) if not isinstance(self.fold, Integral) or isinstance(self.fold, bool): raise TypeError(f"fold must be an int, got {type(self.fold).__name__}.") allowed_folds = self._ALLOWED_FOLDS_BY_REGIME[self.regime] if self.fold not in allowed_folds: allowed_values = " or ".join(str(value) for value in allowed_folds) raise ValueError( f"Fold for regime '{self.regime}' must be {allowed_values}, got {self.fold}." ) if self.split not in VALID_SPLITS: raise ValueError( f"Invalid split '{self.split}'. Must be one of {VALID_SPLITS}." ) if not isinstance(self.test_subject, Integral) or isinstance( self.test_subject, bool ): raise TypeError( "test_subject must be an int, got " f"{type(self.test_subject).__name__}." ) if not isinstance(self.test_session, Integral) or isinstance( self.test_session, bool ): raise TypeError( "test_session must be an int, got " f"{type(self.test_session).__name__}." ) h5_regime = H5_REGIME_BY_REGIME[self.regime] if h5_regime == "cross_x" and self.subset_tier == "nano": raise ValueError( "subset_tier 'nano' is not compatible with cross_x regimes." ) if self.regime == "DS-DM" and self.test_subject == DS_DM_TRAIN_SUBJECT_ID: raise ValueError( "DS-DM benchmark-default uses subject 2 as fixed train subject; " "test_subject cannot be 2." ) # Enforce benchmark-allowed target subject/session pairs per subset-tier/regime. requested_pair = (self.test_subject, self.test_session) if ( self.subset_tier == "lite" and requested_pair not in NEUROPROBE_LITE_SUBJECT_TRIALS ): raise ValueError( f"Target pair {requested_pair} is not in NEUROPROBE_LITE_SUBJECT_TRIALS." ) if ( self.subset_tier == "nano" and requested_pair not in NEUROPROBE_NANO_SUBJECT_TRIALS ): raise ValueError( f"Target pair {requested_pair} is not in NEUROPROBE_NANO_SUBJECT_TRIALS." ) if self.regime == "SS-DM" and self.subset_tier == "full": longest_trials = NEUROPROBE_LONGEST_TRIALS_FOR_SUBJECT.get( self.test_subject, [] ) if len(longest_trials) < 2: raise ValueError( "SS-DM full benchmark-default requires at least two longest trials " f"for subject {self.test_subject}, found {longest_trials}." ) if self.test_session not in longest_trials: raise ValueError( "SS-DM full benchmark-default only supports target sessions present " f"in NEUROPROBE_LONGEST_TRIALS_FOR_SUBJECT for subject {self.test_subject}: " f"{longest_trials}." ) @classmethod def num_folds_for_regime(cls, regime: str) -> int: """Return the number of available folds for one regime.""" if regime not in VALID_REGIMES: raise ValueError( f"Invalid regime '{regime}'. Must be one of {VALID_REGIMES}." ) return len(cls._ALLOWED_FOLDS_BY_REGIME[regime]) def _resolve_requested_recording_ids(self, recording_ids: list[str]) -> list[str]: # Normalize explicit recording-id subsets to a stable, de-duplicated order. if not recording_ids: raise ValueError( "When using explicit-recording mode, recording_ids must contain at least one id." ) ids = recording_ids ids = sorted(set(ids)) if not ids: raise ValueError( "When using explicit-recording mode, recording_ids must contain at least one id." ) # Parse each id once so errors are raised consistently at construction. for rid in ids: _from_recording_id(rid) return ids def _split_recording_ids(self) -> list[str]: """Resolve split-participating recording ids for constructor inputs.""" test_recording_id = _to_recording_id(self.test_subject, self.test_session) if self.regime == "SS-SM": # Within-session uses a single target recording for all splits. return [test_recording_id] if self.regime == "SS-DM": # Cross-session trains on a different session from the same subject. if self.split == "train": return [self._ss_dm_train_recording_id_for_selection()] # Val/test evaluate on the requested target recording. return [test_recording_id] # DS-DM if self.split == "train": # Cross-subject benchmark-default uses a fixed train anchor recording. return [_to_recording_id(DS_DM_TRAIN_SUBJECT_ID, DS_DM_TRAIN_TRIAL_ID)] # Val/test evaluate on the requested held-out target recording. return [test_recording_id] def _ss_dm_train_recording_id_for_selection( self, ) -> str: # Compute SS-DM train recording using benchmark-default selection rules. if self.subset_tier == "lite": # Lite mode always defines exactly two eligible trials per subject. # Training should use "the other lite trial" relative to the test trial. subject_trials = sorted( trial for subject, trial in NEUROPROBE_LITE_SUBJECT_TRIALS if subject == self.test_subject ) if len(subject_trials) != 2: raise ValueError( "SS-DM lite benchmark-default expects exactly two lite trials " f"for subject {self.test_subject}, found {subject_trials}." ) if self.test_session not in subject_trials: raise ValueError( f"Target (test_subject={self.test_subject}, test_session={self.test_session}) " "is not eligible for lite SS-DM benchmark-default." ) # Start with the first lite trial; if that is the test trial, swap to the second. train_session = subject_trials[0] if train_session == self.test_session: train_session = subject_trials[1] return _to_recording_id(self.test_subject, train_session) if self.subset_tier == "full": # Full mode uses the longest-trial ordering table from the benchmark. # Normally pick the longest trial for training. longest_trials = NEUROPROBE_LONGEST_TRIALS_FOR_SUBJECT.get( self.test_subject, [] ) if len(longest_trials) < 2: raise ValueError( "SS-DM full benchmark-default requires at least two longest trials " f"for subject {self.test_subject}, found {longest_trials}." ) # If the longest trial is already the test target, fall back to second-longest # to keep train/test recordings distinct. train_session = longest_trials[0] if train_session == self.test_session: train_session = longest_trials[1] return _to_recording_id(self.test_subject, train_session) raise ValueError( f"subset_tier '{self.subset_tier}' is not supported for SS-DM train selection." )