Skip to content

Embedding Space Visualization

This guide explains how to visualize and analyze embedding spaces using UMAP projections. The tools provided help you understand how embeddings evolve during training and how different groups of data are distributed in the embedding space.

Core Visualization Functions

UMAP Projection

src.visualization.embedding_viz

Visualization utilities for embedding analysis and visualization. This module provides functions for dimensionality reduction, group analysis, and visualization of embeddings with various plotting utilities.

run_umap(embeddings, metric='euclidean', min_dist=0.1, random_state=42, n_neighbors=15, n_components=2)

Source code in src/visualization/embedding_viz.py
def run_umap(
    embeddings, metric="euclidean", min_dist=0.1, random_state=42, n_neighbors=15, n_components=2
):
    reducer = umap.UMAP(
        metric=metric,
        min_dist=min_dist,
        random_state=random_state,
        n_neighbors=n_neighbors,
        n_components=n_components,
    )
    embedding_2d = reducer.fit_transform(embeddings)
    return embedding_2d

Global Visualization

src.visualization.embedding_viz

Visualization utilities for embedding analysis and visualization. This module provides functions for dimensionality reduction, group analysis, and visualization of embeddings with various plotting utilities.

Group Analysis Tools

src.visualization.embedding_viz

Visualization utilities for embedding analysis and visualization. This module provides functions for dimensionality reduction, group analysis, and visualization of embeddings with various plotting utilities.

prepare_group_data(df, group_name, max_samples=100)

For a given disease group, filter records that have EXACTLY one label with 'group_name' (but can have other labels in different groups), then store the other integration names for multi-labeled records.

Source code in src/visualization/embedding_viz.py
def prepare_group_data(df, group_name, max_samples=100):
    """
    For a given disease group, filter records that have EXACTLY one label with 'group_name'
    (but can have other labels in different groups),
    then store the other integration names for multi-labeled records.
    """
    def has_exact_one_in_group(labels_meta):
        count = sum(lbl.get("group") == group_name for lbl in labels_meta)
        return count == 1

    df_group = df[df["labels_meta"].apply(has_exact_one_in_group)].copy()

    # Single integration_name for this group
    df_group["integration_name"] = df_group["labels_meta"].apply(
        lambda lm: extract_integration_name_for_group(lm, group_name)
    )
    df_group = df_group.dropna(subset=["integration_name"])

    # total label count
    df_group["total_labels"] = df_group["labels_meta"].apply(len)

    # Identify other integration names if multi-labeled
    def get_other_integration_names(labels_meta):
        """
        Return integration names for labels that do NOT belong to `group_name`.
        """
        others = []
        for lbl in labels_meta:
            if lbl.get("group") != group_name:
                iname = lbl.get("integration_name", "Unknown")
                others.append(iname)
        return sorted(set(others))

    df_group["other_inames"] = df_group["labels_meta"].apply(get_other_integration_names)

    # Mark "exclusive" vs "multi" based on whether there are other integration names
    def is_multi(other_inames):
        return "multi" if len(other_inames) > 0 else "exclusive"

    df_group["mlabel_flag"] = df_group["other_inames"].apply(is_multi)

    # sort by total_labels asc, then sample
    df_group_sorted = df_group.sort_values("total_labels")
    df_group_sub = df_group_sorted.head(max_samples)
    return df_group, df_group_sub

Example Usage

Basic Embedding Comparison

# Prepare embeddings
pretrained_embeddings = ...  # shape: (N, D)
finetuned_embeddings = ...   # shape: (N, D)

# Create UMAP projections
umap_pretrained = run_umap(pretrained_embeddings)
umap_finetuned = run_umap(finetuned_embeddings)

# Visualize
umaps_dict = {
    'Pre-trained': umap_pretrained,
    'Fine-tuned': umap_finetuned
}
plot_global_umap_grid(umaps_dict, metadata_df)

Group Analysis

# Analyze a specific disease group
target_group = "Atrial Fibrillation"
df_group_full, df_group_sub = prepare_group_data(metadata_df, target_group)

# Visualize group distribution
fig, ax = overlay_group_on_embedding(umap_coords, metadata_df, df_group_sub)
plt.title(f"{target_group} Distribution")
plt.show()

Best Practices

  1. Consistency: Use the same UMAP parameters (metric, random_state) when comparing different embeddings.
  2. Sampling: For large datasets, consider using prepare_group_data to sample a manageable subset.
  3. Visual Clarity:
    • Use appropriate alpha values for background points
    • Choose distinct colors for different groups
    • Add legends and titles for clear interpretation

Advanced Customization

The visualization functions are designed to be flexible:

  • Modify color schemes by adjusting the highlight_color and background colors
  • Customize marker styles for different types of samples
  • Adjust figure sizes and grid layouts for different numbers of embeddings
  • Add additional metadata overlays or annotations