Source code for brainsets.datasets.FlintSlutzkyAccurate2012

from typing import Callable, Optional, Literal
from pathlib import Path

from torch_brain.dataset import Dataset, SpikingDatasetMixin

from ._utils import get_processed_dir


[docs] class FlintSlutzkyAccurate2012(SpikingDatasetMixin, Dataset): """ Motor cortex (M1) spiking activity and reaching kinematics from 1 monkey performing center-out reaching tasks. .. admonition:: Preprocessing To download and prepare this dataset, run .. code:: shell brainsets prepare flint_slutzky_accurate_2012 **Tasks:** Center-Out **Brain Regions:** M1 **Dataset Statistics** - **Subjects:** 1 - **Total Sessions:** 5 - **Total Units:** 957 - **Events:** ~7.9M spikes, ~319k behavioral timestamps **Links** - Paper: `Flint et al. (2012) – Journal of Neural Engineering <https://doi.org/10.1088/1741-2560/9/4/046006>`_ - Dataset: `CRCNS Flint 2012 dataset <https://portal.nersc.gov/project/crcns/download/dream/data_sets/Flint_2012>`_ **Reference** Flint, R. D., Lindberg, E. W., Jordan, L. R., Miller, L. E., & Slutzky, M. W. (2012). *Accurate decoding of reaching movements from field potentials in the absence of spikes.* `Journal of Neural Engineering <https://doi.org/10.1088/1741-2560/9/4/046006>`_, 9(4), 046006. Args: root (str, optional): Root directory for the dataset. Defaults to ``processed_dir`` from brainsets config. recording_ids (list[str], optional): List of recording IDs to load. transform (Callable, optional): Data transformation to apply. split_type (str, optional): Which split type to use. Defaults to "hand_velocity". dirname (str, optional): Subdirectory for the dataset. Defaults to "flint_slutzky_accurate_2012". """ def __init__( self, root: Optional[str] = None, recording_ids: Optional[list[str]] = None, transform: Optional[Callable] = None, split_type: Optional[Literal["hand_velocity"]] = "hand_velocity", dirname: str = "flint_slutzky_accurate_2012", **kwargs, ): if root is None: root = get_processed_dir() super().__init__( dataset_dir=Path(root) / dirname, recording_ids=recording_ids, transform=transform, namespace_attributes=["session.id", "subject.id", "units.id"], **kwargs, ) self.spiking_dataset_mixin_uniquify_unit_ids = True self.split_type = split_type def get_sampling_intervals( self, split: Optional[Literal["train", "valid", "test"]] = None, ): domain_key = "domain" if split is None else f"{split}_domain" ans = {} for rid in self.recording_ids: data = self.get_recording(rid) ans[rid] = getattr(data, domain_key) if self.split_type == "hand_velocity": ans[rid] = ans[rid] & data.hand.domain & data.spikes.domain return ans