Skip to content

Working with Datasets

Overview

This guide explains how to work with datasets in the project. The system is designed to handle multiple modalities (ECG, CMR) through a unified pipeline that consists of three main stages:

  1. Raw Data Handling: Loading and validating original data files
  2. Preprocessing: Converting data into standardized tensor format
  3. Unified Access: Bringing everything together for analysis and training

Data Organization

The project follows a consistent directory structure across all datasets:

data/
├── raw/                    # Original dataset files
│   ├── cmr/                    # Cardiac MRI datasets
│   │   └── acdc/               # ACDC dataset files
│   └── ecg/                    # ECG datasets
│       └── ptbxl/              # PTB-XL dataset files
├── interim/                # Preprocessed tensors (*.pt files)
│   ├── acdc/                   # Preprocessed ACDC records
│   └── ptbxl/                  # Preprocessed PTB-XL records
├── processed/              # Final dataset artifacts
│   ├── acdc/                   # ACDC splits and metadata
│   │   ├── splits.json         # Train/val/test splits
│   │   └── metadata.db         # Record metadata
│   └── ptbxl/                  # PTB-XL splits and metadata
└── embeddings/             # Pre-computed embeddings (optional)
    ├── type1/                  # Embeddings from model 1
    └── type2/                  # Embeddings from model 2

Dataset Components

1. Raw Dataset Handlers

Raw dataset handlers provide the interface to original data files. They handle:

  • Data loading and validation
  • Metadata extraction
  • Label processing
  • Format standardization

The base classes that all handlers must implement:

src.data.raw.data

RawDataset

Bases: BaseDataset, ABC

Base class for handling raw medical data.

Source code in src/data/raw/data.py
class RawDataset(BaseDataset, ABC):
    """Base class for handling raw medical data."""

    def __init__(self, data_root: Path):
        super().__init__(data_root)
        if not self.paths["raw"].exists():
            raise ValueError(f"Raw data path does not exist: {self.paths['raw']}")

    @abstractmethod
    def verify_data(self) -> None:
        """Verify raw data structure and contents. Throw error if invalid."""
        pass

    @abstractmethod
    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset-specific metadata."""
        pass

    @abstractmethod
    def get_target_labels(self) -> List[str]:
        """Get list of target labels."""
        pass

    @abstractmethod
    def get_all_record_ids(self) -> List[str]:
        """Get all available record IDs without loading data."""
        pass

    @abstractmethod
    def load_record(self, record_id: str) -> RawRecord:
        """Load a single record."""
        pass

    @abstractmethod
    def get_stream(self) -> Generator[RawRecord, None, None]:
        """Stream raw records one at a time."""
        pass
get_all_record_ids() abstractmethod

Get all available record IDs without loading data.

Source code in src/data/raw/data.py
@abstractmethod
def get_all_record_ids(self) -> List[str]:
    """Get all available record IDs without loading data."""
    pass
get_metadata() abstractmethod

Get dataset-specific metadata.

Source code in src/data/raw/data.py
@abstractmethod
def get_metadata(self) -> Dict[str, Any]:
    """Get dataset-specific metadata."""
    pass
get_stream() abstractmethod

Stream raw records one at a time.

Source code in src/data/raw/data.py
@abstractmethod
def get_stream(self) -> Generator[RawRecord, None, None]:
    """Stream raw records one at a time."""
    pass
get_target_labels() abstractmethod

Get list of target labels.

Source code in src/data/raw/data.py
@abstractmethod
def get_target_labels(self) -> List[str]:
    """Get list of target labels."""
    pass
load_record(record_id) abstractmethod

Load a single record.

Source code in src/data/raw/data.py
@abstractmethod
def load_record(self, record_id: str) -> RawRecord:
    """Load a single record."""
    pass
verify_data() abstractmethod

Verify raw data structure and contents. Throw error if invalid.

Source code in src/data/raw/data.py
@abstractmethod
def verify_data(self) -> None:
    """Verify raw data structure and contents. Throw error if invalid."""
    pass
RawRecord dataclass

Bases: DatasetRecord

Data class for raw samples.

Source code in src/data/raw/data.py
@dataclass
class RawRecord(DatasetRecord):
    """Data class for raw samples."""

    data: ndarray
    target_labels: List[str] | None
    metadata: Dict[str, Any]

    def __str__(self) -> str:
        return f"RawRecord(id={self.id}, targets={";".join(self.target_labels)})"

    def __post_init__(self):
        if not isinstance(self.data, ndarray):
            raise ValueError(f"Data must be a numpy array, got {type(self.data)}")
        if not isinstance(self.target_labels, list):
            raise ValueError(f"Targets must be a list, got {type(self.target_labels)}")
        if not isinstance(self.metadata, dict):
            raise ValueError(
                f"Metadata must be a dictionary, got {type(self.metadata)}"
            )

Dataset handlers are registered using:

src.data.raw.registry

RawDatasetRegistry

Registry for raw data handlers.

Source code in src/data/raw/registry.py
class RawDatasetRegistry:
    """Registry for raw data handlers."""

    _registry: Dict[str, Dict[str, Type[RawDataset]]] = {"ecg": {}, "cmr": {}}

    @classmethod
    def register(cls, modality: str, dataset_key: str):
        """Decorator to register a raw data handler."""

        def decorator(raw_data_class: Type[RawDataset]):
            if modality not in cls._registry:
                cls._registry[modality] = {}
            cls._registry[modality][dataset_key] = raw_data_class
            return raw_data_class

        return decorator

    @classmethod
    def get_handler(cls, modality: str, dataset_key: str) -> Type[RawDataset]:
        """Get raw data handler by modality and key."""
        if modality not in cls._registry:
            raise ValueError(f"Unknown modality: {modality}")
        if dataset_key not in cls._registry[modality]:
            raise ValueError(f"Unknown dataset key for {modality}: {dataset_key}")
        return cls._registry[modality][dataset_key]

    @classmethod
    def get_modality(cls, dataset_key: str) -> DatasetModality:
        """Get modality by dataset key."""
        for modality, datasets in cls._registry.items():
            if dataset_key in datasets:
                return DatasetModality(modality)
        raise ValueError(f"Unknown dataset key: {dataset_key}")

    @classmethod
    def list_datasets(cls) -> Dict[str, List[str]]:
        """List available datasets per modality."""
        return {
            modality: list(datasets.keys())
            for modality, datasets in cls._registry.items()
        }

    @classmethod
    def list_modalities(cls) -> List[str]:
        """List available modalities."""
        return list(cls._registry.keys())
get_handler(modality, dataset_key) classmethod

Get raw data handler by modality and key.

Source code in src/data/raw/registry.py
@classmethod
def get_handler(cls, modality: str, dataset_key: str) -> Type[RawDataset]:
    """Get raw data handler by modality and key."""
    if modality not in cls._registry:
        raise ValueError(f"Unknown modality: {modality}")
    if dataset_key not in cls._registry[modality]:
        raise ValueError(f"Unknown dataset key for {modality}: {dataset_key}")
    return cls._registry[modality][dataset_key]
get_modality(dataset_key) classmethod

Get modality by dataset key.

Source code in src/data/raw/registry.py
@classmethod
def get_modality(cls, dataset_key: str) -> DatasetModality:
    """Get modality by dataset key."""
    for modality, datasets in cls._registry.items():
        if dataset_key in datasets:
            return DatasetModality(modality)
    raise ValueError(f"Unknown dataset key: {dataset_key}")
list_datasets() classmethod

List available datasets per modality.

Source code in src/data/raw/registry.py
@classmethod
def list_datasets(cls) -> Dict[str, List[str]]:
    """List available datasets per modality."""
    return {
        modality: list(datasets.keys())
        for modality, datasets in cls._registry.items()
    }
list_modalities() classmethod

List available modalities.

Source code in src/data/raw/registry.py
@classmethod
def list_modalities(cls) -> List[str]:
    """List available modalities."""
    return list(cls._registry.keys())
register(modality, dataset_key) classmethod

Decorator to register a raw data handler.

Source code in src/data/raw/registry.py
@classmethod
def register(cls, modality: str, dataset_key: str):
    """Decorator to register a raw data handler."""

    def decorator(raw_data_class: Type[RawDataset]):
        if modality not in cls._registry:
            cls._registry[modality] = {}
        cls._registry[modality][dataset_key] = raw_data_class
        return raw_data_class

    return decorator

2. Unified Dataset System

The unified dataset system brings everything together, providing:

  • Access to raw, preprocessed, and embedded data
  • Automatic data integrity validation
  • Efficient caching
  • Comprehensive metadata management

src.data.unified

UnifiedDataset

Bases: BaseDataset

Source code in src/data/unified.py
class UnifiedDataset(BaseDataset):
    def __init__(self, data_root: Path, modality: DatasetModality, dataset_key: str):
        super().__init__(data_root)
        self._dataset_key = dataset_key
        self._modality = modality

    @property
    def dataset_key(self) -> str:
        return self._dataset_key

    @property
    def modality(self) -> DatasetModality:
        return self._modality

    @cached_property
    def metadata_store(self) -> MetadataStore:
        return MetadataStore(data_root=self.paths["processed"])

    @cached_property
    def raw_dataset(self) -> RawDataset:
        return RawDatasetRegistry.get_handler(self.modality.value, self.dataset_key)(
            self.data_root
        )

    def has_dataset_info(self) -> bool:
        return self.paths["misc"]["dataset_info"].exists()

    def get_dataset_info(self) -> Dict[str, Any]:
        if not self.has_dataset_info():
            raise ValueError("No dataset info found for this dataset.")

        with open(self.paths["misc"]["dataset_info"], "r") as f:
            return json.load(f)

    def has_splits(self) -> bool:
        return self.paths["misc"]["splits_file"].exists()

    def get_splits(self) -> Dict[str, Dict[str, Any]]:
        if not self.has_splits():
            raise ValueError("No splits found for this dataset.")

        with open(self.paths["misc"]["splits_file"], "r") as f:
            splits = json.load(f)
        return splits

    def get_split_by_record_id(self, record_id: str) -> str:
        for split_name, record_ids in self.get_splits().items():
            if record_id in record_ids:
                return split_name
        raise ValueError(f"Record ID '{record_id}' not found in any split.")

    def __get_all_record_ids_from_splits(self) -> List[str]:
        return [
            record_id for split in self.get_splits().values() for record_id in split
        ]

    def __get_all_record_ids_from_interim(self) -> List[str]:
        interim_files = self.paths["interim"].glob("*.pt")
        return [f.stem for f in interim_files]

    @cache
    def get_all_record_ids(self) -> List[str]:
        # if self.has_splits():
        #    return self.__get_all_record_ids_from_splits()
        return (
            self.__get_all_record_ids_from_interim()
        )  # interim is the ground truth for record ids

    @cache
    def __load_embeddings(self, embeddings_type: str) -> Dict[str, torch.Tensor]:
        if embeddings_type not in self.paths["embeddings"].keys():
            raise ValueError(f"Embeddings type '{embeddings_type}' not found.")

        embeddings_path = self.paths["embeddings"][embeddings_type]
        if not embeddings_path.exists():
            return {}

        return torch.load(embeddings_path, map_location="cpu")  # type: ignore

    @lru_cache(maxsize=1000)
    def __load_preprocessed_record(self, record_id: str) -> PreprocessedRecord:
        if record_id not in self.get_all_record_ids():
            raise ValueError(f"Record ID '{record_id}' not found in dataset.")

        interim_file = self.paths["interim"] / f"{record_id}.pt"
        if not interim_file.exists():
            raise ValueError(f"Interim file not found for record ID '{record_id}'.")
        return torch.load(interim_file, map_location="cpu")  # type: ignore

    def get_embeddings(
        self, record_id: str, embeddings_type: str = None
    ) -> Dict[str, torch.Tensor] | torch.Tensor:
        if embeddings_type is None:
            return {
                k: self.__load_embeddings(k).get(record_id, None)
                for k in self.paths["embeddings"].keys()
            }

        return self.__load_embeddings(embeddings_type).get(record_id, None)

    def available_metadata_fields(self) -> set[str]:
        return self.metadata_store.available_fields()

    @lru_cache(maxsize=1000)
    def __getitem__(self, record_id: str) -> UnifiedRecord:
        preprocessed_record = self.__load_preprocessed_record(record_id)

        embeddings = self.get_embeddings(record_id)
        raw_record = self.raw_dataset.load_record(record_id)

        return UnifiedRecord(
            id=record_id,
            raw_record=raw_record,
            preprocessed_record=preprocessed_record,
            embeddings=embeddings,
        )

    def verify_integrity(self) -> None:
        """Validate dataset integrity using interim files as ground truth.

        Performs the following assertions:
            1. At least one interim file exists (raises ValueError if empty)
            2. If splits exist:
                a. All split IDs must exist in interim files (raises on missing IDs)
                b. All interim IDs must exist in splits (raises on coverage mismatch)
                c. No duplicate IDs across splits (raises on duplicates)
                d. Split data must be stored as lists (raises on format error)
            3. Embeddings files (if exist):
                a. Must not contain IDs absent from interim files (raises on extra IDs)
            4. Metadata:
                a. Must exist for every interim ID (raises on missing metadata)

        Raises:
            ValueError: For any integrity check failure, with detailed message
            FileNotFoundError: If critical path components are missing

        Note:
            Interim files (*.pt in interim directory) are considered the canonical
            source of truth for valid record IDs. All other components (splits,
            embeddings, metadata) must align with these IDs.
        """
        interim_ids = set(self.get_all_record_ids())

        # First check there's at least one interim record
        if not interim_ids:
            raise ValueError(f"No interim files found in {self.paths['interim']}")

        # Check if every raw record was preprocessed
        raw_record_ids = set(self.raw_dataset.get_all_record_ids())
        missing_preprocessed = raw_record_ids - interim_ids
        if missing_preprocessed:
            raise ValueError(
                f"{len(missing_preprocessed)} raw records missing from interim files. "
                f"First 5: {sorted(missing_preprocessed)[:5]}"
            )

        # Split consistency checks (if splits exist)
        if self.has_splits():
            splits = self.get_splits()
            split_ids = set()
            all_split_ids = []

            # Collect all split IDs and check per-split validity
            for split_name, split_records in splits.items():
                if not isinstance(split_records, list):
                    raise ValueError(
                        f"Split '{split_name}' should be a list of IDs, got {type(split_records)}"
                    )

                current_split = set(split_records)
                split_ids.update(current_split)
                all_split_ids.extend(split_records)

                # Check individual split validity
                missing_in_interim = current_split - interim_ids
                if missing_in_interim:
                    raise ValueError(
                        f"Split '{split_name}' contains {len(missing_in_interim)} "
                        f"IDs missing from interim files. First 5: {sorted(missing_in_interim)[:5]}"
                    )

            # Check full split coverage
            if split_ids != interim_ids:
                missing_in_splits = interim_ids - split_ids
                extra_in_splits = split_ids - interim_ids
                error_msg = []
                if missing_in_splits:
                    error_msg.append(
                        f"{len(missing_in_splits)} interim IDs missing from splits"
                    )
                if extra_in_splits:
                    error_msg.append(
                        f"{len(extra_in_splits)} extra IDs in splits not in interim"
                    )
                raise ValueError("Split coverage mismatch: " + ", ".join(error_msg))

            # Check for duplicate IDs across splits
            duplicate_ids = {id for id in all_split_ids if all_split_ids.count(id) > 1}
            if duplicate_ids:
                raise ValueError(
                    f"{len(duplicate_ids)} duplicate IDs found across splits. "
                    f"First 5: {sorted(duplicate_ids)[:5]}"
                )

        # Embeddings consistency checks
        for emb_type, emb_path in self.paths["embeddings"].items():
            if emb_path.exists():
                embeddings = self.__load_embeddings(emb_type)
                if not embeddings:
                    continue  # Skip empty embeddings files

                emb_ids = set(embeddings.keys())
                extra_emb_ids = emb_ids - interim_ids
                if extra_emb_ids:
                    raise ValueError(
                        f"{emb_type} embeddings contain {len(extra_emb_ids)} IDs "
                        f"not in interim files. First 5: {sorted(extra_emb_ids)[:5]}"
                    )

        # Metadata completeness check
        missing_metadata = []
        for record_id in interim_ids:
            if not self.metadata_store.get(record_id):
                missing_metadata.append(record_id)
            if len(missing_metadata) >= 5:  # Early exit for large datasets
                break
        if missing_metadata:
            raise ValueError(
                f"Missing metadata for {len(missing_metadata)} records. "
                f"First 5: {missing_metadata[:5]}"
            )

    def __str__(self) -> str:
        return f"ProcessedDataset(data_root={self.data_root}, modality={self.modality}, dataset_key={self.dataset_key})"
verify_integrity()

Validate dataset integrity using interim files as ground truth.

Performs the following assertions
  1. At least one interim file exists (raises ValueError if empty)
  2. If splits exist: a. All split IDs must exist in interim files (raises on missing IDs) b. All interim IDs must exist in splits (raises on coverage mismatch) c. No duplicate IDs across splits (raises on duplicates) d. Split data must be stored as lists (raises on format error)
  3. Embeddings files (if exist): a. Must not contain IDs absent from interim files (raises on extra IDs)
  4. Metadata: a. Must exist for every interim ID (raises on missing metadata)

Raises:

Type Description
ValueError

For any integrity check failure, with detailed message

FileNotFoundError

If critical path components are missing

Note

Interim files (*.pt in interim directory) are considered the canonical source of truth for valid record IDs. All other components (splits, embeddings, metadata) must align with these IDs.

Source code in src/data/unified.py
def verify_integrity(self) -> None:
    """Validate dataset integrity using interim files as ground truth.

    Performs the following assertions:
        1. At least one interim file exists (raises ValueError if empty)
        2. If splits exist:
            a. All split IDs must exist in interim files (raises on missing IDs)
            b. All interim IDs must exist in splits (raises on coverage mismatch)
            c. No duplicate IDs across splits (raises on duplicates)
            d. Split data must be stored as lists (raises on format error)
        3. Embeddings files (if exist):
            a. Must not contain IDs absent from interim files (raises on extra IDs)
        4. Metadata:
            a. Must exist for every interim ID (raises on missing metadata)

    Raises:
        ValueError: For any integrity check failure, with detailed message
        FileNotFoundError: If critical path components are missing

    Note:
        Interim files (*.pt in interim directory) are considered the canonical
        source of truth for valid record IDs. All other components (splits,
        embeddings, metadata) must align with these IDs.
    """
    interim_ids = set(self.get_all_record_ids())

    # First check there's at least one interim record
    if not interim_ids:
        raise ValueError(f"No interim files found in {self.paths['interim']}")

    # Check if every raw record was preprocessed
    raw_record_ids = set(self.raw_dataset.get_all_record_ids())
    missing_preprocessed = raw_record_ids - interim_ids
    if missing_preprocessed:
        raise ValueError(
            f"{len(missing_preprocessed)} raw records missing from interim files. "
            f"First 5: {sorted(missing_preprocessed)[:5]}"
        )

    # Split consistency checks (if splits exist)
    if self.has_splits():
        splits = self.get_splits()
        split_ids = set()
        all_split_ids = []

        # Collect all split IDs and check per-split validity
        for split_name, split_records in splits.items():
            if not isinstance(split_records, list):
                raise ValueError(
                    f"Split '{split_name}' should be a list of IDs, got {type(split_records)}"
                )

            current_split = set(split_records)
            split_ids.update(current_split)
            all_split_ids.extend(split_records)

            # Check individual split validity
            missing_in_interim = current_split - interim_ids
            if missing_in_interim:
                raise ValueError(
                    f"Split '{split_name}' contains {len(missing_in_interim)} "
                    f"IDs missing from interim files. First 5: {sorted(missing_in_interim)[:5]}"
                )

        # Check full split coverage
        if split_ids != interim_ids:
            missing_in_splits = interim_ids - split_ids
            extra_in_splits = split_ids - interim_ids
            error_msg = []
            if missing_in_splits:
                error_msg.append(
                    f"{len(missing_in_splits)} interim IDs missing from splits"
                )
            if extra_in_splits:
                error_msg.append(
                    f"{len(extra_in_splits)} extra IDs in splits not in interim"
                )
            raise ValueError("Split coverage mismatch: " + ", ".join(error_msg))

        # Check for duplicate IDs across splits
        duplicate_ids = {id for id in all_split_ids if all_split_ids.count(id) > 1}
        if duplicate_ids:
            raise ValueError(
                f"{len(duplicate_ids)} duplicate IDs found across splits. "
                f"First 5: {sorted(duplicate_ids)[:5]}"
            )

    # Embeddings consistency checks
    for emb_type, emb_path in self.paths["embeddings"].items():
        if emb_path.exists():
            embeddings = self.__load_embeddings(emb_type)
            if not embeddings:
                continue  # Skip empty embeddings files

            emb_ids = set(embeddings.keys())
            extra_emb_ids = emb_ids - interim_ids
            if extra_emb_ids:
                raise ValueError(
                    f"{emb_type} embeddings contain {len(extra_emb_ids)} IDs "
                    f"not in interim files. First 5: {sorted(extra_emb_ids)[:5]}"
                )

    # Metadata completeness check
    missing_metadata = []
    for record_id in interim_ids:
        if not self.metadata_store.get(record_id):
            missing_metadata.append(record_id)
        if len(missing_metadata) >= 5:  # Early exit for large datasets
            break
    if missing_metadata:
        raise ValueError(
            f"Missing metadata for {len(missing_metadata)} records. "
            f"First 5: {missing_metadata[:5]}"
        )
UnifiedRecord dataclass

Bases: DatasetRecord

Source code in src/data/unified.py
@dataclass
class UnifiedRecord(DatasetRecord):
    raw_record: RawRecord = None
    preprocessed_record: PreprocessedRecord = None
    embeddings: Optional[torch.Tensor] = None

    def __str__(self) -> str:
        return f"UnifiedPreprocessedRecord(id={self.id}, has_embeddings={self.embeddings is not None})"

Implementation Guide

Now that you have a basic understanding of the unified dataset system, let's go through the implementation steps for creating a new dataset handler given some raw data.

Creating a New Dataset Handler

1. Choose the Appropriate Location:

  • ECG handlers: src/data/raw/ecg/<dataset_name>.py
  • CMR handlers: src/data/raw/cmr/<dataset_name>.py

2. Implement Required Methods:

All handlers must implement the RawDataset abstract base class, which defines the structure and contents of the raw data.

src.data.raw.data

RawDataset

Bases: BaseDataset, ABC

Base class for handling raw medical data.

Source code in src/data/raw/data.py
class RawDataset(BaseDataset, ABC):
    """Base class for handling raw medical data."""

    def __init__(self, data_root: Path):
        super().__init__(data_root)
        if not self.paths["raw"].exists():
            raise ValueError(f"Raw data path does not exist: {self.paths['raw']}")

    @abstractmethod
    def verify_data(self) -> None:
        """Verify raw data structure and contents. Throw error if invalid."""
        pass

    @abstractmethod
    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset-specific metadata."""
        pass

    @abstractmethod
    def get_target_labels(self) -> List[str]:
        """Get list of target labels."""
        pass

    @abstractmethod
    def get_all_record_ids(self) -> List[str]:
        """Get all available record IDs without loading data."""
        pass

    @abstractmethod
    def load_record(self, record_id: str) -> RawRecord:
        """Load a single record."""
        pass

    @abstractmethod
    def get_stream(self) -> Generator[RawRecord, None, None]:
        """Stream raw records one at a time."""
        pass
get_all_record_ids() abstractmethod

Get all available record IDs without loading data.

Source code in src/data/raw/data.py
@abstractmethod
def get_all_record_ids(self) -> List[str]:
    """Get all available record IDs without loading data."""
    pass
get_metadata() abstractmethod

Get dataset-specific metadata.

Source code in src/data/raw/data.py
@abstractmethod
def get_metadata(self) -> Dict[str, Any]:
    """Get dataset-specific metadata."""
    pass
get_stream() abstractmethod

Stream raw records one at a time.

Source code in src/data/raw/data.py
@abstractmethod
def get_stream(self) -> Generator[RawRecord, None, None]:
    """Stream raw records one at a time."""
    pass
get_target_labels() abstractmethod

Get list of target labels.

Source code in src/data/raw/data.py
@abstractmethod
def get_target_labels(self) -> List[str]:
    """Get list of target labels."""
    pass
load_record(record_id) abstractmethod

Load a single record.

Source code in src/data/raw/data.py
@abstractmethod
def load_record(self, record_id: str) -> RawRecord:
    """Load a single record."""
    pass
verify_data() abstractmethod

Verify raw data structure and contents. Throw error if invalid.

Source code in src/data/raw/data.py
@abstractmethod
def verify_data(self) -> None:
    """Verify raw data structure and contents. Throw error if invalid."""
    pass

The RawRecord object is defined along the same lines as the RawDataset class, and is used to represent a single record in the dataset. It contains the raw data and metadata for that record ready for preprocessing.

src.data.raw.data

RawRecord dataclass

Bases: DatasetRecord

Data class for raw samples.

Source code in src/data/raw/data.py
@dataclass
class RawRecord(DatasetRecord):
    """Data class for raw samples."""

    data: ndarray
    target_labels: List[str] | None
    metadata: Dict[str, Any]

    def __str__(self) -> str:
        return f"RawRecord(id={self.id}, targets={";".join(self.target_labels)})"

    def __post_init__(self):
        if not isinstance(self.data, ndarray):
            raise ValueError(f"Data must be a numpy array, got {type(self.data)}")
        if not isinstance(self.target_labels, list):
            raise ValueError(f"Targets must be a list, got {type(self.target_labels)}")
        if not isinstance(self.metadata, dict):
            raise ValueError(
                f"Metadata must be a dictionary, got {type(self.metadata)}"
            )

3. Register the Dataset Handler:

In order to use your new dataset handler, you need to register it in the src/data/__init__.py module.

# Import datasets in order to register them
import src.data.raw.cmr.acdc
import src.data.raw.ecg.arrhythmia
import src.data.raw.ecg.grouped_arrhythmia
import src.data.raw.ecg.shandong
import src.data.raw.ecg.ptb_xl

Example Implementations of Dataset Handlers

src.data.raw.ecg.ptb_xl

PTBXL

Bases: RawDataset

Handler for raw PTB-XL ECG data at 500 Hz.

Expected data structure in the raw data directory (self.paths["raw"]): ptbxl/ ├── ptbxl_database.csv # Contains one row per record (indexed by ecg_id) with extensive metadata ├── scp_statements.csv # Contains SCP-ECG annotation mappings (optional) └── records500/ # WFDB records at 500 Hz

The ptbxl_database.csv file contains many columns including: - ecg_id, patient_id, age, sex, height, weight, nurse, site, device, recording_date, etc. - scp_codes: a string representation of a dictionary mapping SCP-ECG statement codes to likelihoods. - filename_hr: path (relative to the ptbxl folder) to the WFDB record for 500 Hz.

Source code in src/data/raw/ecg/ptb_xl.py
@RawDatasetRegistry.register(DatasetModality.ECG.value, "ptbxl")
class PTBXL(RawDataset):
    """Handler for raw PTB-XL ECG data at 500 Hz.

    Expected data structure in the raw data directory (self.paths["raw"]):
        ptbxl/
        ├── ptbxl_database.csv         # Contains one row per record (indexed by ecg_id) with extensive metadata
        ├── scp_statements.csv         # Contains SCP-ECG annotation mappings (optional)
        └── records500/                # WFDB records at 500 Hz

    The ptbxl_database.csv file contains many columns including:
        - ecg_id, patient_id, age, sex, height, weight, nurse, site, device, recording_date, etc.
        - scp_codes: a string representation of a dictionary mapping SCP-ECG statement codes to likelihoods.
        - filename_hr: path (relative to the ptbxl folder) to the WFDB record for 500 Hz.
    """

    dataset_key = "ptbxl"
    modality = DatasetModality.ECG

    def __init__(self, data_root: Path):
        """
        Args:
            data_root: Path to the dataset root folder.
                      It should contain the ptbxl folder with the PTB-XL files.
        """
        super().__init__(data_root)
        self.root = self.paths["raw"]

        self.database_csv = self.root / "ptbxl_database.csv"
        self.scp_csv = self.root / "scp_statements.csv"

        if not self.database_csv.exists():
            raise FileNotFoundError(f"Database CSV not found: {self.database_csv}")
        if not self.scp_csv.exists():
            logger.warning(f"SCP statements CSV not found: {self.scp_csv}")

        # Use "ecg_id" as the index for fast lookup.
        self.metadata_df = pd.read_csv(self.database_csv, index_col="ecg_id")

        # ecg_id is unique, but patient_id is not. Group by patient_id and use random record
        self.metadata_df = (
            self.metadata_df.groupby("patient_id")
            .sample(n=1, random_state=1337)
            .reset_index(drop=True)
        )

        self.metadata_df["patient_id"] = self.metadata_df["patient_id"].astype(int)
        self.metadata_df["scp_codes"] = self.metadata_df["scp_codes"].apply(
            lambda x: ast.literal_eval(x)
            if isinstance(x, str)
            else x  # scp codes are stored as dict strings
        )

        self.statements_dict = pd.read_csv(self.scp_csv, index_col=0).to_dict(
            orient="index"
        )

        self.metadata_dict: Dict[str, Any] = self.metadata_df.to_dict(orient="index")
        self.records_folder = (
            self.root / "records500"
        )  # we are only interested in the 500Hz data
        if not self.records_folder.exists():
            raise FileNotFoundError(f"Records folder not found: {self.records_folder}")

        self.filename_col = (
            "filename_hr"  # 500 Hz records / for 100 hz use "filename_lr"
        )

    def verify_data(self) -> None:
        """Verify the raw data structure and contents."""
        if not self.database_csv.exists():
            raise ValueError(f"Database CSV does not exist: {self.database_csv}")
        if not self.records_folder.exists():
            raise ValueError(f"Records folder does not exist: {self.records_folder}")

        for record_id in self.get_all_record_ids():
            file_path = self._get_file_path(record_id)
            header_path = file_path.with_suffix(".hea")
            dat_path = file_path.with_suffix(".dat")
            if not header_path.exists():
                raise FileNotFoundError(
                    f"WFDB header file for record {record_id} not found: {header_path}"
                )

            if not dat_path.exists():
                raise FileNotFoundError(
                    f"WFDB data file for record {record_id} not found: {dat_path}"
                )

        # based on https://physionet.org/content/ptb-xl/1.0.3/, we want only unique patients with one record per patient
        assert (
            len(self.metadata_df) == 18869
        ), f"Expected 18869 records, got {len(self.metadata_df)}"

    def get_metadata(self) -> Dict[str, Any]:
        """Get dataset-specific metadata.

        Returns:
            A dictionary containing dataset details.
        """
        return {
            "num_leads": 12,
            "sampling_rate": 500,
            "record_length_seconds": 10,
            "total_records": len(self.metadata_df),
            "num_patients": int(self.metadata_df["patient_id"].nunique()),
        }

    def _get_file_path(self, record_id: str) -> Path:
        """Get the full path to the WFDB record file for a given record ID.

        Args:
            record_id: The unique ECG record identifier.

        Returns:
            Full path to the WFDB record file.
        """
        return self.root / self.metadata_dict[int(record_id)][self.filename_col]

    @cache
    def get_target_labels(self) -> List[str]:
        """Extract all unique SCP codes (with likelihood > 0) from the metadata.

        Returns:
            Sorted list of unique target label codes.
        """
        labels = set()
        for record_id in self.get_all_record_ids():
            codes = self._extract_scp_codes(int(record_id))
            labels.update(self._extract_target_labels(codes))
        return sorted(list(labels))

    def get_all_record_ids(self) -> List[str]:
        """Get all available record IDs without loading the actual data.

        Returns:
            List of record identifiers (ecg_id) as strings.
        """
        return self.metadata_df.index.astype(str).tolist()

    def get_stream(self) -> Generator[RawRecord, None, None]:
        """Stream raw records one at a time.

        Yields:
            RawRecord objects for each record.
        """
        for record_id in self.get_all_record_ids():
            try:
                yield self.load_record(record_id)
            except Exception as e:
                logger.error(f"Error loading record {record_id}: {e}")
                continue

    @cache
    def load_record(self, record_id: str) -> RawRecord:
        """Load a single ECG record by its identifier.

        Args:
            record_id: The unique ECG record identifier.

        Returns:
            A RawRecord object containing the ECG signal data, target labels, and metadata.
        """
        record_id = int(record_id)
        if record_id not in self.metadata_dict:
            raise ValueError(f"Record {record_id} not found in metadata.")
        meta = self.metadata_dict[record_id]

        file_path = self._get_file_path(record_id)
        record = self._read_wfdb_record(record_id)
        data = record[0]

        codes = self._extract_scp_codes(record_id)
        target_labels = self._extract_target_labels(codes)

        record_metadata = dict(meta)
        record_metadata["wfdb_file"] = str(file_path.relative_to(self.root))
        record_metadata["signal_fields"] = record[1]

        record_metadata["scp_statements"] = {
            code: self._extract_scp_statement(code) for code in codes
        }

        return RawRecord(
            id=record_id,
            data=data,
            target_labels=target_labels,
            metadata=record_metadata,
        )

    def _extract_scp_statement(self, code: str) -> str:
        """Extract the SCP statement for a given code from the SCP statements CSV.

        Args:
            code: SCP code to look up.

        Returns:
            SCP statement description.
        """
        if code in self.statements_dict:
            return self.statements_dict[code]
        raise ValueError(f"SCP code {code} not found in statements CSV.")

    def _extract_target_labels(self, scp_codes: Dict[str, float]) -> List[str]:
        """Extract target labels from the SCP codes dictionary.

        Args:
            scp_codes: Dictionary mapping SCP codes to likelihoods.

        Returns:
            List of target labels with likelihood > 0.
        """
        return [
            code for code, likelihood in scp_codes.items() if float(likelihood) >= 0
        ]  # include unknown codes (likelihood == 0): "... where likelihood is set to 0 if unknown) ..."

    def _extract_scp_codes(self, record_id: int) -> Dict[str, float]:
        """Extract SCP codes and likelihoods for a given record ID.

        Args:
            record_id: The unique ECG record identifier.

        Returns:
            Dictionary mapping SCP codes to likelihoods.
        """
        codes = self.metadata_dict[record_id].get("scp_codes")
        if not isinstance(codes, dict):
            raise ValueError(f"SCP codes not found for record {record_id}")
        return codes

    def _read_wfdb_record(self, record_id: int) -> Tuple[np.ndarray, Dict[str, Any]]:
        """Read the WFDB record data for a given record ID.

        Args:
            record_id: The unique ECG record identifier.

        Returns:
            Numpy array containing the ECG signal data.
        """
        file_path = self._get_file_path(record_id)
        try:
            record = wfdb.rdsamp(str(file_path))
            data = np.array(record[0], dtype=np.float32)
            data = np.swapaxes(data, 0, 1)
            return data, record[1]
        except Exception as e:
            raise RuntimeError(f"Error reading WFDB record {file_path}: {e}") from e
__init__(data_root)

Parameters:

Name Type Description Default
data_root Path

Path to the dataset root folder. It should contain the ptbxl folder with the PTB-XL files.

required
Source code in src/data/raw/ecg/ptb_xl.py
def __init__(self, data_root: Path):
    """
    Args:
        data_root: Path to the dataset root folder.
                  It should contain the ptbxl folder with the PTB-XL files.
    """
    super().__init__(data_root)
    self.root = self.paths["raw"]

    self.database_csv = self.root / "ptbxl_database.csv"
    self.scp_csv = self.root / "scp_statements.csv"

    if not self.database_csv.exists():
        raise FileNotFoundError(f"Database CSV not found: {self.database_csv}")
    if not self.scp_csv.exists():
        logger.warning(f"SCP statements CSV not found: {self.scp_csv}")

    # Use "ecg_id" as the index for fast lookup.
    self.metadata_df = pd.read_csv(self.database_csv, index_col="ecg_id")

    # ecg_id is unique, but patient_id is not. Group by patient_id and use random record
    self.metadata_df = (
        self.metadata_df.groupby("patient_id")
        .sample(n=1, random_state=1337)
        .reset_index(drop=True)
    )

    self.metadata_df["patient_id"] = self.metadata_df["patient_id"].astype(int)
    self.metadata_df["scp_codes"] = self.metadata_df["scp_codes"].apply(
        lambda x: ast.literal_eval(x)
        if isinstance(x, str)
        else x  # scp codes are stored as dict strings
    )

    self.statements_dict = pd.read_csv(self.scp_csv, index_col=0).to_dict(
        orient="index"
    )

    self.metadata_dict: Dict[str, Any] = self.metadata_df.to_dict(orient="index")
    self.records_folder = (
        self.root / "records500"
    )  # we are only interested in the 500Hz data
    if not self.records_folder.exists():
        raise FileNotFoundError(f"Records folder not found: {self.records_folder}")

    self.filename_col = (
        "filename_hr"  # 500 Hz records / for 100 hz use "filename_lr"
    )
get_all_record_ids()

Get all available record IDs without loading the actual data.

Returns:

Type Description
List[str]

List of record identifiers (ecg_id) as strings.

Source code in src/data/raw/ecg/ptb_xl.py
def get_all_record_ids(self) -> List[str]:
    """Get all available record IDs without loading the actual data.

    Returns:
        List of record identifiers (ecg_id) as strings.
    """
    return self.metadata_df.index.astype(str).tolist()
get_metadata()

Get dataset-specific metadata.

Returns:

Type Description
Dict[str, Any]

A dictionary containing dataset details.

Source code in src/data/raw/ecg/ptb_xl.py
def get_metadata(self) -> Dict[str, Any]:
    """Get dataset-specific metadata.

    Returns:
        A dictionary containing dataset details.
    """
    return {
        "num_leads": 12,
        "sampling_rate": 500,
        "record_length_seconds": 10,
        "total_records": len(self.metadata_df),
        "num_patients": int(self.metadata_df["patient_id"].nunique()),
    }
get_stream()

Stream raw records one at a time.

Yields:

Type Description
RawRecord

RawRecord objects for each record.

Source code in src/data/raw/ecg/ptb_xl.py
def get_stream(self) -> Generator[RawRecord, None, None]:
    """Stream raw records one at a time.

    Yields:
        RawRecord objects for each record.
    """
    for record_id in self.get_all_record_ids():
        try:
            yield self.load_record(record_id)
        except Exception as e:
            logger.error(f"Error loading record {record_id}: {e}")
            continue
get_target_labels() cached

Extract all unique SCP codes (with likelihood > 0) from the metadata.

Returns:

Type Description
List[str]

Sorted list of unique target label codes.

Source code in src/data/raw/ecg/ptb_xl.py
@cache
def get_target_labels(self) -> List[str]:
    """Extract all unique SCP codes (with likelihood > 0) from the metadata.

    Returns:
        Sorted list of unique target label codes.
    """
    labels = set()
    for record_id in self.get_all_record_ids():
        codes = self._extract_scp_codes(int(record_id))
        labels.update(self._extract_target_labels(codes))
    return sorted(list(labels))
load_record(record_id) cached

Load a single ECG record by its identifier.

Parameters:

Name Type Description Default
record_id str

The unique ECG record identifier.

required

Returns:

Type Description
RawRecord

A RawRecord object containing the ECG signal data, target labels, and metadata.

Source code in src/data/raw/ecg/ptb_xl.py
@cache
def load_record(self, record_id: str) -> RawRecord:
    """Load a single ECG record by its identifier.

    Args:
        record_id: The unique ECG record identifier.

    Returns:
        A RawRecord object containing the ECG signal data, target labels, and metadata.
    """
    record_id = int(record_id)
    if record_id not in self.metadata_dict:
        raise ValueError(f"Record {record_id} not found in metadata.")
    meta = self.metadata_dict[record_id]

    file_path = self._get_file_path(record_id)
    record = self._read_wfdb_record(record_id)
    data = record[0]

    codes = self._extract_scp_codes(record_id)
    target_labels = self._extract_target_labels(codes)

    record_metadata = dict(meta)
    record_metadata["wfdb_file"] = str(file_path.relative_to(self.root))
    record_metadata["signal_fields"] = record[1]

    record_metadata["scp_statements"] = {
        code: self._extract_scp_statement(code) for code in codes
    }

    return RawRecord(
        id=record_id,
        data=data,
        target_labels=target_labels,
        metadata=record_metadata,
    )
verify_data()

Verify the raw data structure and contents.

Source code in src/data/raw/ecg/ptb_xl.py
def verify_data(self) -> None:
    """Verify the raw data structure and contents."""
    if not self.database_csv.exists():
        raise ValueError(f"Database CSV does not exist: {self.database_csv}")
    if not self.records_folder.exists():
        raise ValueError(f"Records folder does not exist: {self.records_folder}")

    for record_id in self.get_all_record_ids():
        file_path = self._get_file_path(record_id)
        header_path = file_path.with_suffix(".hea")
        dat_path = file_path.with_suffix(".dat")
        if not header_path.exists():
            raise FileNotFoundError(
                f"WFDB header file for record {record_id} not found: {header_path}"
            )

        if not dat_path.exists():
            raise FileNotFoundError(
                f"WFDB data file for record {record_id} not found: {dat_path}"
            )

    # based on https://physionet.org/content/ptb-xl/1.0.3/, we want only unique patients with one record per patient
    assert (
        len(self.metadata_df) == 18869
    ), f"Expected 18869 records, got {len(self.metadata_df)}"

src.data.raw.cmr.acdc

ACDC

Bases: RawDataset

Source code in src/data/raw/cmr/acdc.py
@RawDatasetRegistry.register("cmr", "acdc")
class ACDC(RawDataset):
    dataset_key = "acdc"
    modality = DatasetModality.CMR

    required_metadata_fields = {
        "ed_frame": None,
        "es_frame": None,
        "group": None,
        "height": None,
        "weight": None,
        "nb_frames": None,
    }

    def __init__(self, data_root: Path):
        super().__init__(data_root)
        self.data_path = self.paths["raw"]
        self.__nifti_files = {
            f.parent.name: f.absolute() for f in self.data_path.glob("**/*_4d.nii.gz")
        }
        self.__config_files = {
            f.parent.name: f.absolute() for f in self.data_path.glob("**/Info.cfg")
        }

    def verify_data(self) -> None:
        if not self.data_path.exists():
            raise ValueError(f"ACDC data path does not exist: {self.data_path}")

        if len(self.__nifti_files) == 0:
            raise ValueError(f"No 4D NIFTI files found in {self.data_path}")

        if len(self.__config_files) == 0:
            raise ValueError(f"No Info.cfg files found in {self.data_path}")

        # each NIFTI file should have a corresponding Info.cfg file
        nifti_stems = set(self.__nifti_files.keys())
        config_stems = set(self.__config_files.keys())
        if nifti_stems != config_stems:
            missing_nifti = nifti_stems - config_stems
            missing_config = config_stems - nifti_stems
            raise ValueError(
                f"Missing Info.cfg files for {missing_nifti} and NIFTI files for {missing_config}"
            )

        # based on https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html
        assert (
            len(self.get_all_record_ids()) == 150
        ), f"Expected 150 records, got {len(self.get_all_record_ids())}"

    def get_metadata(self) -> Dict[str, Any]:
        return {"default_image_size": 256}

    def _get_record_paths(self, record_id: str) -> Tuple[Path, Path]:
        """Get paths to .nii.gz and Info.cfg files for a given record ID."""
        return self.__nifti_files.get(record_id), self.__config_files.get(record_id)

    @cache
    def get_target_labels(self) -> List[str]:
        """Get list of target labels using streaming to reduce memory usage."""
        groups = set()
        for record_id in self.get_all_record_ids():
            _, config_path = self._get_record_paths(record_id)
            group = self._parse_config(config_path).get("group")
            if group is not None:
                groups.add(group)
        return sorted(list(groups))

    def get_all_record_ids(self) -> List[str]:
        """Get all record IDs without loading data."""
        return list(self.__nifti_files.keys())

    def get_stream(self) -> Generator[RawRecord, None, None]:
        """Stream records one at a time to reduce memory usage."""
        for record_id in self.get_all_record_ids():
            yield self.load_record(record_id)

    @cache
    def load_record(self, record_id: str) -> RawRecord:
        nifti_file, config_file = self._get_record_paths(record_id)

        metadata = self._parse_config(config_file)
        frame_indices = self._calculate_frame_indices(metadata)
        image_data = self._get_extracted_frames(nifti_file, frame_indices)

        metadata.update(
            {
                f"{frame_type}_frame_idx": idx
                for frame_type, idx in zip(["ed", "mid", "es"], frame_indices)
            }
        )

        metadata["nifti_path"] = str(nifti_file.relative_to(self.data_path))

        group = metadata.pop("group", None)
        return RawRecord(
            id=record_id,
            data=image_data,
            target_labels=[group] if group else None,
            metadata=metadata,
        )

    @staticmethod
    def _calculate_frame_indices(metadata: Dict[str, Any]) -> List[int]:
        required_fields = ["ed_frame", "es_frame", "nb_frames"]
        if not all(metadata.get(field) is not None for field in required_fields):
            raise ValueError("Missing required frame information")

        ed_frame_idx = metadata["ed_frame"] - 1
        es_frame_idx = metadata["es_frame"] - 1
        nb_frames = metadata["nb_frames"]

        if ed_frame_idx <= es_frame_idx:
            mid_frame_idx = (ed_frame_idx + es_frame_idx) // 2
        else:
            mid_frame_idx = ((ed_frame_idx + es_frame_idx + nb_frames) // 2) % nb_frames

        return [ed_frame_idx, mid_frame_idx, es_frame_idx]

    def _get_extracted_frames(self, path: Path, frame_indices: List[int]) -> np.ndarray:
        nifti_data = self._read_nifti(path)
        slices = self._extract_middle_slice(nifti_data)
        return self._extract_frames(slices, frame_indices)

    @staticmethod
    def _read_nifti(path: Path) -> np.ndarray:
        try:
            return nib.load(str(path)).get_fdata()
        except Exception as e:
            raise RuntimeError(f"Error reading NIFTI file {path}: {e}")

    @staticmethod
    def _extract_middle_slice(data: np.ndarray) -> np.ndarray:
        """Gets the middle basal apical slice"""
        # We approximate the middle slice by taking the middle slice along the z-axis
        return data[:, :, data.shape[2] // 2, :]

    @staticmethod
    def _extract_frames(slice_data: np.ndarray, frame_indices: List[int]) -> np.ndarray:
        num_frames = slice_data.shape[-1]
        invalid_indices = [idx for idx in frame_indices if idx < 0 or idx >= num_frames]

        if invalid_indices:
            raise IndexError(
                f"Frame indices out of bounds: {invalid_indices}. "
                f"Valid range: [0, {num_frames - 1}]"
            )

        return slice_data[:, :, frame_indices]

    @cache
    def _parse_config(self, path: Path) -> Dict[str, Any]:
        """Parse config file with caching."""
        config = configparser.ConfigParser()
        with open(path, "r") as f:
            config.read_string("[DEFAULT]\n" + f.read())

        parsed_values = {
            "ed_frame": config["DEFAULT"].getint("ED"),
            "es_frame": config["DEFAULT"].getint("ES"),
            "group": config["DEFAULT"].get("Group"),
            "height": config["DEFAULT"].getfloat("Height"),
            "weight": config["DEFAULT"].getfloat("Weight"),
            "nb_frames": config["DEFAULT"].getint("NbFrame"),
        }
        return {k: v for k, v in parsed_values.items() if v is not None}
get_all_record_ids()

Get all record IDs without loading data.

Source code in src/data/raw/cmr/acdc.py
def get_all_record_ids(self) -> List[str]:
    """Get all record IDs without loading data."""
    return list(self.__nifti_files.keys())
get_stream()

Stream records one at a time to reduce memory usage.

Source code in src/data/raw/cmr/acdc.py
def get_stream(self) -> Generator[RawRecord, None, None]:
    """Stream records one at a time to reduce memory usage."""
    for record_id in self.get_all_record_ids():
        yield self.load_record(record_id)
get_target_labels() cached

Get list of target labels using streaming to reduce memory usage.

Source code in src/data/raw/cmr/acdc.py
@cache
def get_target_labels(self) -> List[str]:
    """Get list of target labels using streaming to reduce memory usage."""
    groups = set()
    for record_id in self.get_all_record_ids():
        _, config_path = self._get_record_paths(record_id)
        group = self._parse_config(config_path).get("group")
        if group is not None:
            groups.add(group)
    return sorted(list(groups))

Using the Unified Dataset

In this section, we will see how to use the UnifiedDataset class to load and preprocess data from a dataset.

The UnifiedDataset can easily be initialized using the data_root, modality and dataset_key arguments. It will then load the dataset's metadata and provide methods for preprocessing and embedding the data.

from pathlib import Path
from src.data.dataset import DatasetModality
from src.data.unified import UnifiedDataset

# Initialize dataset
dataset = UnifiedDataset(
    data_root=Path("data"),
    modality=DatasetModality.ECG,
    dataset_key="ptbxl"
)

# Access data
record = dataset["patient_001"]
raw_data = record.raw_record.data
preprocessed = record.preprocessed_record
embeddings = record.embeddings

# Get splits and metadata
splits = dataset.get_splits() if dataset.has_splits() else None
metadata_fields = dataset.available_metadata_fields()

# Verify integrity
dataset.verify_integrity()

Next Steps