Skip to content

Model Architectures

This section provides detailed documentation for all model architectures in the project. Our models are organized into several categories:

  • Self-Supervised Learning: Pre-training architectures for learning representations without labels
  • Encoders: Architectures for extracting representations from data
  • Classifiers: Architectures for classifying data and modeling downstream tasks

Base Components

The foundational interfaces and base classes that models build upon.

src.utils.model_weights

PretrainedWeightsMixin

Mixin class for loading pretrained weights with intelligent key matching.

Source code in src/utils/model_weights.py
class PretrainedWeightsMixin:
    """Mixin class for loading pretrained weights with intelligent key matching."""

    def load_pretrained_weights(
        self,
        weights_path: str,
        strict: bool = False,
        missing_key_threshold: float = 0.1,
    ) -> None:
        """Load pretrained weights with intelligent key matching.

        Args:
            weights_path: Path to the pretrained weights file
            strict: Whether to strictly enforce matching keys
            missing_key_threshold: Maximum allowed percentage of missing keys

        Raises:
            RuntimeError: If loading criteria are not met
        """
        try:
            state_dict = self._load_and_extract_state_dict(weights_path)

            target_keys = set(self.state_dict().keys())
            target_keys = self._get_filtered_state_dict_keys(target_keys)

            new_state_dict, stats = self._match_and_validate_state_dict_keys(
                state_dict, target_keys
            )

            self._validate_weights_loading_criteria(
                stats, strict, missing_key_threshold
            )
            self.load_state_dict(new_state_dict, strict=False)

            self._log_weights_loading_results(stats, weights_path)

        except Exception as e:
            logger.error(
                f"Error loading pretrained weights from {weights_path}: {str(e)}"
            )
            raise

    def _load_and_extract_state_dict(self, weights_path: str) -> Dict[str, Any]:
        state_dict = torch.load(weights_path, map_location="cpu")

        # Extract nested state dict if necessary
        for key in ["state_dict", "model", "network"]:
            if isinstance(state_dict, dict) and key in state_dict:
                state_dict = state_dict[key]

        return state_dict

    def _get_filtered_state_dict_keys(self, target_keys: Set[str]) -> Set[str]:
        before_filter = len(target_keys)
        before_filter_keys = target_keys.copy()
        target_keys = {k for k in target_keys if self._load_state_dict_key(k)}
        if len(target_keys) < before_filter:
            logger.info(
                f"Filtered {before_filter} keys to {len(target_keys)} keys after filtering. "
                f"Filtered keys: {before_filter_keys - target_keys}"
            )
        return target_keys

    def _match_and_validate_state_dict_keys(
        self, source_dict: Dict[str, torch.Tensor], target_keys: Set[str]
    ) -> Tuple[OrderedDict, Dict]:
        new_state_dict = OrderedDict()
        missing_keys = []
        matched_keys = set()
        shape_mismatches = []

        source_keys = set(source_dict.keys())

        for target_key in target_keys:
            matching_key = self._find_matching_state_dict_key(target_key, source_keys)

            if not matching_key:
                missing_keys.append(target_key)
                continue

            if self._shapes_match(
                source_dict[matching_key], self.state_dict()[target_key]
            ):
                new_state_dict[target_key] = source_dict[matching_key]
                matched_keys.add(matching_key)
            else:
                shape_mismatches.append(f"Shape mismatch for {target_key}")

        return new_state_dict, {
            "missing_keys": missing_keys,
            "unexpected_keys": list(source_keys - matched_keys),
            "shape_mismatches": shape_mismatches,
            "total_keys": len(target_keys),
            "matched_keys": len(matched_keys),
        }

    @staticmethod
    def _shapes_match(source_tensor: torch.Tensor, target_tensor: torch.Tensor) -> bool:
        return source_tensor.shape == target_tensor.shape

    def _validate_weights_loading_criteria(
        self, stats: Dict, strict: bool, missing_key_threshold: float
    ) -> None:
        if stats["shape_mismatches"]:
            raise RuntimeError("\n".join(stats["shape_mismatches"]))

        missing_ratio = len(stats["missing_keys"]) / stats["total_keys"]
        if missing_ratio > missing_key_threshold:
            raise RuntimeError(
                f"Too many missing keys: {len(stats['missing_keys'])}/{stats['total_keys']} "
                f"({missing_ratio:.1%} > {missing_key_threshold:.1%}) threshold. "
                f"Missing keys: {stats['missing_keys']}"
            )

        if strict and stats["missing_keys"]:
            raise RuntimeError(f"Strict loading failed: {stats['missing_keys']}")

    def _log_weights_loading_results(self, stats: Dict, weights_path: str) -> None:
        if stats["missing_keys"]:
            logger.warning(
                f"Missing keys: {len(stats['missing_keys'])}/{stats['total_keys']} "
                f"({len(stats['missing_keys']) / stats['total_keys']:.1%})."
                f"Missing keys: {stats['missing_keys']}"
            )
        if stats["unexpected_keys"]:
            logger.warning(f"Unexpected keys: {stats['unexpected_keys']}")

        logger.info(
            f"Successfully loaded weights from {weights_path} "
            f"({stats['matched_keys']}/{stats['total_keys']} layers). "
            f"Coverage: {stats['matched_keys'] / stats['total_keys']:.1%}"
        )

    def _find_matching_state_dict_key(
        self, target_key: str, available_keys: Set[str]
    ) -> Optional[str]:
        """Find matching key by handling the encoder prefix in model's state dict. Can be overridden."""
        if target_key in available_keys:
            return target_key
        return None

    def _load_state_dict_key(self, key: str) -> bool:
        """Filter keys to load. Can be overridden."""
        return True
load_pretrained_weights(weights_path, strict=False, missing_key_threshold=0.1)

Load pretrained weights with intelligent key matching.

Parameters:

Name Type Description Default
weights_path str

Path to the pretrained weights file

required
strict bool

Whether to strictly enforce matching keys

False
missing_key_threshold float

Maximum allowed percentage of missing keys

0.1

Raises:

Type Description
RuntimeError

If loading criteria are not met

Source code in src/utils/model_weights.py
def load_pretrained_weights(
    self,
    weights_path: str,
    strict: bool = False,
    missing_key_threshold: float = 0.1,
) -> None:
    """Load pretrained weights with intelligent key matching.

    Args:
        weights_path: Path to the pretrained weights file
        strict: Whether to strictly enforce matching keys
        missing_key_threshold: Maximum allowed percentage of missing keys

    Raises:
        RuntimeError: If loading criteria are not met
    """
    try:
        state_dict = self._load_and_extract_state_dict(weights_path)

        target_keys = set(self.state_dict().keys())
        target_keys = self._get_filtered_state_dict_keys(target_keys)

        new_state_dict, stats = self._match_and_validate_state_dict_keys(
            state_dict, target_keys
        )

        self._validate_weights_loading_criteria(
            stats, strict, missing_key_threshold
        )
        self.load_state_dict(new_state_dict, strict=False)

        self._log_weights_loading_results(stats, weights_path)

    except Exception as e:
        logger.error(
            f"Error loading pretrained weights from {weights_path}: {str(e)}"
        )
        raise

Our models implement a sophisticated weights management system through the PretrainedWeightsMixin. This system is designed with several key principles:

  1. Robustness: Gracefully handle different weight file formats and structures
  2. Flexibility: Support partial loading and key matching
  3. Safety: Validate shapes and provide meaningful errors
  4. Transparency: Detailed logging of the loading process

The mixin provides intelligent weight loading with features like: - Automatic nested state dict extraction - Configurable missing key tolerance - Shape validation - Detailed loading statistics - Extensible key matching logic

Usage Example

class MyModel(nn.Module, PretrainedWeightsMixin):
    def __init__(self):
        super().__init__()
        # ... model definition ...

    def load_my_weights(self, path):
        # Load with 20% missing key tolerance
        self.load_pretrained_weights(path, strict=False, missing_key_threshold=0.2)

Design Philosophy

The mixin is designed to solve common issues in deep learning weight management:

  • Versioning: Models evolve, but weights should remain usable
  • Flexibility: Support both exact and partial loading
  • Debugging: Clear feedback about what was loaded
  • Safety: Prevent silent failures with shape mismatches
  • Extensibility: Easy to customize key matching logic

src.models.encoder_interface

EncoderInterface

Bases: ABC

Interface for neural network encoders that extract features from input data.

Source code in src/models/encoder_interface.py
class EncoderInterface(ABC):
    """Interface for neural network encoders that extract features from input data."""

    @abstractmethod
    def forward_features(
        self, x: torch.Tensor, localized: bool = False
    ) -> torch.Tensor:
        """Extract features from input tensor.

        Args:
            x: Input tensor
            localized: Whether to return localized features instead of global features

        Returns:
            Tensor of extracted features
        """
        pass
forward_features(x, localized=False) abstractmethod

Extract features from input tensor.

Parameters:

Name Type Description Default
x Tensor

Input tensor

required
localized bool

Whether to return localized features instead of global features

False

Returns:

Type Description
Tensor

Tensor of extracted features

Source code in src/models/encoder_interface.py
@abstractmethod
def forward_features(
    self, x: torch.Tensor, localized: bool = False
) -> torch.Tensor:
    """Extract features from input tensor.

    Args:
        x: Input tensor
        localized: Whether to return localized features instead of global features

    Returns:
        Tensor of extracted features
    """
    pass

The encoder interface abstracts the data preprocessing pipeline and provides a unified interface for all encoders in the project that can be used to extract features. This is mainly useful for the generation of representations and was specifically designed for this.

Self-Supervised Learning

Our self-supervised learning implementations are based on state-of-the-art approaches adapted for medical data.

src.models.mae

Implementation Credits

Our MAE implementation is based on:

src.models.mae_lit

LitMAE

Bases: MaskedAutoencoderViT, LightningModule

Source code in src/models/mae_lit.py
class LitMAE(MaskedAutoencoderViT, pl.LightningModule):
    def __init__(
        self,
        img_size,
        patch_size,
        embedding_dim,
        depth,
        num_heads,
        decoder_embed_dim,
        decoder_depth,
        decoder_num_heads,
        mlp_ratio,
        norm_layer,
        norm_pix_loss,
        ncc_weight,
        mask_ratio,
        learning_rate,
        weight_decay,
        warmup_epochs,
        max_epochs,
        pretrained_weights,
        min_lr=0.0,
    ):
        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            embed_dim=embedding_dim,
            depth=depth,
            num_heads=num_heads,
            decoder_embed_dim=decoder_embed_dim,
            decoder_depth=decoder_depth,
            decoder_num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio,
            norm_layer=norm_layer,
            norm_pix_loss=norm_pix_loss,
            ncc_weight=ncc_weight,
            pretrained_weights=pretrained_weights,
        )
        self.save_hyperparameters()

        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.warmup_epochs = warmup_epochs
        self.min_lr = min_lr
        self.max_epochs = max_epochs

        self.mask_ratio = mask_ratio

    def _step(self, batch):
        samples, _, _ = batch
        loss, samples_hat, samples_hat_masked = self(
            samples, mask_ratio=self.mask_ratio
        )

        normalized_corr = self.ncc(samples, samples_hat)
        batch_size = samples.shape[0]
        return loss, normalized_corr, batch_size

    def training_step(self, batch, batch_idx):
        loss, normalized_corr, batch_size = self._step(batch)
        loss_value = loss.item()
        self.log_dict(
            {
                "train/loss": loss_value,
                "train/ncc": normalized_corr,
                "lr": self.optimizers().param_groups[0]["lr"],
            },
            on_step=True,
            on_epoch=True,
            batch_size=batch_size,
            sync_dist=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, normalized_corr, batch_size = self._step(batch)
        loss_value = loss.item()
        self.log_dict(
            {
                "val/loss": loss,
                "val/ncc": normalized_corr,
                "lr": self.optimizers().param_groups[0]["lr"],
            },
            on_epoch=True,
            batch_size=batch_size,
            sync_dist=True,
        )

        return loss_value

    def configure_optimizers(self):
        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_pretrain.py#L256
        param_groups = optim_factory.add_weight_decay(self, self.weight_decay)

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_pretrain.py#L257
        optimizer = torch.optim.AdamW(
            param_groups, lr=self.learning_rate, betas=(0.9, 0.95)
        )

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/util/lr_sched.py
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.warmup_epochs,
            max_epochs=self.max_epochs,
            eta_min=self.min_lr,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",  # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_pretrain.py#L79
                "frequency": 1,
            },
        }

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_pretrain.py#L25
    def norm(self, data: torch.Tensor()) -> torch.Tensor():
        """
        Zero-Normalize data to have mean=0 and standard_deviation=1

        Parameters
        ----------
        data:  tensor
        """
        mean = torch.mean(data, dim=-1, keepdim=True)
        var = torch.var(data, dim=-1, keepdim=True)

        return (data - mean) / (var + 1e-12) ** 0.5

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_pretrain.py#L38
    def ncc(self, data_0: torch.Tensor(), data_1: torch.Tensor()) -> torch.Tensor():
        """
        Zero-Normalized cross-correlation coefficient between two data sets

        Zero-Normalized cross-correlation equals the cosine of the angle between the unit vectors F and T,
        being thus 1 if and only if F equals T multiplied by a positive scalar.

        Parameters
        ----------
        data_0, data_1 :  tensors of same size
        """

        nb_of_signals = 1
        for dim in range(
            data_0.dim() - 1
        ):  # all but the last dimension (which is the actual signal)
            nb_of_signals = nb_of_signals * data_0.shape[dim]

        cross_corrs = (1.0 / (data_0.shape[-1] - 1)) * torch.sum(
            self.norm(data=data_0) * self.norm(data=data_1), dim=-1
        )
        return cross_corrs.sum() / nb_of_signals
ncc(data_0, data_1)

Zero-Normalized cross-correlation coefficient between two data sets

Zero-Normalized cross-correlation equals the cosine of the angle between the unit vectors F and T, being thus 1 if and only if F equals T multiplied by a positive scalar.

Parameters

data_0, data_1 : tensors of same size

Source code in src/models/mae_lit.py
def ncc(self, data_0: torch.Tensor(), data_1: torch.Tensor()) -> torch.Tensor():
    """
    Zero-Normalized cross-correlation coefficient between two data sets

    Zero-Normalized cross-correlation equals the cosine of the angle between the unit vectors F and T,
    being thus 1 if and only if F equals T multiplied by a positive scalar.

    Parameters
    ----------
    data_0, data_1 :  tensors of same size
    """

    nb_of_signals = 1
    for dim in range(
        data_0.dim() - 1
    ):  # all but the last dimension (which is the actual signal)
        nb_of_signals = nb_of_signals * data_0.shape[dim]

    cross_corrs = (1.0 / (data_0.shape[-1] - 1)) * torch.sum(
        self.norm(data=data_0) * self.norm(data=data_1), dim=-1
    )
    return cross_corrs.sum() / nb_of_signals
norm(data)

Zero-Normalize data to have mean=0 and standard_deviation=1

Parameters

data: tensor

Source code in src/models/mae_lit.py
def norm(self, data: torch.Tensor()) -> torch.Tensor():
    """
    Zero-Normalize data to have mean=0 and standard_deviation=1

    Parameters
    ----------
    data:  tensor
    """
    mean = torch.mean(data, dim=-1, keepdim=True)
    var = torch.var(data, dim=-1, keepdim=True)

    return (data - mean) / (var + 1e-12) ** 0.5

src.models.sim_clr

SimCLR

Bases: LightningModule

Lightning module for imaging SimCLR.

Alternates training between contrastive model and online classifier.

Source code in src/models/sim_clr.py
class SimCLR(pl.LightningModule):
    """
    Lightning module for imaging SimCLR.

    Alternates training between contrastive model and online classifier.
    """

    def __init__(
        self,
        encoder_backbone_model_name: str,
        projection_dim: int,
        temperature: float,
        num_classes: int,
        init_strat: str,
        weights: Optional[List[float]],
        learning_rate: float,
        weight_decay: float,
        lr_classifier: float,
        weight_decay_classifier: float,
        scheduler: str,
        anneal_max_epochs: int,
        warmup_epochs: int,
        max_epochs: int,
        check_val_every_n_epoch: int,
        log_images: bool = False,
        pretrained_weights: Optional[str] = None,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.encoder_backbone_model_name = encoder_backbone_model_name
        self.projection_dim = projection_dim
        self.temperature = temperature
        self.num_classes = num_classes
        self.init_strat = init_strat
        self.weights = weights
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.lr_classifier = lr_classifier
        self.weight_decay_classifier = weight_decay_classifier
        self.scheduler = scheduler
        self.anneal_max_epochs = anneal_max_epochs
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.check_val_every_n_epoch = check_val_every_n_epoch
        self.log_images = log_images

        # Manual optimization for multiple optimizers
        self.automatic_optimization = False

        # Initialize encoder
        self.encoder_imaging = CMREncoder(
            backbone_model_name=encoder_backbone_model_name,
            pretrained_weights=pretrained_weights,
        )
        pooled_dim = self.encoder_imaging.pooled_dim

        self.projection_head = SimCLRProjectionHead(
            pooled_dim, pooled_dim, projection_dim
        )
        self.criterion_train = NTXentLoss(temperature=temperature)
        self.criterion_val = NTXentLoss(temperature=temperature)

        # Defines weights to be used for the classifier in case of imbalanced data
        if not weights:
            weights = [1.0 for _ in range(num_classes)]
        self.weights = torch.tensor(weights)

        # Classifier
        self.classifier = LinearClassifier(
            in_size=pooled_dim, num_classes=num_classes, init_type=init_strat
        )
        self.classifier_criterion = torch.nn.CrossEntropyLoss(weight=self.weights)

        self.top1_acc_train = torchmetrics.Accuracy(
            task="multiclass", top_k=1, num_classes=num_classes
        )
        self.top1_acc_val = torchmetrics.Accuracy(
            task="multiclass", top_k=1, num_classes=num_classes
        )

        self.top5_acc_train = torchmetrics.Accuracy(
            task="multiclass", top_k=5, num_classes=num_classes
        )
        self.top5_acc_val = torchmetrics.Accuracy(
            task="multiclass", top_k=5, num_classes=num_classes
        )

        self.f1_train = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)
        self.f1_val = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

        self.classifier_acc_train = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes, average="weighted"
        )
        self.classifier_acc_val = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes, average="weighted"
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Generates projection of data.
        """
        x = self.encoder_imaging(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

    def training_step(
        self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
    ) -> torch.Tensor:
        """
        Alternates calculation of loss for training between contrastive model and online classifier.
        """
        x0, x1, y, indices = batch

        opt1, opt2 = self.optimizers()

        # Train contrastive model using opt1
        z0 = self.forward(x0)
        z1 = self.forward(x1)

        loss, _, _ = self.criterion_train(z0, z1)

        self.log(
            "imaging.train.loss", loss, on_epoch=True, on_step=False, sync_dist=True
        )
        self.log(
            "imaging.train.top1",
            self.top1_acc_train,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            "imaging.train.top5",
            self.top5_acc_train,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )

        opt1.zero_grad()
        self.manual_backward(loss)
        opt1.step()

        # Train classifier using opt2
        embedding = torch.squeeze(self.encoder_imaging(x0))
        y_hat = self.classifier(embedding)
        cls_loss = self.classifier_criterion(y_hat, y)

        y_hat = y_hat.argmax(dim=1)
        y = y.argmax(dim=1)

        self.f1_train(y_hat, y)
        self.classifier_acc_train(y_hat, y)

        self.log(
            "classifier.train.loss",
            cls_loss,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            "classifier.train.f1",
            self.f1_train,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            "classifier.train.accuracy",
            self.classifier_acc_train,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )

        opt2.zero_grad()
        self.manual_backward(cls_loss)
        opt2.step()

        return loss

    def validation_step(
        self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
    ) -> torch.Tensor:
        """
        Validate both contrastive model and classifier
        """
        x0, x1, y, indices = batch

        # Validate contrastive model
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        loss, _, _ = self.criterion_val(z0, z1)

        self.log("val_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log("imaging.val.loss", loss, on_epoch=True, on_step=False, sync_dist=True)
        self.log(
            "imaging.val.top1",
            self.top1_acc_val,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            "imaging.val.top5",
            self.top5_acc_val,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )

        # Validate classifier
        self.classifier.eval()
        embedding = torch.squeeze(self.encoder_imaging(x0))
        y_hat = self.classifier(embedding)
        loss = self.classifier_criterion(y_hat, y)

        y_hat = y_hat.argmax(dim=1)
        y = y.argmax(dim=1)

        self.f1_val(y_hat, y)
        self.classifier_acc_val(y_hat, y)

        self.log(
            "classifier.val.loss", loss, on_epoch=True, on_step=False, sync_dist=True
        )
        self.log(
            "classifier.val.f1",
            self.f1_val,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.log(
            "classifier.val.accuracy",
            self.classifier_acc_val,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        self.classifier.train()

        if not hasattr(self, "validation_step_outputs"):
            self.validation_step_outputs = []
        self.validation_step_outputs.append(x0)
        return x0

    def on_validation_epoch_end(self) -> None:
        """
        Log an image from each validation step using the appropriate logger.
        """
        if self.log_images and hasattr(self, "validation_step_outputs"):
            example_img = (
                self.validation_step_outputs[0]
                .cpu()
                .detach()
                .numpy()[0][0]  # First image in batch, first channel
            )

            if isinstance(self.logger, NeptuneLogger):
                self.logger.run["Image Example"].upload(File.as_image(example_img))
            elif isinstance(self.logger, WandbLogger):
                self.logger.log_image(key="Image Example", images=[example_img])

    def configure_optimizers(self) -> Tuple[Dict, Dict]:
        """
        Define and return optimizer and scheduler for contrastive model and online classifier.
        Scheduler for online classifier often disabled
        """
        optimizer = torch.optim.Adam(
            [
                {"params": self.encoder_imaging.parameters()},
                {"params": self.projection_head.parameters()},
            ],
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )
        classifier_optimizer = torch.optim.Adam(
            self.classifier.parameters(),
            lr=self.lr_classifier,
            weight_decay=self.weight_decay_classifier,
        )

        if self.scheduler == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.anneal_max_epochs, eta_min=0, last_epoch=-1
            )
        elif self.scheduler == "anneal":
            scheduler = LinearWarmupCosineAnnealingLR(
                optimizer, warmup_epochs=self.warmup_epochs, max_epochs=self.max_epochs
            )

        classifier_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            classifier_optimizer,
            patience=int(20 / self.check_val_every_n_epoch),
            min_lr=self.lr_classifier * 0.0001,
        )

        return (
            {"optimizer": optimizer, "lr_scheduler": scheduler},  # Contrastive
            {"optimizer": classifier_optimizer},  # Classifier
        )
configure_optimizers()

Define and return optimizer and scheduler for contrastive model and online classifier. Scheduler for online classifier often disabled

Source code in src/models/sim_clr.py
def configure_optimizers(self) -> Tuple[Dict, Dict]:
    """
    Define and return optimizer and scheduler for contrastive model and online classifier.
    Scheduler for online classifier often disabled
    """
    optimizer = torch.optim.Adam(
        [
            {"params": self.encoder_imaging.parameters()},
            {"params": self.projection_head.parameters()},
        ],
        lr=self.learning_rate,
        weight_decay=self.weight_decay,
    )
    classifier_optimizer = torch.optim.Adam(
        self.classifier.parameters(),
        lr=self.lr_classifier,
        weight_decay=self.weight_decay_classifier,
    )

    if self.scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.anneal_max_epochs, eta_min=0, last_epoch=-1
        )
    elif self.scheduler == "anneal":
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer, warmup_epochs=self.warmup_epochs, max_epochs=self.max_epochs
        )

    classifier_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        classifier_optimizer,
        patience=int(20 / self.check_val_every_n_epoch),
        min_lr=self.lr_classifier * 0.0001,
    )

    return (
        {"optimizer": optimizer, "lr_scheduler": scheduler},  # Contrastive
        {"optimizer": classifier_optimizer},  # Classifier
    )
forward(x)

Generates projection of data.

Source code in src/models/sim_clr.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Generates projection of data.
    """
    x = self.encoder_imaging(x).flatten(start_dim=1)
    z = self.projection_head(x)
    return z
on_validation_epoch_end()

Log an image from each validation step using the appropriate logger.

Source code in src/models/sim_clr.py
def on_validation_epoch_end(self) -> None:
    """
    Log an image from each validation step using the appropriate logger.
    """
    if self.log_images and hasattr(self, "validation_step_outputs"):
        example_img = (
            self.validation_step_outputs[0]
            .cpu()
            .detach()
            .numpy()[0][0]  # First image in batch, first channel
        )

        if isinstance(self.logger, NeptuneLogger):
            self.logger.run["Image Example"].upload(File.as_image(example_img))
        elif isinstance(self.logger, WandbLogger):
            self.logger.log_image(key="Image Example", images=[example_img])
training_step(batch, _)

Alternates calculation of loss for training between contrastive model and online classifier.

Source code in src/models/sim_clr.py
def training_step(
    self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
) -> torch.Tensor:
    """
    Alternates calculation of loss for training between contrastive model and online classifier.
    """
    x0, x1, y, indices = batch

    opt1, opt2 = self.optimizers()

    # Train contrastive model using opt1
    z0 = self.forward(x0)
    z1 = self.forward(x1)

    loss, _, _ = self.criterion_train(z0, z1)

    self.log(
        "imaging.train.loss", loss, on_epoch=True, on_step=False, sync_dist=True
    )
    self.log(
        "imaging.train.top1",
        self.top1_acc_train,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )
    self.log(
        "imaging.train.top5",
        self.top5_acc_train,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )

    opt1.zero_grad()
    self.manual_backward(loss)
    opt1.step()

    # Train classifier using opt2
    embedding = torch.squeeze(self.encoder_imaging(x0))
    y_hat = self.classifier(embedding)
    cls_loss = self.classifier_criterion(y_hat, y)

    y_hat = y_hat.argmax(dim=1)
    y = y.argmax(dim=1)

    self.f1_train(y_hat, y)
    self.classifier_acc_train(y_hat, y)

    self.log(
        "classifier.train.loss",
        cls_loss,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )
    self.log(
        "classifier.train.f1",
        self.f1_train,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )
    self.log(
        "classifier.train.accuracy",
        self.classifier_acc_train,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )

    opt2.zero_grad()
    self.manual_backward(cls_loss)
    opt2.step()

    return loss
validation_step(batch, _)

Validate both contrastive model and classifier

Source code in src/models/sim_clr.py
def validation_step(
    self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
) -> torch.Tensor:
    """
    Validate both contrastive model and classifier
    """
    x0, x1, y, indices = batch

    # Validate contrastive model
    z0 = self.forward(x0)
    z1 = self.forward(x1)
    loss, _, _ = self.criterion_val(z0, z1)

    self.log("val_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
    self.log("imaging.val.loss", loss, on_epoch=True, on_step=False, sync_dist=True)
    self.log(
        "imaging.val.top1",
        self.top1_acc_val,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )
    self.log(
        "imaging.val.top5",
        self.top5_acc_val,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )

    # Validate classifier
    self.classifier.eval()
    embedding = torch.squeeze(self.encoder_imaging(x0))
    y_hat = self.classifier(embedding)
    loss = self.classifier_criterion(y_hat, y)

    y_hat = y_hat.argmax(dim=1)
    y = y.argmax(dim=1)

    self.f1_val(y_hat, y)
    self.classifier_acc_val(y_hat, y)

    self.log(
        "classifier.val.loss", loss, on_epoch=True, on_step=False, sync_dist=True
    )
    self.log(
        "classifier.val.f1",
        self.f1_val,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )
    self.log(
        "classifier.val.accuracy",
        self.classifier_acc_val,
        on_epoch=True,
        on_step=False,
        sync_dist=True,
    )
    self.classifier.train()

    if not hasattr(self, "validation_step_outputs"):
        self.validation_step_outputs = []
    self.validation_step_outputs.append(x0)
    return x0

Implementation Credits

Our SimCLR implementation is based on:

Encoders

Model definitions mainly designed to obtain representations.

src.models.ecg_encoder

ECGEncoder

Bases: PretrainedWeightsMixin, EncoderInterface, VisionTransformer

Source code in src/models/ecg_encoder.py
class ECGEncoder(
    PretrainedWeightsMixin,
    EncoderInterface,
    timm.models.vision_transformer.VisionTransformer,
):
    def __init__(
        self,
        img_size: Union[Tuple[int, int, int], List[int]],
        patch_size: Union[Tuple[int, int], List[int]],
        embedding_dim: int,
        depth: int,
        num_heads: int,
        mlp_ratio: float,
        qkv_bias: bool,
        norm_layer: Optional[nn.Module],
        global_pool: str,
        pretrained_weights: Optional[str] = None,
    ):
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)

        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            embed_dim=embedding_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
        )

        self.global_pool = global_pool
        self.embed_dim = embedding_dim
        self.num_heads = num_heads
        self.norm_layer = norm_layer

        if self.global_pool == "attention_pool":
            self.attention_pool = nn.MultiheadAttention(
                embed_dim=embedding_dim, num_heads=num_heads, batch_first=True
            )
        if self.global_pool:
            self.fc_norm = norm_layer(embedding_dim)
            del self.norm  # remove the original norm

        if pretrained_weights:
            self.load_pretrained_weights(pretrained_weights)

    def forward_features(self, x, localized=False):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(
            B, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        if localized:
            outcome = x[:, 1:]
        elif self.global_pool == "attention_pool":
            q = x[:, 1:, :].mean(dim=1, keepdim=True)
            k = x[:, 1:, :]
            v = x[:, 1:, :]
            x, x_weights = self.attention_pool(
                q, k, v
            )  # attention pool without cls token
            outcome = self.fc_norm(x.squeeze(dim=1))
        elif self.global_pool:
            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
            outcome = self.fc_norm(x)
        else:
            x = self.norm(x)
            outcome = x[:, 0]

        return outcome

    def _find_matching_state_dict_key(
        self, target_key: str, available_keys: set
    ) -> Optional[str]:
        """Find matching key by handling the encoder prefix in model's state dict.

        Args:
            target_key: Key from model's state dict (with 'encoder.' prefix)
            available_keys: Keys available in the loaded weights

        Returns:
            Optional[str]: Matching key from available_keys if found, None otherwise
        """
        if target_key.startswith("head."):
            return None

        return super()._find_matching_state_dict_key(target_key, available_keys)

vit_patchX(**kwargs)

Function to create Vision Transformer conforming to the pre-trained weights by Turgut et. al (2025)

Source code in src/models/ecg_encoder.py
def vit_patchX(**kwargs):
    """Function to create Vision Transformer conforming to the pre-trained weights by Turgut et. al (2025)"""
    model = ECGEncoder(
        patch_size=(1, 100),  # To match patch_embed.proj.weight: [384, 1, 1, 100]
        img_size=(1, 12, 2500),
        embedding_dim=384,  # To match embedding dimension
        depth=3,  # 3 transformer blocks
        num_heads=6,  # 384/64=6 heads (standard head dim of 64)
        mlp_ratio=4,  # Matches the 1536 dimension in mlp layers
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    return model

Implementation Credits

The classifier also bases on the MAE implementation by Turgut et al. as outlined in the MAE section.

src.models.cmr_encoder

CMREncoder

Bases: PretrainedWeightsMixin, EncoderInterface, Module

Source code in src/models/cmr_encoder.py
class CMREncoder(PretrainedWeightsMixin, EncoderInterface, nn.Module):
    BACKBONE_MODELS = ["resnet18", "resnet50"]

    def __init__(
        self, backbone_model_name: str, pretrained_weights: Optional[str] = None
    ):
        super().__init__()
        if backbone_model_name not in self.BACKBONE_MODELS:
            raise ValueError(f"Unknown backbone model: {backbone_model_name}")

        if backbone_model_name == "resnet18":
            resnet = torchvision.models.resnet18()
            self.pooled_dim = 512
        elif backbone_model_name == "resnet50":
            resnet = torchvision.models.resnet50()
            self.pooled_dim = 2048
        else:
            raise ValueError(f"Unknown model type: {backbone_model_name}")

        self.encoder = self._remove_last_layer(resnet)

        if pretrained_weights:
            self.load_pretrained_weights(pretrained_weights)

    def _find_matching_state_dict_key(
        self, target_key: str, available_keys: set
    ) -> Optional[str]:
        """Find matching key by handling the encoder prefix in model's state dict.

        Args:
            target_key: Key from model's state dict (with 'encoder.' prefix)
            available_keys: Keys available in the loaded weights

        Returns:
            Optional[str]: Matching key from available_keys if found, None otherwise
        """
        for prefix in ("encoder.", ""):
            imaging_key = f"encoder_imaging.{target_key[len(prefix):]}"
            if imaging_key in available_keys:
                return imaging_key

        return super()._find_matching_state_dict_key(target_key, available_keys)

    def _remove_last_layer(self, resnet):
        """
        Remove the fully connected layer and pooling layer from the resnet model.
        """
        return nn.Sequential(*list(resnet.children())[:-1])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x).squeeze()

    def forward_features(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)

Implementation Credits

The encoder also bases on the MMCL-ECG-CMR implementation by Turgut et al. as outlined in the MMCL-ECG-CMR section.

Classifiers

Model definitions mainly designed to classify data. Generally, these models can be extended to perform any kind of downstream task.

src.models.ecg_classifier

ECGClassifier

Bases: PretrainedWeightsMixin, MetricsMixin, VisionTransformer, LightningModule

Source code in src/models/ecg_classifier.py
class ECGClassifier(
    PretrainedWeightsMixin,
    MetricsMixin,
    timm.models.vision_transformer.VisionTransformer,
    pl.LightningModule,
):
    def __init__(
        self,
        img_size: Union[Tuple[int, int, int], List[int]],
        patch_size: Union[Tuple[int, int], List[int]],
        embedding_dim: int,
        depth: int,
        num_heads: int,
        mlp_ratio: float,
        qkv_bias: bool,
        num_classes: int,
        learning_rate: float,
        weight_decay: float,
        warmup_epochs: int,
        max_epochs: int,
        layer_decay: float,
        norm_layer: nn.Module,
        drop_path_rate: float,
        smoothing: float,
        task_type: Literal["multiclass", "multilabel"],
        global_pool=False,
        masking_blockwise=False,
        mask_ratio=0.0,
        mask_c_ratio=0.0,
        mask_t_ratio=0.0,
        min_lr=0.0,
        pretrained_weights: Optional[str] = None,
        pos_weight: Optional[torch.Tensor] = None,
    ):
        self.save_hyperparameters()

        super().__init__(
            img_size=img_size,
            patch_size=patch_size,
            embed_dim=embedding_dim,
            depth=depth,
            num_heads=num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            norm_layer=norm_layer,
            drop_rate=drop_path_rate,
            num_classes=num_classes,
        )

        self.pretrained_weights = pretrained_weights
        if pretrained_weights:
            self.load_pretrained_weights(pretrained_weights)

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L373
        self.blocks[-1].attn.forward = ECGClassifier.attention_forward_wrapper(
            self.blocks[-1].attn
        )  # required to read out the attention map of the last layer

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L402
        # manually initialize fc layer (as presumably not part of pretrained weights)
        trunc_normal_(self.head.weight, std=0.01)  # 2e-5)

        self.masking_blockwise = masking_blockwise
        self.mask_ratio = mask_ratio
        self.mask_c_ratio = mask_c_ratio
        self.mask_t_ratio = mask_t_ratio

        self.learning_rate = learning_rate
        self.min_lr = min_lr
        self.weight_decay = weight_decay
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.layer_decay = layer_decay

        self.task_type = task_type
        self.downstream_task = "classification"
        self.pos_weight = pos_weight

        self.global_pool = global_pool
        if self.global_pool == "attention_pool":
            self.attention_pool = nn.MultiheadAttention(
                embed_dim=embedding_dim,
                num_heads=num_heads,
                batch_first=True,
            )
        if self.global_pool:
            self.fc_norm = norm_layer(embedding_dim)
            del self.norm  # remove the original norm

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L289
        # self.class_weights = 2.0 / (
        #    2.0 * torch.Tensor([1.0, 1.0])
        # )  # total_nb_samples / (nb_classes * samples_per_class)
        self.class_weights = None

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L445
        # We deviate here in favor of BCEWithLogitsLoss which is multi-label compatible.
        if task_type == "multilabel":
            self.criterion = torch.nn.BCEWithLogitsLoss(
                weight=self.class_weights, pos_weight=self.pos_weight
            )
        else:  # multiclass
            self.criterion = torch.nn.CrossEntropyLoss(
                weight=self.class_weights,
                label_smoothing=smoothing,
            )

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/models_vit.py#L44
    def random_masking(self, x, mask_ratio):
        """
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1
        )  # ascend: small is keep, large is remove
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0

        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/models_vit.py#L72
    def random_masking_blockwise(self, x, mask_c_ratio, mask_t_ratio):
        """
        2D: ECG recording (N, 1, C, T) (masking c and t under mask_c_ratio and mask_t_ratio)
        Perform per-sample random masking by per-sample shuffling.
        Per-sample shuffling is done by argsort random noise.
        x: [N, L, D], sequence
        """
        N, L, D = x.shape  # batch, length, dim
        C, T = (
            int(self.img_size[-2] / self.patch_size[-2]),
            int(self.img_size[-1] / self.patch_size[-1]),
        )

        # mask C
        x = x.reshape(N, C, T, D)
        len_keep_C = int(C * (1 - mask_c_ratio))
        noise = torch.rand(N, C, device=x.device)  # noise in [0, 1]
        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1
        )  # ascend: small is keep, large is remove
        ids_keep = ids_shuffle[:, :len_keep_C]
        index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
        x = torch.gather(x, dim=1, index=index)  # N, len_keep_C(C'), T, D

        # mask T
        x = x.permute(0, 2, 1, 3)  # N C' T D => N T C' D
        len_keep_T = int(T * (1 - mask_t_ratio))
        noise = torch.rand(N, T, device=x.device)  # noise in [0, 1]
        # sort noise for each sample
        ids_shuffle = torch.argsort(
            noise, dim=1
        )  # ascend: small is keep, large is remove
        ids_keep = ids_shuffle[:, :len_keep_T]
        index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_C, D)
        x_masked = torch.gather(x, dim=1, index=index)
        x_masked = x_masked.permute(0, 2, 1, 3)  # N T' C' D => N C' T' D

        x_masked = x_masked.reshape(
            N, len_keep_T * len_keep_C, D
        )  # N C' T' D => N L' D

        return x_masked, None, None

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/models_vit.py#L107
    def forward_features(self, x):
        """
        x: [B=N, L, D], sequence
        """
        B = x.shape[0]
        x = self.patch_embed(x)

        x = x + self.pos_embed[:, 1:, :]
        if self.masking_blockwise:
            x, _, _ = self.random_masking_blockwise(
                x, self.mask_c_ratio, self.mask_t_ratio
            )
        else:
            x, _, _ = self.random_masking(x, self.mask_ratio)

        cls_token = self.cls_token + self.pos_embed[:, 0, :]
        cls_tokens = cls_token.expand(
            B, -1, -1
        )  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)

        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        if self.global_pool == "attention_pool":
            q = x[:, 1:, :].mean(dim=1, keepdim=True)
            k = x[:, 1:, :]
            v = x[:, 1:, :]
            x, x_weights = self.attention_pool(
                q, k, v
            )  # attention pool without cls token
            outcome = self.fc_norm(x.squeeze(dim=1))
        elif self.global_pool:
            x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
            outcome = self.fc_norm(x)
        else:
            x = self.norm(x)
            outcome = x[:, 0]

        return outcome

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/models_vit.py#L144
    def forward_head(self, x, pre_logits: bool = False):
        if self.global_pool:
            x = (
                x[:, self.num_prefix_tokens :].mean(dim=1)
                if self.global_pool == "avg"
                else x[:, :]
            )
        x = self.fc_norm(x)

        if self.downstream_task == "classification":
            return x if pre_logits else self.head(x)
        elif self.downstream_task == "regression":
            return x if pre_logits else self.head(x)  # .sigmoid()

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], _
    ) -> torch.Tensor:
        """Training step for downstream task."""
        x, y, _ = batch

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L261
        # We inherit forward from VisionTransformer, which calls forward_features and forward_head as per reference above
        y_hat = self(x)

        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss, on_epoch=True, on_step=False, sync_dist=True)

        self.compute_metrics(y_hat, y.long(), "train")
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, int], _
    ) -> torch.Tensor:
        """Validation step for downstream task."""
        x, y, _ = batch

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L261
        # We inherit forward from VisionTransformer, which calls forward_features and forward_head as per reference above
        y_hat = self(x)

        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss, on_epoch=True, on_step=False, sync_dist=True)

        self.compute_metrics(y_hat, y.long(), "val")
        return loss

    def test_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor, int], _
    ) -> torch.Tensor:
        """Test step for downstream task."""
        x, y, _ = batch

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L261
        # We inherit forward from VisionTransformer, which calls forward_features and forward_head as per reference above
        y_hat = self(x)

        loss = self.criterion(y_hat, y)
        self.log("test_loss", loss, on_epoch=True, on_step=False, sync_dist=True)

        self.compute_metrics(y_hat, y.long(), "test")
        return loss

    def configure_optimizers(self) -> Dict:
        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L438
        param_groups = lrd.param_groups_lrd(
            self,
            self.weight_decay,
            no_weight_decay_list=self.no_weight_decay(),
            layer_decay=self.layer_decay,
        )

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L442C29-L442C33
        optimizer = torch.optim.AdamW(param_groups, lr=self.learning_rate)

        # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/util/lr_sched.py
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.warmup_epochs,
            max_epochs=self.max_epochs,
            eta_min=self.min_lr,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
                # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L74
                "frequency": 1,
            },
        }

    def on_train_epoch_end(self):
        self.finalize_metrics("train")

    def on_validation_epoch_end(self):
        self.finalize_metrics("val")

    def on_test_epoch_end(self):
        self.finalize_metrics("test")

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/main_finetune.py#L237C1-L259C22
    @staticmethod
    def attention_forward_wrapper(attn_obj):
        """
        Modified version of def forward() of class Attention() in timm.models.vision_transformer
        """

        def my_forward(x):
            B, N, C = x.shape  # C = embed_dim
            # (3, B, Heads, N, head_dim)
            qkv = (
                attn_obj.qkv(x)
                .reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads)
                .permute(2, 0, 3, 1, 4)
            )
            q, k, v = qkv.unbind(
                0
            )  # make torchscript happy (cannot use tensor as tuple)

            # (B, Heads, N, N)
            attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
            attn = attn.softmax(dim=-1)
            attn = attn_obj.attn_drop(attn)
            # (B, Heads, N, N)
            attn_obj.attn_map = attn  # this was added

            # (B, N, Heads*head_dim)
            x = (attn @ v).transpose(1, 2).reshape(B, N, C)
            x = attn_obj.proj(x)
            x = attn_obj.proj_drop(x)
            return x

        return my_forward
attention_forward_wrapper(attn_obj) staticmethod

Modified version of def forward() of class Attention() in timm.models.vision_transformer

Source code in src/models/ecg_classifier.py
@staticmethod
def attention_forward_wrapper(attn_obj):
    """
    Modified version of def forward() of class Attention() in timm.models.vision_transformer
    """

    def my_forward(x):
        B, N, C = x.shape  # C = embed_dim
        # (3, B, Heads, N, head_dim)
        qkv = (
            attn_obj.qkv(x)
            .reshape(B, N, 3, attn_obj.num_heads, C // attn_obj.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv.unbind(
            0
        )  # make torchscript happy (cannot use tensor as tuple)

        # (B, Heads, N, N)
        attn = (q @ k.transpose(-2, -1)) * attn_obj.scale
        attn = attn.softmax(dim=-1)
        attn = attn_obj.attn_drop(attn)
        # (B, Heads, N, N)
        attn_obj.attn_map = attn  # this was added

        # (B, N, Heads*head_dim)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = attn_obj.proj(x)
        x = attn_obj.proj_drop(x)
        return x

    return my_forward
forward_features(x)

x: [B=N, L, D], sequence

Source code in src/models/ecg_classifier.py
def forward_features(self, x):
    """
    x: [B=N, L, D], sequence
    """
    B = x.shape[0]
    x = self.patch_embed(x)

    x = x + self.pos_embed[:, 1:, :]
    if self.masking_blockwise:
        x, _, _ = self.random_masking_blockwise(
            x, self.mask_c_ratio, self.mask_t_ratio
        )
    else:
        x, _, _ = self.random_masking(x, self.mask_ratio)

    cls_token = self.cls_token + self.pos_embed[:, 0, :]
    cls_tokens = cls_token.expand(
        B, -1, -1
    )  # stole cls_tokens impl from Phil Wang, thanks
    x = torch.cat((cls_tokens, x), dim=1)

    x = self.pos_drop(x)

    for blk in self.blocks:
        x = blk(x)

    if self.global_pool == "attention_pool":
        q = x[:, 1:, :].mean(dim=1, keepdim=True)
        k = x[:, 1:, :]
        v = x[:, 1:, :]
        x, x_weights = self.attention_pool(
            q, k, v
        )  # attention pool without cls token
        outcome = self.fc_norm(x.squeeze(dim=1))
    elif self.global_pool:
        x = x[:, 1:, :].mean(dim=1)  # global pool without cls token
        outcome = self.fc_norm(x)
    else:
        x = self.norm(x)
        outcome = x[:, 0]

    return outcome
random_masking(x, mask_ratio)

Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence

Source code in src/models/ecg_classifier.py
def random_masking(self, x, mask_ratio):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))

    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

    # sort noise for each sample
    ids_shuffle = torch.argsort(
        noise, dim=1
    )  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0

    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore
random_masking_blockwise(x, mask_c_ratio, mask_t_ratio)

2D: ECG recording (N, 1, C, T) (masking c and t under mask_c_ratio and mask_t_ratio) Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random noise. x: [N, L, D], sequence

Source code in src/models/ecg_classifier.py
def random_masking_blockwise(self, x, mask_c_ratio, mask_t_ratio):
    """
    2D: ECG recording (N, 1, C, T) (masking c and t under mask_c_ratio and mask_t_ratio)
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    C, T = (
        int(self.img_size[-2] / self.patch_size[-2]),
        int(self.img_size[-1] / self.patch_size[-1]),
    )

    # mask C
    x = x.reshape(N, C, T, D)
    len_keep_C = int(C * (1 - mask_c_ratio))
    noise = torch.rand(N, C, device=x.device)  # noise in [0, 1]
    # sort noise for each sample
    ids_shuffle = torch.argsort(
        noise, dim=1
    )  # ascend: small is keep, large is remove
    ids_keep = ids_shuffle[:, :len_keep_C]
    index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
    x = torch.gather(x, dim=1, index=index)  # N, len_keep_C(C'), T, D

    # mask T
    x = x.permute(0, 2, 1, 3)  # N C' T D => N T C' D
    len_keep_T = int(T * (1 - mask_t_ratio))
    noise = torch.rand(N, T, device=x.device)  # noise in [0, 1]
    # sort noise for each sample
    ids_shuffle = torch.argsort(
        noise, dim=1
    )  # ascend: small is keep, large is remove
    ids_keep = ids_shuffle[:, :len_keep_T]
    index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_C, D)
    x_masked = torch.gather(x, dim=1, index=index)
    x_masked = x_masked.permute(0, 2, 1, 3)  # N T' C' D => N C' T' D

    x_masked = x_masked.reshape(
        N, len_keep_T * len_keep_C, D
    )  # N C' T' D => N L' D

    return x_masked, None, None
test_step(batch, _)

Test step for downstream task.

Source code in src/models/ecg_classifier.py
def test_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor, int], _
) -> torch.Tensor:
    """Test step for downstream task."""
    x, y, _ = batch

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L261
    # We inherit forward from VisionTransformer, which calls forward_features and forward_head as per reference above
    y_hat = self(x)

    loss = self.criterion(y_hat, y)
    self.log("test_loss", loss, on_epoch=True, on_step=False, sync_dist=True)

    self.compute_metrics(y_hat, y.long(), "test")
    return loss
training_step(batch, _)

Training step for downstream task.

Source code in src/models/ecg_classifier.py
def training_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor], _
) -> torch.Tensor:
    """Training step for downstream task."""
    x, y, _ = batch

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L261
    # We inherit forward from VisionTransformer, which calls forward_features and forward_head as per reference above
    y_hat = self(x)

    loss = self.criterion(y_hat, y)
    self.log("train_loss", loss, on_epoch=True, on_step=False, sync_dist=True)

    self.compute_metrics(y_hat, y.long(), "train")
    return loss
validation_step(batch, _)

Validation step for downstream task.

Source code in src/models/ecg_classifier.py
def validation_step(
    self, batch: Tuple[torch.Tensor, torch.Tensor, int], _
) -> torch.Tensor:
    """Validation step for downstream task."""
    x, y, _ = batch

    # https://github.com/oetu/mae/blob/ba56dd91a7b8db544c1cb0df3a00c5c8a90fbb65/engine_finetune.py#L261
    # We inherit forward from VisionTransformer, which calls forward_features and forward_head as per reference above
    y_hat = self(x)

    loss = self.criterion(y_hat, y)
    self.log("val_loss", loss, on_epoch=True, on_step=False, sync_dist=True)

    self.compute_metrics(y_hat, y.long(), "val")
    return loss

Implementation Credits

The classifier also bases on the MAE implementation by Turgut et al. as outlined in the MAE section.

src.models.cmr_classifier

CMRClassifier

Bases: CMREncoder, MetricsMixin, LightningModule

Source code in src/models/cmr_classifier.py
class CMRClassifier(CMREncoder, MetricsMixin, pl.LightningModule):
    def __init__(
        self,
        backbone_model_name: str,
        num_classes: int,
        weights: Optional[List[float]],
        learning_rate: float,
        weight_decay: float,
        scheduler: str,
        anneal_max_epochs: int,
        warmup_epochs: int,
        max_epochs: int,
        freeze_encoder: bool,
        classifier_type: str,
        task_type: Literal["multiclass", "multilabel"] = "multiclass",
        pretrained_weights: Optional[str] = None,
    ):
        self.task_type = task_type
        if task_type != "multiclass":
            raise ValueError("CMRClassifier only supports multiclass classification")

        super().__init__(backbone_model_name, pretrained_weights)
        self.save_hyperparameters()

        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.initial_lr = learning_rate
        self.weight_decay = weight_decay
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.anneal_max_epochs = anneal_max_epochs
        self.scheduler = scheduler
        self.freeze_encoder = freeze_encoder

        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False

        #  https://github.com/oetu/MMCL-ECG-CMR/blob/bd3c18672de8e5fa73bb753613df94547bd6245b/mmcl/models/ResnetEvalModel.py#L50
        #  NOTE: We specifically omit the possibility of using a projection head as the available model weights come all without it

        # https://github.com/oetu/MMCL-ECG-CMR/blob/bd3c18672de8e5fa73bb753613df94547bd6245b/mmcl/models/ResnetEvalModel.py#L77
        input_dim = self.pooled_dim
        if classifier_type == "mlp":
            self.head = nn.Sequential(
                OrderedDict(
                    [
                        ("fc1", nn.Linear(input_dim, input_dim // 4)),
                        ("relu1", nn.ReLU(inplace=True)),
                        ("fc2", nn.Linear(input_dim // 4, input_dim // 16)),
                        ("relu2", nn.ReLU(inplace=True)),
                        ("fc3", nn.Linear(input_dim // 16, num_classes)),
                    ]
                )
            )
        else:
            self.head = nn.Linear(input_dim, num_classes)

        # https://github.com/oetu/MMCL-ECG-CMR/blob/bd3c18672de8e5fa73bb753613df94547bd6245b/mmcl/models/Evaluator.py#L41
        if weights:
            self.weights = torch.tensor(weights)
        else:
            self.weights = None
        self.criterion = nn.CrossEntropyLoss(weight=self.weights)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through encoder and classifier."""
        x = super().forward(x)
        x = self.head(x)
        return x

    def training_step(
        self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
    ) -> torch.Tensor:
        """Training step for classification."""
        x0, _, y, _ = batch
        y_hat = self(x0)
        y_true = y.argmax(dim=1)

        loss = self.criterion(y_hat, y_true)
        self.compute_metrics(y_hat, y_true, "train")
        self.log("train_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
        return loss

    def validation_step(
        self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
    ) -> torch.Tensor:
        """Validation step for classification."""
        x0, _, y, _ = batch
        y_hat = self(x0)
        y_true = y.argmax(dim=1)

        loss = self.criterion(y_hat, y_true)
        self.compute_metrics(y_hat, y_true, "val")
        self.log("val_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
        return loss

    def test_step(
        self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
    ) -> torch.Tensor:
        """Test step for classification."""
        x0, _, y, _ = batch
        y_hat = self(x0)
        y_true = y.argmax(dim=1)

        loss = self.criterion(y_hat, y_true)
        self.compute_metrics(y_hat, y_true, "test")
        self.log("test_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
        return loss

    def configure_optimizers(self) -> Dict:
        """Configure optimizer for classification task."""
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        if self.scheduler == "cosine":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.anneal_max_epochs, eta_min=0, last_epoch=-1
            )

        elif self.scheduler == "anneal":
            scheduler = LinearWarmupCosineAnnealingLR(
                optimizer, warmup_epochs=self.warmup_epochs, max_epochs=self.max_epochs
            )

        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
        }

    def on_train_epoch_end(self):
        self.finalize_metrics("train")

    def on_validation_epoch_end(self):
        self.finalize_metrics("val")

    def on_test_epoch_end(self):
        self.finalize_metrics("test")
configure_optimizers()

Configure optimizer for classification task.

Source code in src/models/cmr_classifier.py
def configure_optimizers(self) -> Dict:
    """Configure optimizer for classification task."""
    optimizer = torch.optim.Adam(
        self.parameters(),
        lr=self.learning_rate,
        weight_decay=self.weight_decay,
    )

    if self.scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.anneal_max_epochs, eta_min=0, last_epoch=-1
        )

    elif self.scheduler == "anneal":
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer, warmup_epochs=self.warmup_epochs, max_epochs=self.max_epochs
        )

    return {
        "optimizer": optimizer,
        "lr_scheduler": scheduler,
    }
forward(x)

Forward pass through encoder and classifier.

Source code in src/models/cmr_classifier.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass through encoder and classifier."""
    x = super().forward(x)
    x = self.head(x)
    return x
test_step(batch, _)

Test step for classification.

Source code in src/models/cmr_classifier.py
def test_step(
    self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
) -> torch.Tensor:
    """Test step for classification."""
    x0, _, y, _ = batch
    y_hat = self(x0)
    y_true = y.argmax(dim=1)

    loss = self.criterion(y_hat, y_true)
    self.compute_metrics(y_hat, y_true, "test")
    self.log("test_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
    return loss
training_step(batch, _)

Training step for classification.

Source code in src/models/cmr_classifier.py
def training_step(
    self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
) -> torch.Tensor:
    """Training step for classification."""
    x0, _, y, _ = batch
    y_hat = self(x0)
    y_true = y.argmax(dim=1)

    loss = self.criterion(y_hat, y_true)
    self.compute_metrics(y_hat, y_true, "train")
    self.log("train_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
    return loss
validation_step(batch, _)

Validation step for classification.

Source code in src/models/cmr_classifier.py
def validation_step(
    self, batch: Tuple[List[torch.Tensor], torch.Tensor], _
) -> torch.Tensor:
    """Validation step for classification."""
    x0, _, y, _ = batch
    y_hat = self(x0)
    y_true = y.argmax(dim=1)

    loss = self.criterion(y_hat, y_true)
    self.compute_metrics(y_hat, y_true, "val")
    self.log("val_loss", loss, on_epoch=True, on_step=False, sync_dist=True)
    return loss

Implementation Credits

The classifier also bases on the SimCLR implementation by Turgut et al. as outlined in the SimCLR section.

Utility Models

Helper models for evaluation and feature extraction. These models are not part of the main model architecture and are not meant to be used directly but rather as building blocks for other models.

src.models.linear_classifier

LinearClassifier

Bases: Module

Simple linear classifier that is a single fully connected layer from input to class prediction.

Source code in src/models/linear_classifier.py
class LinearClassifier(nn.Module):
    """
    Simple linear classifier that is a single fully connected layer from input to class prediction.
    """

    def __init__(self, in_size: int, num_classes: int, init_type: str) -> None:
        super(LinearClassifier, self).__init__()
        self.model = nn.Linear(in_size, num_classes)
        self.init_type = init_type
        self.model.apply(self.init_weights)

    def init_weights(self, m, init_gain=0.02) -> None:
        """
        Initializes weights according to desired strategy
        """
        if isinstance(m, nn.Linear):
            if self.init_type == "normal":
                nn.init.normal_(m.weight.data, 0, 0.001)
            elif self.init_type == "xavier":
                nn.init.xavier_normal_(m.weight.data, gain=init_gain)
            elif self.init_type == "kaiming":
                nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
            elif self.init_type == "orthogonal":
                nn.init.orthogonal_(m.weight.data, gain=init_gain)
            if hasattr(m, "bias") and m.bias is not None:
                nn.init.constant_(m.bias.data, 0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)
init_weights(m, init_gain=0.02)

Initializes weights according to desired strategy

Source code in src/models/linear_classifier.py
def init_weights(self, m, init_gain=0.02) -> None:
    """
    Initializes weights according to desired strategy
    """
    if isinstance(m, nn.Linear):
        if self.init_type == "normal":
            nn.init.normal_(m.weight.data, 0, 0.001)
        elif self.init_type == "xavier":
            nn.init.xavier_normal_(m.weight.data, gain=init_gain)
        elif self.init_type == "kaiming":
            nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
        elif self.init_type == "orthogonal":
            nn.init.orthogonal_(m.weight.data, gain=init_gain)
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)