"""Base pipeline classes for OpenNeuro datasets."""
from abc import ABC
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Literal, Optional
import logging
import sys
import h5py
import numpy as np
import pandas as pd
from temporaldata import Data, Interval
try:
from mne_bids import read_raw_bids
MNE_BIDS_AVAILABLE = True
except ImportError:
read_raw_bids = None
MNE_BIDS_AVAILABLE = False
from brainsets import serialize_fn_map
from brainsets.descriptions import (
BrainsetDescription,
DeviceDescription,
SessionDescription,
SubjectDescription,
)
from brainsets.pipeline import BrainsetPipeline
from brainsets.utils.bids_utils import (
build_bids_path,
fetch_eeg_recordings,
fetch_ieeg_recordings,
check_eeg_recording_files_exist,
check_ieeg_recording_files_exist,
get_subject_info,
)
from brainsets.utils.mne_utils import (
extract_signal,
extract_measurement_date,
extract_channels,
)
from brainsets.utils.split import generate_string_kfold_assignment
from brainsets.utils.openneuro import (
construct_s3_url_from_path,
download_dataset_description,
download_recording,
fetch_all_filenames,
fetch_participants_tsv,
fetch_species,
fetch_latest_snapshot_tag,
)
base_openneuro_parser = ArgumentParser()
base_openneuro_parser.add_argument("--redownload", action="store_true")
base_openneuro_parser.add_argument("--reprocess", action="store_true")
base_openneuro_parser.add_argument(
"--on-version-mismatch",
choices=["abort", "continue", "prompt"],
default="prompt",
help=(
"Behavior when origin_version differs from latest OpenNeuro version: "
"'abort' raises an error, 'continue' proceeds with warning, "
"'prompt' asks for confirmation in interactive sessions."
),
)
OpenNeuroDataModality = Literal["eeg", "ieeg"]
def _require_mne_bids(func_name: str) -> None:
"""Raise ImportError if mne-bids is not available."""
if not MNE_BIDS_AVAILABLE:
raise ImportError(
f"{func_name} requires mne-bids, which is not installed. "
"Install it with `pip install mne-bids`."
)
[docs]
class OpenNeuroPipeline(BrainsetPipeline, ABC):
"""Abstract base class for OpenNeuro dataset pipelines.
This class provides foundational tools and conventions for preprocessing and handling
`OpenNeuro <https://openneuro.org/>`_ datasets within the Brainsets framework.
It is designed to be subclassed for specific datasets and supports both EEG and iEEG modalities.
**Attributes (to be defined by subclasses):**
- :attr:`dataset_id`: Identifier for the OpenNeuro dataset (e.g., "ds005555").
- :attr:`brainset_id`: Unique local identifier for the brainset.
- :attr:`origin_version`: Version string corresponding to the raw source dataset.
- :attr:`derived_version`: Version or tag indicating the processing version of the derived data.
- :attr:`description`: Optional textual description of the dataset.
- :attr:`modality`: Data modality for this pipeline. Must be overridden by subclasses.
**Customization points:**
This class supports and encourages dataset-specific customizations via:
- :attr:`CHANNEL_NAME_REMAPPING`: Map original to standardized channel names.
- :attr:`TYPE_CHANNELS_REMAPPING`: Map channel types to specific channel names.
- :attr:`IGNORE_CHANNELS`: List channels to exclude from processing.
These can be set as class attributes or managed dynamically by overriding the following methods:
- :meth:`get_channel_name_remapping()`
- :meth:`get_type_channels_remapping()`
The :meth:`process_common` method implements the standard steps and routines shared
by all OpenNeuro datasets. This provides a consistent entry point for all dataset
processing. Subclasses may extend or override the :meth:`process` method to
implement dataset-specific processing logic.
**Documentation can be found in the official brainsets docs:**
See [Creating an OpenNeuro Pipeline](https://brainsets.readthedocs.io/en/latest/concepts/openneuro_pipeline.html) for the complete guide on building OpenNeuro pipelines.
"""
parser = base_openneuro_parser
"""Argument parser for common OpenNeuro pipeline flags."""
modality: OpenNeuroDataModality
"""Data modality for this pipeline. Must be overridden by subclasses."""
dataset_id: str
"""OpenNeuro dataset identifier (e.g., "ds005555", "ds006914")."""
brainset_id: str
"""Unique identifier for the brainset."""
origin_version: str
"""Version of the original data. Must be specified by the author of each pipeline."""
derived_version: str
"""Version of the processed data. Must be specified by the author of each pipeline."""
description: Optional[str] = None
"""Optional description of the dataset."""
CHANNEL_NAME_REMAPPING: Optional[dict[str, str]] = None
"""Optional dict mapping original channel name to new standardized name.
For more complex configurations (e.g., per-recording mappings), override
get_channel_name_remapping() instead.
"""
TYPE_CHANNELS_REMAPPING: Optional[dict[str, list[str]]] = None
"""Optional dict mapping channel types to lists of channel names.
For more complex configurations (e.g., per-recording mappings), override
get_type_channels_remapping() instead.
"""
IGNORE_CHANNELS: Optional[list[str]] = None
"""Optional list of channel names to ignore.
Channel names should be specified as they appear in the original namespace of
the raw object (i.e., prior to any remapping or type changes).
"""
[docs]
@staticmethod
def validate_dataset_id(dataset_id: str) -> None:
"""Validate OpenNeuro dataset identifier format.
OpenNeuro dataset IDs follow the format 'ds' followed by exactly 6 digits,
where the numeric portion ranges from 000001 to 009999.
Args:
dataset_id: The dataset identifier in strict format:
- Must be lowercase 'ds' followed by exactly 6 digits.
- Numeric portion must be between 000001 and 009999.
Raises:
ValueError: If the dataset ID format is invalid, does not match strict format,
or the numeric part is outside the valid range.
"""
if (
not isinstance(dataset_id, str)
or len(dataset_id) != 8
or not dataset_id.startswith("ds")
or not dataset_id[2:].isdigit()
):
raise ValueError(
f"Invalid dataset ID format: '{dataset_id}'. Expected 'ds' followed by exactly 6 digits."
)
numeric_part = int(dataset_id[2:])
if numeric_part < 1 or numeric_part > 9999:
raise ValueError(
f"Dataset ID '{dataset_id}' has invalid numeric portion. Must be between 000001 and 009999."
)
@classmethod
def _validate_dataset_version(
cls,
latest_snapshot_tag: str,
on_mismatch: Literal["abort", "continue", "prompt"] = "prompt",
) -> None:
"""Validate origin version against the latest OpenNeuro snapshot tag.
Args:
latest_snapshot_tag: The latest snapshot tag available on OpenNeuro for this dataset.
on_mismatch: Policy when ``origin_version`` differs from latest
(``"abort"``, ``"continue"``, or ``"prompt"``). If a mismatch is detected, the
``on_mismatch`` parameter determines the behavior (default: ``"prompt"``):
- ``"abort"``: Raises an error and exits the pipeline.
- ``"continue"``: Logs a warning and proceeds with the latest version.
- ``"prompt"``: Prompts the user for confirmation and proceeds if confirmed.
Raises:
SystemExit: If mismatch policy aborts execution or user declines prompt.
"""
def user_confirms(
prompt: str,
) -> bool:
"""Return True if the user confirms continuation, False otherwise."""
answer = input(prompt).strip().lower()
return answer in {"y", "yes"}
if latest_snapshot_tag != cls.origin_version:
if on_mismatch == "continue":
logging.warning(
f"⚠️ Dataset version '{cls.origin_version}' was used to create the brainset pipeline for dataset '{cls.dataset_id}', "
f"but the latest available version on OpenNeuro is '{latest_snapshot_tag}'. "
"Downloading data or running the pipeline now will use the latest version, "
"which may differ from the original version used, potentially causing errors or inconsistencies. "
"Check the CHANGES file of the dataset for details about the differences between versions."
)
elif on_mismatch == "abort":
raise SystemExit(
"🛑 Aborting pipeline due to dataset version mismatch."
)
elif on_mismatch == "prompt":
prompt_message = (
f"⚠️ Dataset '{cls.dataset_id}' pipeline version is '{cls.origin_version}', "
f"but latest on OpenNeuro is '{latest_snapshot_tag}'. "
"👉 Continue with latest version? [y/N]: "
)
if not user_confirms(prompt_message):
raise SystemExit(
"🛑 Aborted by user due to dataset version mismatch."
)
@staticmethod
def _validate_on_mismatch_policy(on_version_mismatch: str) -> None:
"""Validate that on_version_mismatch policy is compatible with execution mode.
In non-interactive sessions, the 'prompt' policy is invalid because it requires
user input. This validation runs early to provide a clear error message.
Args:
on_version_mismatch: Policy value ('abort', 'continue', or 'prompt').
Raises:
ValueError: If on_version_mismatch='prompt' in non-interactive mode.
"""
if on_version_mismatch == "prompt" and not sys.stdin.isatty():
raise ValueError(
"Cannot use --on-version-mismatch='prompt' in non-interactive mode. "
"The program is running without a TTY and cannot prompt for user input. "
"Set --on-version-mismatch to either 'continue' (warn and proceed) or 'abort' (fail on mismatch)."
)
@staticmethod
def _normalize_species(species: str | None) -> str | None:
"""Normalize species names to ``"HOMO_SAPIENS"`` or None.
Args:
species: The input species name (string or None).
Returns:
``"HOMO_SAPIENS"`` for recognized human aliases, otherwise None.
"""
if not isinstance(species, str):
return None
normalized_species = species.strip().lower()
homo_sapiens_aliases = {
"homo",
"homo sapiens",
"human",
"humans",
"h. sapiens",
}
if normalized_species in homo_sapiens_aliases:
return "HOMO_SAPIENS"
return None
[docs]
@classmethod
def get_manifest(cls, raw_dir: Path, args: Optional[Namespace]) -> pd.DataFrame:
"""Generate a manifest DataFrame by discovering recordings from OpenNeuro.
This implementation queries OpenNeuro S3 and parses BIDS-compliant
filenames to discover recordings for the pipeline modality.
Args:
raw_dir: Raw data directory assigned to this brainset
args: Pipeline-specific arguments parsed from the command line
Returns:
DataFrame with columns:
- subject_id: Subject identifier (e.g., 'sub-01')
- recording_id: Recording identifier (index)
- s3_url: S3 URL for downloading
"""
# Determine the 'on_version_mismatch' policy from args if available, else default to 'prompt'
on_version_mismatch = args.on_version_mismatch
cls._validate_on_mismatch_policy(on_version_mismatch)
# Validate that dataset ID has the correct format
cls.validate_dataset_id(cls.dataset_id)
# Fetch the latest snapshot tag available on OpenNeuro for the dataset
latest_snapshot_tag = fetch_latest_snapshot_tag(cls.dataset_id)
cls._validate_dataset_version(
latest_snapshot_tag, on_mismatch=on_version_mismatch
)
# Fetch the species of the participants in the dataset
species = fetch_species(cls.dataset_id)
species = cls._normalize_species(species)
# Fetch the participants.tsv file from the dataset
participants_data = fetch_participants_tsv(cls.dataset_id)
# Fetch all filenames in the dataset from OpenNeuro S3
all_files = fetch_all_filenames(cls.dataset_id)
# Depending on modality, extract a list of recordings
if cls.modality == "eeg":
recordings = fetch_eeg_recordings(all_files)
elif cls.modality == "ieeg":
recordings = fetch_ieeg_recordings(all_files)
else:
raise ValueError(f"Unknown modality: {cls.modality}")
manifest_list = []
for rec in recordings:
subject_id = rec["subject_id"]
recording_id = rec["recording_id"]
fpath = rec["fpath"]
# Construct the S3 URL for the recording
s3_url = construct_s3_url_from_path(
cls.dataset_id,
fpath,
recording_id,
)
# Fetch the subject information from the participants.tsv file
subject_info = get_subject_info(subject_id, participants_data)
manifest_list.append(
{
"subject_id": subject_id,
"recording_id": recording_id,
"s3_url": s3_url,
"latest_snapshot_tag": latest_snapshot_tag,
"age": subject_info.get("age"),
"sex": subject_info.get("sex"),
"species": species,
}
)
if not manifest_list:
raise ValueError(
f"No {cls.modality.upper()} recordings found in dataset {cls.dataset_id}"
)
# Create a DataFrame for the manifest and set 'recording_id' as its index
manifest = pd.DataFrame(manifest_list)
return manifest.set_index("recording_id")
[docs]
def download(self, manifest_item) -> pd.Series:
"""Download data for a single recording from OpenNeuro S3.
Args:
manifest_item: A single row of the manifest
Returns:
Series containing ``subject_id``, ``recording_id``, ``s3_url``, ``latest_snapshot_tag``, ``age``, ``sex``, and ``species``.
"""
self.update_status("DOWNLOADING")
self.raw_dir.mkdir(exist_ok=True, parents=True)
subject_id = manifest_item.subject_id
recording_id = manifest_item.Index
s3_url = manifest_item.s3_url
root_dir = self.raw_dir
# if the dataset_description.json file does not exist or the redownload flag is set, download it
# dataset_description.json is required for mne-bids to recognize a valid BIDS dataset
dataset_description_exists = (root_dir / "dataset_description.json").exists()
if not dataset_description_exists or getattr(self.args, "redownload", False):
download_dataset_description(self.dataset_id, root_dir)
if not getattr(self.args, "redownload", False):
if self.modality == "eeg":
if check_eeg_recording_files_exist(root_dir, recording_id):
self.update_status("Already Downloaded")
return manifest_item
elif self.modality == "ieeg":
if check_ieeg_recording_files_exist(root_dir, recording_id):
self.update_status("Already Downloaded")
return manifest_item
try:
download_recording(s3_url, root_dir)
except Exception as e:
raise RuntimeError(
f"Failed to download data for {subject_id} from {self.dataset_id}: {str(e)}"
) from e
return manifest_item
[docs]
def process_common(self, download_output: pd.Series) -> Optional[tuple[Data, Path]]:
"""Process data files and create a Data object.
This method handles common OpenNeuro processing tasks:
1. Loads BIDS-structured data files using MNE-BIDS
2. Extracts metadata (subject, session, device, brainset descriptions)
3. Extracts signal and channel information
5. Creates a Data object
Args:
download_output: Series returned by download()
Returns:
Tuple of ``(data, store_path)``, or ``None`` if processing is skipped.
"""
self.processed_dir.mkdir(exist_ok=True, parents=True)
recording_id = download_output.Index
subject_id = download_output.subject_id
species = download_output.species
age = download_output.age
sex = download_output.sex
store_path = self.processed_dir / f"{recording_id}.h5"
if not getattr(self.args, "reprocess", False):
if store_path.exists():
self.update_status("Already Processed")
return None
_require_mne_bids("_process_common")
self.update_status(f"Loading {self.modality.upper()} file")
bids_path = build_bids_path(self.raw_dir, recording_id, self.modality)
raw = read_raw_bids(
bids_path,
on_ch_mismatch="reorder",
verbose="CRITICAL",
)
self.update_status("Extracting Metadata")
source = f"https://openneuro.org/datasets/{self.dataset_id}"
dataset_description = (
self.description
if self.description
else f"OpenNeuro dataset {self.dataset_id}"
)
brainset_description = BrainsetDescription(
id=self.brainset_id,
origin_version=download_output.latest_snapshot_tag,
derived_version=self.derived_version,
source=source,
description=dataset_description,
)
subject_description = SubjectDescription(
id=subject_id,
species=species,
age=age,
sex=sex,
)
meas_date = extract_measurement_date(raw)
session_description = SessionDescription(
id=recording_id, recording_date=meas_date
)
device_description = DeviceDescription(id=recording_id)
self.update_status(f"Extracting {self.modality.upper()} Signal")
signal = extract_signal(
raw,
ignore_channels=self.IGNORE_CHANNELS,
)
self.update_status("Building Channels")
channels = extract_channels(
raw,
channel_names_mapping=self.get_channel_name_remapping(recording_id),
type_channels_mapping=self.get_type_channels_remapping(recording_id),
ignore_channels=self.IGNORE_CHANNELS,
)
self.update_status("Creating Data Object")
data_kwargs = {
"brainset": brainset_description,
"subject": subject_description,
"session": session_description,
"device": device_description,
"channels": channels,
"domain": signal.domain,
}
data_kwargs[self.modality] = signal
data = Data(**data_kwargs)
return data, store_path
[docs]
def process(self, download_output: pd.Series) -> None:
"""Process and save the dataset.
Default implementation calls :meth:`_process_common` and persists the
result. Subclasses can override to add dataset-specific processing.
Args:
download_output: Series returned by download()
"""
result = self.process_common(download_output)
if result is None:
return
data, store_path = result
self.update_status("Storing")
with h5py.File(store_path, "w") as file:
data.to_hdf5(file, serialize_fn_map=serialize_fn_map)
[docs]
def get_channel_name_remapping(
self,
recording_id: str | None = None,
) -> Optional[dict[str, str]]:
"""Return channel name remapping for a given recording.
Override this method to provide per-recording channel name remappings.
The default implementation returns the class-level CHANNEL_NAME_REMAPPING attribute.
Args:
recording_id: The recording identifier
Returns:
Mapping from original channel names to standardized names, or
``None``.
"""
return self.CHANNEL_NAME_REMAPPING
[docs]
def get_type_channels_remapping(
self,
recording_id: str | None = None,
) -> Optional[dict[str, list[str]]]:
"""Return channel type remapping for a given recording.
Override this method to provide per-recording channel type remappings.
The default implementation returns the class-level TYPE_CHANNELS_REMAPPING attribute.
Args:
recording_id: The recording identifier
Returns:
Mapping from channel type to channel name list, or ``None``.
"""
return self.TYPE_CHANNELS_REMAPPING