Working with Datasets
Overview
This guide explains how to work with datasets in the project. The system is designed to handle multiple modalities (ECG, CMR) through a unified pipeline that consists of three main stages:
- Raw Data Handling: Loading and validating original data files
- Preprocessing: Converting data into standardized tensor format
- Unified Access: Bringing everything together for analysis and training
Data Organization
The project follows a consistent directory structure across all datasets:
data/
├── raw/ # Original dataset files
│ ├── cmr/ # Cardiac MRI datasets
│ │ └── acdc/ # ACDC dataset files
│ └── ecg/ # ECG datasets
│ └── ptbxl/ # PTB-XL dataset files
├── interim/ # Preprocessed tensors (*.pt files)
│ ├── acdc/ # Preprocessed ACDC records
│ └── ptbxl/ # Preprocessed PTB-XL records
├── processed/ # Final dataset artifacts
│ ├── acdc/ # ACDC splits and metadata
│ │ ├── splits.json # Train/val/test splits
│ │ └── metadata.db # Record metadata
│ └── ptbxl/ # PTB-XL splits and metadata
└── embeddings/ # Pre-computed embeddings (optional)
├── type1/ # Embeddings from model 1
└── type2/ # Embeddings from model 2
Dataset Components
1. Raw Dataset Handlers
Raw dataset handlers provide the interface to original data files. They handle:
- Data loading and validation
- Metadata extraction
- Label processing
- Format standardization
The base classes that all handlers must implement:
src.data.raw.data
RawDataset
Bases: BaseDataset, ABC
Base class for handling raw medical data.
Source code in src/data/raw/data.py
| class RawDataset(BaseDataset, ABC):
"""Base class for handling raw medical data."""
def __init__(self, data_root: Path):
super().__init__(data_root)
if not self.paths["raw"].exists():
raise ValueError(f"Raw data path does not exist: {self.paths['raw']}")
@abstractmethod
def verify_data(self) -> None:
"""Verify raw data structure and contents. Throw error if invalid."""
pass
@abstractmethod
def get_metadata(self) -> Dict[str, Any]:
"""Get dataset-specific metadata."""
pass
@abstractmethod
def get_target_labels(self) -> List[str]:
"""Get list of target labels."""
pass
@abstractmethod
def get_all_record_ids(self) -> List[str]:
"""Get all available record IDs without loading data."""
pass
@abstractmethod
def load_record(self, record_id: str) -> RawRecord:
"""Load a single record."""
pass
@abstractmethod
def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream raw records one at a time."""
pass
|
get_all_record_ids()
abstractmethod
Get all available record IDs without loading data.
Source code in src/data/raw/data.py
| @abstractmethod
def get_all_record_ids(self) -> List[str]:
"""Get all available record IDs without loading data."""
pass
|
Get dataset-specific metadata.
Source code in src/data/raw/data.py
| @abstractmethod
def get_metadata(self) -> Dict[str, Any]:
"""Get dataset-specific metadata."""
pass
|
get_stream()
abstractmethod
Stream raw records one at a time.
Source code in src/data/raw/data.py
| @abstractmethod
def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream raw records one at a time."""
pass
|
get_target_labels()
abstractmethod
Get list of target labels.
Source code in src/data/raw/data.py
| @abstractmethod
def get_target_labels(self) -> List[str]:
"""Get list of target labels."""
pass
|
load_record(record_id)
abstractmethod
Load a single record.
Source code in src/data/raw/data.py
| @abstractmethod
def load_record(self, record_id: str) -> RawRecord:
"""Load a single record."""
pass
|
verify_data()
abstractmethod
Verify raw data structure and contents. Throw error if invalid.
Source code in src/data/raw/data.py
| @abstractmethod
def verify_data(self) -> None:
"""Verify raw data structure and contents. Throw error if invalid."""
pass
|
RawRecord
dataclass
Bases: DatasetRecord
Data class for raw samples.
Source code in src/data/raw/data.py
| @dataclass
class RawRecord(DatasetRecord):
"""Data class for raw samples."""
data: ndarray
target_labels: List[str] | None
metadata: Dict[str, Any]
def __str__(self) -> str:
return f"RawRecord(id={self.id}, targets={";".join(self.target_labels)})"
def __post_init__(self):
if not isinstance(self.data, ndarray):
raise ValueError(f"Data must be a numpy array, got {type(self.data)}")
if not isinstance(self.target_labels, list):
raise ValueError(f"Targets must be a list, got {type(self.target_labels)}")
if not isinstance(self.metadata, dict):
raise ValueError(
f"Metadata must be a dictionary, got {type(self.metadata)}"
)
|
Dataset handlers are registered using:
src.data.raw.registry
RawDatasetRegistry
Registry for raw data handlers.
Source code in src/data/raw/registry.py
| class RawDatasetRegistry:
"""Registry for raw data handlers."""
_registry: Dict[str, Dict[str, Type[RawDataset]]] = {"ecg": {}, "cmr": {}}
@classmethod
def register(cls, modality: str, dataset_key: str):
"""Decorator to register a raw data handler."""
def decorator(raw_data_class: Type[RawDataset]):
if modality not in cls._registry:
cls._registry[modality] = {}
cls._registry[modality][dataset_key] = raw_data_class
return raw_data_class
return decorator
@classmethod
def get_handler(cls, modality: str, dataset_key: str) -> Type[RawDataset]:
"""Get raw data handler by modality and key."""
if modality not in cls._registry:
raise ValueError(f"Unknown modality: {modality}")
if dataset_key not in cls._registry[modality]:
raise ValueError(f"Unknown dataset key for {modality}: {dataset_key}")
return cls._registry[modality][dataset_key]
@classmethod
def get_modality(cls, dataset_key: str) -> DatasetModality:
"""Get modality by dataset key."""
for modality, datasets in cls._registry.items():
if dataset_key in datasets:
return DatasetModality(modality)
raise ValueError(f"Unknown dataset key: {dataset_key}")
@classmethod
def list_datasets(cls) -> Dict[str, List[str]]:
"""List available datasets per modality."""
return {
modality: list(datasets.keys())
for modality, datasets in cls._registry.items()
}
@classmethod
def list_modalities(cls) -> List[str]:
"""List available modalities."""
return list(cls._registry.keys())
|
get_handler(modality, dataset_key)
classmethod
Get raw data handler by modality and key.
Source code in src/data/raw/registry.py
| @classmethod
def get_handler(cls, modality: str, dataset_key: str) -> Type[RawDataset]:
"""Get raw data handler by modality and key."""
if modality not in cls._registry:
raise ValueError(f"Unknown modality: {modality}")
if dataset_key not in cls._registry[modality]:
raise ValueError(f"Unknown dataset key for {modality}: {dataset_key}")
return cls._registry[modality][dataset_key]
|
get_modality(dataset_key)
classmethod
Get modality by dataset key.
Source code in src/data/raw/registry.py
| @classmethod
def get_modality(cls, dataset_key: str) -> DatasetModality:
"""Get modality by dataset key."""
for modality, datasets in cls._registry.items():
if dataset_key in datasets:
return DatasetModality(modality)
raise ValueError(f"Unknown dataset key: {dataset_key}")
|
list_datasets()
classmethod
List available datasets per modality.
Source code in src/data/raw/registry.py
| @classmethod
def list_datasets(cls) -> Dict[str, List[str]]:
"""List available datasets per modality."""
return {
modality: list(datasets.keys())
for modality, datasets in cls._registry.items()
}
|
list_modalities()
classmethod
List available modalities.
Source code in src/data/raw/registry.py
| @classmethod
def list_modalities(cls) -> List[str]:
"""List available modalities."""
return list(cls._registry.keys())
|
register(modality, dataset_key)
classmethod
Decorator to register a raw data handler.
Source code in src/data/raw/registry.py
| @classmethod
def register(cls, modality: str, dataset_key: str):
"""Decorator to register a raw data handler."""
def decorator(raw_data_class: Type[RawDataset]):
if modality not in cls._registry:
cls._registry[modality] = {}
cls._registry[modality][dataset_key] = raw_data_class
return raw_data_class
return decorator
|
2. Unified Dataset System
The unified dataset system brings everything together, providing:
- Access to raw, preprocessed, and embedded data
- Automatic data integrity validation
- Efficient caching
- Comprehensive metadata management
src.data.unified
UnifiedDataset
Bases: BaseDataset
Source code in src/data/unified.py
| class UnifiedDataset(BaseDataset):
def __init__(self, data_root: Path, modality: DatasetModality, dataset_key: str):
super().__init__(data_root)
self._dataset_key = dataset_key
self._modality = modality
@property
def dataset_key(self) -> str:
return self._dataset_key
@property
def modality(self) -> DatasetModality:
return self._modality
@cached_property
def metadata_store(self) -> MetadataStore:
return MetadataStore(data_root=self.paths["processed"])
@cached_property
def raw_dataset(self) -> RawDataset:
return RawDatasetRegistry.get_handler(self.modality.value, self.dataset_key)(
self.data_root
)
def has_dataset_info(self) -> bool:
return self.paths["misc"]["dataset_info"].exists()
def get_dataset_info(self) -> Dict[str, Any]:
if not self.has_dataset_info():
raise ValueError("No dataset info found for this dataset.")
with open(self.paths["misc"]["dataset_info"], "r") as f:
return json.load(f)
def has_splits(self) -> bool:
return self.paths["misc"]["splits_file"].exists()
def get_splits(self) -> Dict[str, Dict[str, Any]]:
if not self.has_splits():
raise ValueError("No splits found for this dataset.")
with open(self.paths["misc"]["splits_file"], "r") as f:
splits = json.load(f)
return splits
def get_split_by_record_id(self, record_id: str) -> str:
for split_name, record_ids in self.get_splits().items():
if record_id in record_ids:
return split_name
raise ValueError(f"Record ID '{record_id}' not found in any split.")
def __get_all_record_ids_from_splits(self) -> List[str]:
return [
record_id for split in self.get_splits().values() for record_id in split
]
def __get_all_record_ids_from_interim(self) -> List[str]:
interim_files = self.paths["interim"].glob("*.pt")
return [f.stem for f in interim_files]
@cache
def get_all_record_ids(self) -> List[str]:
# if self.has_splits():
# return self.__get_all_record_ids_from_splits()
return (
self.__get_all_record_ids_from_interim()
) # interim is the ground truth for record ids
@cache
def __load_embeddings(self, embeddings_type: str) -> Dict[str, torch.Tensor]:
if embeddings_type not in self.paths["embeddings"].keys():
raise ValueError(f"Embeddings type '{embeddings_type}' not found.")
embeddings_path = self.paths["embeddings"][embeddings_type]
if not embeddings_path.exists():
return {}
return torch.load(embeddings_path, map_location="cpu") # type: ignore
@lru_cache(maxsize=1000)
def __load_preprocessed_record(self, record_id: str) -> PreprocessedRecord:
if record_id not in self.get_all_record_ids():
raise ValueError(f"Record ID '{record_id}' not found in dataset.")
interim_file = self.paths["interim"] / f"{record_id}.pt"
if not interim_file.exists():
raise ValueError(f"Interim file not found for record ID '{record_id}'.")
return torch.load(interim_file, map_location="cpu") # type: ignore
def get_embeddings(
self, record_id: str, embeddings_type: str = None
) -> Dict[str, torch.Tensor] | torch.Tensor:
if embeddings_type is None:
return {
k: self.__load_embeddings(k).get(record_id, None)
for k in self.paths["embeddings"].keys()
}
return self.__load_embeddings(embeddings_type).get(record_id, None)
def available_metadata_fields(self) -> set[str]:
return self.metadata_store.available_fields()
@lru_cache(maxsize=1000)
def __getitem__(self, record_id: str) -> UnifiedRecord:
preprocessed_record = self.__load_preprocessed_record(record_id)
embeddings = self.get_embeddings(record_id)
raw_record = self.raw_dataset.load_record(record_id)
return UnifiedRecord(
id=record_id,
raw_record=raw_record,
preprocessed_record=preprocessed_record,
embeddings=embeddings,
)
def verify_integrity(self) -> None:
"""Validate dataset integrity using interim files as ground truth.
Performs the following assertions:
1. At least one interim file exists (raises ValueError if empty)
2. If splits exist:
a. All split IDs must exist in interim files (raises on missing IDs)
b. All interim IDs must exist in splits (raises on coverage mismatch)
c. No duplicate IDs across splits (raises on duplicates)
d. Split data must be stored as lists (raises on format error)
3. Embeddings files (if exist):
a. Must not contain IDs absent from interim files (raises on extra IDs)
4. Metadata:
a. Must exist for every interim ID (raises on missing metadata)
Raises:
ValueError: For any integrity check failure, with detailed message
FileNotFoundError: If critical path components are missing
Note:
Interim files (*.pt in interim directory) are considered the canonical
source of truth for valid record IDs. All other components (splits,
embeddings, metadata) must align with these IDs.
"""
interim_ids = set(self.get_all_record_ids())
# First check there's at least one interim record
if not interim_ids:
raise ValueError(f"No interim files found in {self.paths['interim']}")
# Check if every raw record was preprocessed
raw_record_ids = set(self.raw_dataset.get_all_record_ids())
missing_preprocessed = raw_record_ids - interim_ids
if missing_preprocessed:
raise ValueError(
f"{len(missing_preprocessed)} raw records missing from interim files. "
f"First 5: {sorted(missing_preprocessed)[:5]}"
)
# Split consistency checks (if splits exist)
if self.has_splits():
splits = self.get_splits()
split_ids = set()
all_split_ids = []
# Collect all split IDs and check per-split validity
for split_name, split_records in splits.items():
if not isinstance(split_records, list):
raise ValueError(
f"Split '{split_name}' should be a list of IDs, got {type(split_records)}"
)
current_split = set(split_records)
split_ids.update(current_split)
all_split_ids.extend(split_records)
# Check individual split validity
missing_in_interim = current_split - interim_ids
if missing_in_interim:
raise ValueError(
f"Split '{split_name}' contains {len(missing_in_interim)} "
f"IDs missing from interim files. First 5: {sorted(missing_in_interim)[:5]}"
)
# Check full split coverage
if split_ids != interim_ids:
missing_in_splits = interim_ids - split_ids
extra_in_splits = split_ids - interim_ids
error_msg = []
if missing_in_splits:
error_msg.append(
f"{len(missing_in_splits)} interim IDs missing from splits"
)
if extra_in_splits:
error_msg.append(
f"{len(extra_in_splits)} extra IDs in splits not in interim"
)
raise ValueError("Split coverage mismatch: " + ", ".join(error_msg))
# Check for duplicate IDs across splits
duplicate_ids = {id for id in all_split_ids if all_split_ids.count(id) > 1}
if duplicate_ids:
raise ValueError(
f"{len(duplicate_ids)} duplicate IDs found across splits. "
f"First 5: {sorted(duplicate_ids)[:5]}"
)
# Embeddings consistency checks
for emb_type, emb_path in self.paths["embeddings"].items():
if emb_path.exists():
embeddings = self.__load_embeddings(emb_type)
if not embeddings:
continue # Skip empty embeddings files
emb_ids = set(embeddings.keys())
extra_emb_ids = emb_ids - interim_ids
if extra_emb_ids:
raise ValueError(
f"{emb_type} embeddings contain {len(extra_emb_ids)} IDs "
f"not in interim files. First 5: {sorted(extra_emb_ids)[:5]}"
)
# Metadata completeness check
missing_metadata = []
for record_id in interim_ids:
if not self.metadata_store.get(record_id):
missing_metadata.append(record_id)
if len(missing_metadata) >= 5: # Early exit for large datasets
break
if missing_metadata:
raise ValueError(
f"Missing metadata for {len(missing_metadata)} records. "
f"First 5: {missing_metadata[:5]}"
)
def __str__(self) -> str:
return f"ProcessedDataset(data_root={self.data_root}, modality={self.modality}, dataset_key={self.dataset_key})"
|
verify_integrity()
Validate dataset integrity using interim files as ground truth.
Performs the following assertions
- At least one interim file exists (raises ValueError if empty)
- If splits exist:
a. All split IDs must exist in interim files (raises on missing IDs)
b. All interim IDs must exist in splits (raises on coverage mismatch)
c. No duplicate IDs across splits (raises on duplicates)
d. Split data must be stored as lists (raises on format error)
- Embeddings files (if exist):
a. Must not contain IDs absent from interim files (raises on extra IDs)
- Metadata:
a. Must exist for every interim ID (raises on missing metadata)
Raises:
| Type |
Description |
ValueError
|
For any integrity check failure, with detailed message
|
FileNotFoundError
|
If critical path components are missing
|
Note
Interim files (*.pt in interim directory) are considered the canonical
source of truth for valid record IDs. All other components (splits,
embeddings, metadata) must align with these IDs.
Source code in src/data/unified.py
| def verify_integrity(self) -> None:
"""Validate dataset integrity using interim files as ground truth.
Performs the following assertions:
1. At least one interim file exists (raises ValueError if empty)
2. If splits exist:
a. All split IDs must exist in interim files (raises on missing IDs)
b. All interim IDs must exist in splits (raises on coverage mismatch)
c. No duplicate IDs across splits (raises on duplicates)
d. Split data must be stored as lists (raises on format error)
3. Embeddings files (if exist):
a. Must not contain IDs absent from interim files (raises on extra IDs)
4. Metadata:
a. Must exist for every interim ID (raises on missing metadata)
Raises:
ValueError: For any integrity check failure, with detailed message
FileNotFoundError: If critical path components are missing
Note:
Interim files (*.pt in interim directory) are considered the canonical
source of truth for valid record IDs. All other components (splits,
embeddings, metadata) must align with these IDs.
"""
interim_ids = set(self.get_all_record_ids())
# First check there's at least one interim record
if not interim_ids:
raise ValueError(f"No interim files found in {self.paths['interim']}")
# Check if every raw record was preprocessed
raw_record_ids = set(self.raw_dataset.get_all_record_ids())
missing_preprocessed = raw_record_ids - interim_ids
if missing_preprocessed:
raise ValueError(
f"{len(missing_preprocessed)} raw records missing from interim files. "
f"First 5: {sorted(missing_preprocessed)[:5]}"
)
# Split consistency checks (if splits exist)
if self.has_splits():
splits = self.get_splits()
split_ids = set()
all_split_ids = []
# Collect all split IDs and check per-split validity
for split_name, split_records in splits.items():
if not isinstance(split_records, list):
raise ValueError(
f"Split '{split_name}' should be a list of IDs, got {type(split_records)}"
)
current_split = set(split_records)
split_ids.update(current_split)
all_split_ids.extend(split_records)
# Check individual split validity
missing_in_interim = current_split - interim_ids
if missing_in_interim:
raise ValueError(
f"Split '{split_name}' contains {len(missing_in_interim)} "
f"IDs missing from interim files. First 5: {sorted(missing_in_interim)[:5]}"
)
# Check full split coverage
if split_ids != interim_ids:
missing_in_splits = interim_ids - split_ids
extra_in_splits = split_ids - interim_ids
error_msg = []
if missing_in_splits:
error_msg.append(
f"{len(missing_in_splits)} interim IDs missing from splits"
)
if extra_in_splits:
error_msg.append(
f"{len(extra_in_splits)} extra IDs in splits not in interim"
)
raise ValueError("Split coverage mismatch: " + ", ".join(error_msg))
# Check for duplicate IDs across splits
duplicate_ids = {id for id in all_split_ids if all_split_ids.count(id) > 1}
if duplicate_ids:
raise ValueError(
f"{len(duplicate_ids)} duplicate IDs found across splits. "
f"First 5: {sorted(duplicate_ids)[:5]}"
)
# Embeddings consistency checks
for emb_type, emb_path in self.paths["embeddings"].items():
if emb_path.exists():
embeddings = self.__load_embeddings(emb_type)
if not embeddings:
continue # Skip empty embeddings files
emb_ids = set(embeddings.keys())
extra_emb_ids = emb_ids - interim_ids
if extra_emb_ids:
raise ValueError(
f"{emb_type} embeddings contain {len(extra_emb_ids)} IDs "
f"not in interim files. First 5: {sorted(extra_emb_ids)[:5]}"
)
# Metadata completeness check
missing_metadata = []
for record_id in interim_ids:
if not self.metadata_store.get(record_id):
missing_metadata.append(record_id)
if len(missing_metadata) >= 5: # Early exit for large datasets
break
if missing_metadata:
raise ValueError(
f"Missing metadata for {len(missing_metadata)} records. "
f"First 5: {missing_metadata[:5]}"
)
|
UnifiedRecord
dataclass
Bases: DatasetRecord
Source code in src/data/unified.py
| @dataclass
class UnifiedRecord(DatasetRecord):
raw_record: RawRecord = None
preprocessed_record: PreprocessedRecord = None
embeddings: Optional[torch.Tensor] = None
def __str__(self) -> str:
return f"UnifiedPreprocessedRecord(id={self.id}, has_embeddings={self.embeddings is not None})"
|
Implementation Guide
Now that you have a basic understanding of the unified dataset system, let's go through the implementation steps for creating a new dataset handler given some raw data.
Creating a New Dataset Handler
1. Choose the Appropriate Location:
- ECG handlers:
src/data/raw/ecg/<dataset_name>.py
- CMR handlers:
src/data/raw/cmr/<dataset_name>.py
2. Implement Required Methods:
All handlers must implement the RawDataset abstract base class, which defines the structure and contents of the raw data.
src.data.raw.data
RawDataset
Bases: BaseDataset, ABC
Base class for handling raw medical data.
Source code in src/data/raw/data.py
| class RawDataset(BaseDataset, ABC):
"""Base class for handling raw medical data."""
def __init__(self, data_root: Path):
super().__init__(data_root)
if not self.paths["raw"].exists():
raise ValueError(f"Raw data path does not exist: {self.paths['raw']}")
@abstractmethod
def verify_data(self) -> None:
"""Verify raw data structure and contents. Throw error if invalid."""
pass
@abstractmethod
def get_metadata(self) -> Dict[str, Any]:
"""Get dataset-specific metadata."""
pass
@abstractmethod
def get_target_labels(self) -> List[str]:
"""Get list of target labels."""
pass
@abstractmethod
def get_all_record_ids(self) -> List[str]:
"""Get all available record IDs without loading data."""
pass
@abstractmethod
def load_record(self, record_id: str) -> RawRecord:
"""Load a single record."""
pass
@abstractmethod
def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream raw records one at a time."""
pass
|
get_all_record_ids()
abstractmethod
Get all available record IDs without loading data.
Source code in src/data/raw/data.py
| @abstractmethod
def get_all_record_ids(self) -> List[str]:
"""Get all available record IDs without loading data."""
pass
|
Get dataset-specific metadata.
Source code in src/data/raw/data.py
| @abstractmethod
def get_metadata(self) -> Dict[str, Any]:
"""Get dataset-specific metadata."""
pass
|
get_stream()
abstractmethod
Stream raw records one at a time.
Source code in src/data/raw/data.py
| @abstractmethod
def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream raw records one at a time."""
pass
|
get_target_labels()
abstractmethod
Get list of target labels.
Source code in src/data/raw/data.py
| @abstractmethod
def get_target_labels(self) -> List[str]:
"""Get list of target labels."""
pass
|
load_record(record_id)
abstractmethod
Load a single record.
Source code in src/data/raw/data.py
| @abstractmethod
def load_record(self, record_id: str) -> RawRecord:
"""Load a single record."""
pass
|
verify_data()
abstractmethod
Verify raw data structure and contents. Throw error if invalid.
Source code in src/data/raw/data.py
| @abstractmethod
def verify_data(self) -> None:
"""Verify raw data structure and contents. Throw error if invalid."""
pass
|
The RawRecord object is defined along the same lines as the RawDataset class, and is used to represent a single record in the dataset. It contains the raw data and metadata for that record ready for preprocessing.
src.data.raw.data
RawRecord
dataclass
Bases: DatasetRecord
Data class for raw samples.
Source code in src/data/raw/data.py
| @dataclass
class RawRecord(DatasetRecord):
"""Data class for raw samples."""
data: ndarray
target_labels: List[str] | None
metadata: Dict[str, Any]
def __str__(self) -> str:
return f"RawRecord(id={self.id}, targets={";".join(self.target_labels)})"
def __post_init__(self):
if not isinstance(self.data, ndarray):
raise ValueError(f"Data must be a numpy array, got {type(self.data)}")
if not isinstance(self.target_labels, list):
raise ValueError(f"Targets must be a list, got {type(self.target_labels)}")
if not isinstance(self.metadata, dict):
raise ValueError(
f"Metadata must be a dictionary, got {type(self.metadata)}"
)
|
3. Register the Dataset Handler:
In order to use your new dataset handler, you need to register it in the src/data/__init__.py module.
# Import datasets in order to register them
import src.data.raw.cmr.acdc
import src.data.raw.ecg.arrhythmia
import src.data.raw.ecg.grouped_arrhythmia
import src.data.raw.ecg.shandong
import src.data.raw.ecg.ptb_xl
Example Implementations of Dataset Handlers
src.data.raw.ecg.ptb_xl
PTBXL
Bases: RawDataset
Handler for raw PTB-XL ECG data at 500 Hz.
Expected data structure in the raw data directory (self.paths["raw"]):
ptbxl/
├── ptbxl_database.csv # Contains one row per record (indexed by ecg_id) with extensive metadata
├── scp_statements.csv # Contains SCP-ECG annotation mappings (optional)
└── records500/ # WFDB records at 500 Hz
The ptbxl_database.csv file contains many columns including:
- ecg_id, patient_id, age, sex, height, weight, nurse, site, device, recording_date, etc.
- scp_codes: a string representation of a dictionary mapping SCP-ECG statement codes to likelihoods.
- filename_hr: path (relative to the ptbxl folder) to the WFDB record for 500 Hz.
Source code in src/data/raw/ecg/ptb_xl.py
| @RawDatasetRegistry.register(DatasetModality.ECG.value, "ptbxl")
class PTBXL(RawDataset):
"""Handler for raw PTB-XL ECG data at 500 Hz.
Expected data structure in the raw data directory (self.paths["raw"]):
ptbxl/
├── ptbxl_database.csv # Contains one row per record (indexed by ecg_id) with extensive metadata
├── scp_statements.csv # Contains SCP-ECG annotation mappings (optional)
└── records500/ # WFDB records at 500 Hz
The ptbxl_database.csv file contains many columns including:
- ecg_id, patient_id, age, sex, height, weight, nurse, site, device, recording_date, etc.
- scp_codes: a string representation of a dictionary mapping SCP-ECG statement codes to likelihoods.
- filename_hr: path (relative to the ptbxl folder) to the WFDB record for 500 Hz.
"""
dataset_key = "ptbxl"
modality = DatasetModality.ECG
def __init__(self, data_root: Path):
"""
Args:
data_root: Path to the dataset root folder.
It should contain the ptbxl folder with the PTB-XL files.
"""
super().__init__(data_root)
self.root = self.paths["raw"]
self.database_csv = self.root / "ptbxl_database.csv"
self.scp_csv = self.root / "scp_statements.csv"
if not self.database_csv.exists():
raise FileNotFoundError(f"Database CSV not found: {self.database_csv}")
if not self.scp_csv.exists():
logger.warning(f"SCP statements CSV not found: {self.scp_csv}")
# Use "ecg_id" as the index for fast lookup.
self.metadata_df = pd.read_csv(self.database_csv, index_col="ecg_id")
# ecg_id is unique, but patient_id is not. Group by patient_id and use random record
self.metadata_df = (
self.metadata_df.groupby("patient_id")
.sample(n=1, random_state=1337)
.reset_index(drop=True)
)
self.metadata_df["patient_id"] = self.metadata_df["patient_id"].astype(int)
self.metadata_df["scp_codes"] = self.metadata_df["scp_codes"].apply(
lambda x: ast.literal_eval(x)
if isinstance(x, str)
else x # scp codes are stored as dict strings
)
self.statements_dict = pd.read_csv(self.scp_csv, index_col=0).to_dict(
orient="index"
)
self.metadata_dict: Dict[str, Any] = self.metadata_df.to_dict(orient="index")
self.records_folder = (
self.root / "records500"
) # we are only interested in the 500Hz data
if not self.records_folder.exists():
raise FileNotFoundError(f"Records folder not found: {self.records_folder}")
self.filename_col = (
"filename_hr" # 500 Hz records / for 100 hz use "filename_lr"
)
def verify_data(self) -> None:
"""Verify the raw data structure and contents."""
if not self.database_csv.exists():
raise ValueError(f"Database CSV does not exist: {self.database_csv}")
if not self.records_folder.exists():
raise ValueError(f"Records folder does not exist: {self.records_folder}")
for record_id in self.get_all_record_ids():
file_path = self._get_file_path(record_id)
header_path = file_path.with_suffix(".hea")
dat_path = file_path.with_suffix(".dat")
if not header_path.exists():
raise FileNotFoundError(
f"WFDB header file for record {record_id} not found: {header_path}"
)
if not dat_path.exists():
raise FileNotFoundError(
f"WFDB data file for record {record_id} not found: {dat_path}"
)
# based on https://physionet.org/content/ptb-xl/1.0.3/, we want only unique patients with one record per patient
assert (
len(self.metadata_df) == 18869
), f"Expected 18869 records, got {len(self.metadata_df)}"
def get_metadata(self) -> Dict[str, Any]:
"""Get dataset-specific metadata.
Returns:
A dictionary containing dataset details.
"""
return {
"num_leads": 12,
"sampling_rate": 500,
"record_length_seconds": 10,
"total_records": len(self.metadata_df),
"num_patients": int(self.metadata_df["patient_id"].nunique()),
}
def _get_file_path(self, record_id: str) -> Path:
"""Get the full path to the WFDB record file for a given record ID.
Args:
record_id: The unique ECG record identifier.
Returns:
Full path to the WFDB record file.
"""
return self.root / self.metadata_dict[int(record_id)][self.filename_col]
@cache
def get_target_labels(self) -> List[str]:
"""Extract all unique SCP codes (with likelihood > 0) from the metadata.
Returns:
Sorted list of unique target label codes.
"""
labels = set()
for record_id in self.get_all_record_ids():
codes = self._extract_scp_codes(int(record_id))
labels.update(self._extract_target_labels(codes))
return sorted(list(labels))
def get_all_record_ids(self) -> List[str]:
"""Get all available record IDs without loading the actual data.
Returns:
List of record identifiers (ecg_id) as strings.
"""
return self.metadata_df.index.astype(str).tolist()
def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream raw records one at a time.
Yields:
RawRecord objects for each record.
"""
for record_id in self.get_all_record_ids():
try:
yield self.load_record(record_id)
except Exception as e:
logger.error(f"Error loading record {record_id}: {e}")
continue
@cache
def load_record(self, record_id: str) -> RawRecord:
"""Load a single ECG record by its identifier.
Args:
record_id: The unique ECG record identifier.
Returns:
A RawRecord object containing the ECG signal data, target labels, and metadata.
"""
record_id = int(record_id)
if record_id not in self.metadata_dict:
raise ValueError(f"Record {record_id} not found in metadata.")
meta = self.metadata_dict[record_id]
file_path = self._get_file_path(record_id)
record = self._read_wfdb_record(record_id)
data = record[0]
codes = self._extract_scp_codes(record_id)
target_labels = self._extract_target_labels(codes)
record_metadata = dict(meta)
record_metadata["wfdb_file"] = str(file_path.relative_to(self.root))
record_metadata["signal_fields"] = record[1]
record_metadata["scp_statements"] = {
code: self._extract_scp_statement(code) for code in codes
}
return RawRecord(
id=record_id,
data=data,
target_labels=target_labels,
metadata=record_metadata,
)
def _extract_scp_statement(self, code: str) -> str:
"""Extract the SCP statement for a given code from the SCP statements CSV.
Args:
code: SCP code to look up.
Returns:
SCP statement description.
"""
if code in self.statements_dict:
return self.statements_dict[code]
raise ValueError(f"SCP code {code} not found in statements CSV.")
def _extract_target_labels(self, scp_codes: Dict[str, float]) -> List[str]:
"""Extract target labels from the SCP codes dictionary.
Args:
scp_codes: Dictionary mapping SCP codes to likelihoods.
Returns:
List of target labels with likelihood > 0.
"""
return [
code for code, likelihood in scp_codes.items() if float(likelihood) >= 0
] # include unknown codes (likelihood == 0): "... where likelihood is set to 0 if unknown) ..."
def _extract_scp_codes(self, record_id: int) -> Dict[str, float]:
"""Extract SCP codes and likelihoods for a given record ID.
Args:
record_id: The unique ECG record identifier.
Returns:
Dictionary mapping SCP codes to likelihoods.
"""
codes = self.metadata_dict[record_id].get("scp_codes")
if not isinstance(codes, dict):
raise ValueError(f"SCP codes not found for record {record_id}")
return codes
def _read_wfdb_record(self, record_id: int) -> Tuple[np.ndarray, Dict[str, Any]]:
"""Read the WFDB record data for a given record ID.
Args:
record_id: The unique ECG record identifier.
Returns:
Numpy array containing the ECG signal data.
"""
file_path = self._get_file_path(record_id)
try:
record = wfdb.rdsamp(str(file_path))
data = np.array(record[0], dtype=np.float32)
data = np.swapaxes(data, 0, 1)
return data, record[1]
except Exception as e:
raise RuntimeError(f"Error reading WFDB record {file_path}: {e}") from e
|
__init__(data_root)
Parameters:
| Name |
Type |
Description |
Default |
data_root
|
Path
|
Path to the dataset root folder.
It should contain the ptbxl folder with the PTB-XL files.
|
required
|
Source code in src/data/raw/ecg/ptb_xl.py
| def __init__(self, data_root: Path):
"""
Args:
data_root: Path to the dataset root folder.
It should contain the ptbxl folder with the PTB-XL files.
"""
super().__init__(data_root)
self.root = self.paths["raw"]
self.database_csv = self.root / "ptbxl_database.csv"
self.scp_csv = self.root / "scp_statements.csv"
if not self.database_csv.exists():
raise FileNotFoundError(f"Database CSV not found: {self.database_csv}")
if not self.scp_csv.exists():
logger.warning(f"SCP statements CSV not found: {self.scp_csv}")
# Use "ecg_id" as the index for fast lookup.
self.metadata_df = pd.read_csv(self.database_csv, index_col="ecg_id")
# ecg_id is unique, but patient_id is not. Group by patient_id and use random record
self.metadata_df = (
self.metadata_df.groupby("patient_id")
.sample(n=1, random_state=1337)
.reset_index(drop=True)
)
self.metadata_df["patient_id"] = self.metadata_df["patient_id"].astype(int)
self.metadata_df["scp_codes"] = self.metadata_df["scp_codes"].apply(
lambda x: ast.literal_eval(x)
if isinstance(x, str)
else x # scp codes are stored as dict strings
)
self.statements_dict = pd.read_csv(self.scp_csv, index_col=0).to_dict(
orient="index"
)
self.metadata_dict: Dict[str, Any] = self.metadata_df.to_dict(orient="index")
self.records_folder = (
self.root / "records500"
) # we are only interested in the 500Hz data
if not self.records_folder.exists():
raise FileNotFoundError(f"Records folder not found: {self.records_folder}")
self.filename_col = (
"filename_hr" # 500 Hz records / for 100 hz use "filename_lr"
)
|
get_all_record_ids()
Get all available record IDs without loading the actual data.
Returns:
| Type |
Description |
List[str]
|
List of record identifiers (ecg_id) as strings.
|
Source code in src/data/raw/ecg/ptb_xl.py
| def get_all_record_ids(self) -> List[str]:
"""Get all available record IDs without loading the actual data.
Returns:
List of record identifiers (ecg_id) as strings.
"""
return self.metadata_df.index.astype(str).tolist()
|
Get dataset-specific metadata.
Returns:
| Type |
Description |
Dict[str, Any]
|
A dictionary containing dataset details.
|
Source code in src/data/raw/ecg/ptb_xl.py
| def get_metadata(self) -> Dict[str, Any]:
"""Get dataset-specific metadata.
Returns:
A dictionary containing dataset details.
"""
return {
"num_leads": 12,
"sampling_rate": 500,
"record_length_seconds": 10,
"total_records": len(self.metadata_df),
"num_patients": int(self.metadata_df["patient_id"].nunique()),
}
|
get_stream()
Stream raw records one at a time.
Yields:
| Type |
Description |
RawRecord
|
RawRecord objects for each record.
|
Source code in src/data/raw/ecg/ptb_xl.py
| def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream raw records one at a time.
Yields:
RawRecord objects for each record.
"""
for record_id in self.get_all_record_ids():
try:
yield self.load_record(record_id)
except Exception as e:
logger.error(f"Error loading record {record_id}: {e}")
continue
|
get_target_labels()
cached
Extract all unique SCP codes (with likelihood > 0) from the metadata.
Returns:
| Type |
Description |
List[str]
|
Sorted list of unique target label codes.
|
Source code in src/data/raw/ecg/ptb_xl.py
| @cache
def get_target_labels(self) -> List[str]:
"""Extract all unique SCP codes (with likelihood > 0) from the metadata.
Returns:
Sorted list of unique target label codes.
"""
labels = set()
for record_id in self.get_all_record_ids():
codes = self._extract_scp_codes(int(record_id))
labels.update(self._extract_target_labels(codes))
return sorted(list(labels))
|
load_record(record_id)
cached
Load a single ECG record by its identifier.
Parameters:
| Name |
Type |
Description |
Default |
record_id
|
str
|
The unique ECG record identifier.
|
required
|
Returns:
| Type |
Description |
RawRecord
|
A RawRecord object containing the ECG signal data, target labels, and metadata.
|
Source code in src/data/raw/ecg/ptb_xl.py
| @cache
def load_record(self, record_id: str) -> RawRecord:
"""Load a single ECG record by its identifier.
Args:
record_id: The unique ECG record identifier.
Returns:
A RawRecord object containing the ECG signal data, target labels, and metadata.
"""
record_id = int(record_id)
if record_id not in self.metadata_dict:
raise ValueError(f"Record {record_id} not found in metadata.")
meta = self.metadata_dict[record_id]
file_path = self._get_file_path(record_id)
record = self._read_wfdb_record(record_id)
data = record[0]
codes = self._extract_scp_codes(record_id)
target_labels = self._extract_target_labels(codes)
record_metadata = dict(meta)
record_metadata["wfdb_file"] = str(file_path.relative_to(self.root))
record_metadata["signal_fields"] = record[1]
record_metadata["scp_statements"] = {
code: self._extract_scp_statement(code) for code in codes
}
return RawRecord(
id=record_id,
data=data,
target_labels=target_labels,
metadata=record_metadata,
)
|
verify_data()
Verify the raw data structure and contents.
Source code in src/data/raw/ecg/ptb_xl.py
| def verify_data(self) -> None:
"""Verify the raw data structure and contents."""
if not self.database_csv.exists():
raise ValueError(f"Database CSV does not exist: {self.database_csv}")
if not self.records_folder.exists():
raise ValueError(f"Records folder does not exist: {self.records_folder}")
for record_id in self.get_all_record_ids():
file_path = self._get_file_path(record_id)
header_path = file_path.with_suffix(".hea")
dat_path = file_path.with_suffix(".dat")
if not header_path.exists():
raise FileNotFoundError(
f"WFDB header file for record {record_id} not found: {header_path}"
)
if not dat_path.exists():
raise FileNotFoundError(
f"WFDB data file for record {record_id} not found: {dat_path}"
)
# based on https://physionet.org/content/ptb-xl/1.0.3/, we want only unique patients with one record per patient
assert (
len(self.metadata_df) == 18869
), f"Expected 18869 records, got {len(self.metadata_df)}"
|
src.data.raw.cmr.acdc
ACDC
Bases: RawDataset
Source code in src/data/raw/cmr/acdc.py
| @RawDatasetRegistry.register("cmr", "acdc")
class ACDC(RawDataset):
dataset_key = "acdc"
modality = DatasetModality.CMR
required_metadata_fields = {
"ed_frame": None,
"es_frame": None,
"group": None,
"height": None,
"weight": None,
"nb_frames": None,
}
def __init__(self, data_root: Path):
super().__init__(data_root)
self.data_path = self.paths["raw"]
self.__nifti_files = {
f.parent.name: f.absolute() for f in self.data_path.glob("**/*_4d.nii.gz")
}
self.__config_files = {
f.parent.name: f.absolute() for f in self.data_path.glob("**/Info.cfg")
}
def verify_data(self) -> None:
if not self.data_path.exists():
raise ValueError(f"ACDC data path does not exist: {self.data_path}")
if len(self.__nifti_files) == 0:
raise ValueError(f"No 4D NIFTI files found in {self.data_path}")
if len(self.__config_files) == 0:
raise ValueError(f"No Info.cfg files found in {self.data_path}")
# each NIFTI file should have a corresponding Info.cfg file
nifti_stems = set(self.__nifti_files.keys())
config_stems = set(self.__config_files.keys())
if nifti_stems != config_stems:
missing_nifti = nifti_stems - config_stems
missing_config = config_stems - nifti_stems
raise ValueError(
f"Missing Info.cfg files for {missing_nifti} and NIFTI files for {missing_config}"
)
# based on https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html
assert (
len(self.get_all_record_ids()) == 150
), f"Expected 150 records, got {len(self.get_all_record_ids())}"
def get_metadata(self) -> Dict[str, Any]:
return {"default_image_size": 256}
def _get_record_paths(self, record_id: str) -> Tuple[Path, Path]:
"""Get paths to .nii.gz and Info.cfg files for a given record ID."""
return self.__nifti_files.get(record_id), self.__config_files.get(record_id)
@cache
def get_target_labels(self) -> List[str]:
"""Get list of target labels using streaming to reduce memory usage."""
groups = set()
for record_id in self.get_all_record_ids():
_, config_path = self._get_record_paths(record_id)
group = self._parse_config(config_path).get("group")
if group is not None:
groups.add(group)
return sorted(list(groups))
def get_all_record_ids(self) -> List[str]:
"""Get all record IDs without loading data."""
return list(self.__nifti_files.keys())
def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream records one at a time to reduce memory usage."""
for record_id in self.get_all_record_ids():
yield self.load_record(record_id)
@cache
def load_record(self, record_id: str) -> RawRecord:
nifti_file, config_file = self._get_record_paths(record_id)
metadata = self._parse_config(config_file)
frame_indices = self._calculate_frame_indices(metadata)
image_data = self._get_extracted_frames(nifti_file, frame_indices)
metadata.update(
{
f"{frame_type}_frame_idx": idx
for frame_type, idx in zip(["ed", "mid", "es"], frame_indices)
}
)
metadata["nifti_path"] = str(nifti_file.relative_to(self.data_path))
group = metadata.pop("group", None)
return RawRecord(
id=record_id,
data=image_data,
target_labels=[group] if group else None,
metadata=metadata,
)
@staticmethod
def _calculate_frame_indices(metadata: Dict[str, Any]) -> List[int]:
required_fields = ["ed_frame", "es_frame", "nb_frames"]
if not all(metadata.get(field) is not None for field in required_fields):
raise ValueError("Missing required frame information")
ed_frame_idx = metadata["ed_frame"] - 1
es_frame_idx = metadata["es_frame"] - 1
nb_frames = metadata["nb_frames"]
if ed_frame_idx <= es_frame_idx:
mid_frame_idx = (ed_frame_idx + es_frame_idx) // 2
else:
mid_frame_idx = ((ed_frame_idx + es_frame_idx + nb_frames) // 2) % nb_frames
return [ed_frame_idx, mid_frame_idx, es_frame_idx]
def _get_extracted_frames(self, path: Path, frame_indices: List[int]) -> np.ndarray:
nifti_data = self._read_nifti(path)
slices = self._extract_middle_slice(nifti_data)
return self._extract_frames(slices, frame_indices)
@staticmethod
def _read_nifti(path: Path) -> np.ndarray:
try:
return nib.load(str(path)).get_fdata()
except Exception as e:
raise RuntimeError(f"Error reading NIFTI file {path}: {e}")
@staticmethod
def _extract_middle_slice(data: np.ndarray) -> np.ndarray:
"""Gets the middle basal apical slice"""
# We approximate the middle slice by taking the middle slice along the z-axis
return data[:, :, data.shape[2] // 2, :]
@staticmethod
def _extract_frames(slice_data: np.ndarray, frame_indices: List[int]) -> np.ndarray:
num_frames = slice_data.shape[-1]
invalid_indices = [idx for idx in frame_indices if idx < 0 or idx >= num_frames]
if invalid_indices:
raise IndexError(
f"Frame indices out of bounds: {invalid_indices}. "
f"Valid range: [0, {num_frames - 1}]"
)
return slice_data[:, :, frame_indices]
@cache
def _parse_config(self, path: Path) -> Dict[str, Any]:
"""Parse config file with caching."""
config = configparser.ConfigParser()
with open(path, "r") as f:
config.read_string("[DEFAULT]\n" + f.read())
parsed_values = {
"ed_frame": config["DEFAULT"].getint("ED"),
"es_frame": config["DEFAULT"].getint("ES"),
"group": config["DEFAULT"].get("Group"),
"height": config["DEFAULT"].getfloat("Height"),
"weight": config["DEFAULT"].getfloat("Weight"),
"nb_frames": config["DEFAULT"].getint("NbFrame"),
}
return {k: v for k, v in parsed_values.items() if v is not None}
|
get_all_record_ids()
Get all record IDs without loading data.
Source code in src/data/raw/cmr/acdc.py
| def get_all_record_ids(self) -> List[str]:
"""Get all record IDs without loading data."""
return list(self.__nifti_files.keys())
|
get_stream()
Stream records one at a time to reduce memory usage.
Source code in src/data/raw/cmr/acdc.py
| def get_stream(self) -> Generator[RawRecord, None, None]:
"""Stream records one at a time to reduce memory usage."""
for record_id in self.get_all_record_ids():
yield self.load_record(record_id)
|
get_target_labels()
cached
Get list of target labels using streaming to reduce memory usage.
Source code in src/data/raw/cmr/acdc.py
| @cache
def get_target_labels(self) -> List[str]:
"""Get list of target labels using streaming to reduce memory usage."""
groups = set()
for record_id in self.get_all_record_ids():
_, config_path = self._get_record_paths(record_id)
group = self._parse_config(config_path).get("group")
if group is not None:
groups.add(group)
return sorted(list(groups))
|
Using the Unified Dataset
In this section, we will see how to use the UnifiedDataset class to load and preprocess data from a dataset.
The UnifiedDataset can easily be initialized using the data_root, modality and dataset_key arguments. It will then load the dataset's metadata and provide methods for preprocessing and embedding the data.
from pathlib import Path
from src.data.dataset import DatasetModality
from src.data.unified import UnifiedDataset
# Initialize dataset
dataset = UnifiedDataset(
data_root=Path("data"),
modality=DatasetModality.ECG,
dataset_key="ptbxl"
)
# Access data
record = dataset["patient_001"]
raw_data = record.raw_record.data
preprocessed = record.preprocessed_record
embeddings = record.embeddings
# Get splits and metadata
splits = dataset.get_splits() if dataset.has_splits() else None
metadata_fields = dataset.available_metadata_fields()
# Verify integrity
dataset.verify_integrity()
Next Steps