from __future__ import annotations
import itertools
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
import scanpy as sc
from matplotlib import pyplot as plt
from drvi.utils.plotting import cmap
from drvi.utils.tools import iterate_on_top_differential_vars
from drvi.utils.tools.interpretability._latent_traverse import get_dimensions_of_traverse_data
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any
from anndata import AnnData
def make_heatmap_groups(ordered_list: list) -> tuple[list[tuple[int, int]], list[Any]]:
"""Create group positions and labels for scanpy heatmap visualization of marker genes.
This helper function processes an ordered list to identify groups of
consecutive identical elements and returns their positions and labels.
It's used to create group annotations for scanpy heatmap plots.
Parameters
----------
ordered_list
List of elements where consecutive identical elements form groups.
Returns
-------
tuple[list[tuple[int, int]], list]
A tuple containing:
- List of tuples with (start_index, end_index) for each group
- List of group labels (unique values from ordered_list)
Notes
-----
The function uses `itertools.groupby` to identify consecutive groups
of identical elements. Each group is represented by its start and end
indices (inclusive).
Examples
--------
>>> # Simple example
>>> groups, labels = make_heatmap_groups(["A", "A", "B", "B", "B", "A"])
>>> print(f"Groups: {groups}") # [(0, 1), (2, 4), (5, 5)]
>>> print(f"Labels: {labels}") # ['A', 'B', 'A']
"""
n_groups, group_names = zip(
*[(len(list(group)), key) for (key, group) in itertools.groupby(ordered_list)], strict=False
)
group_positions = [0] + list(itertools.accumulate(n_groups))
group_positions = list(zip(group_positions[:-1], [c - 1 for c in group_positions[1:]], strict=False))
return group_positions, group_names
[docs]
def differential_vars_heatmap(
traverse_adata: AnnData,
key: str,
title_col: str = "title",
score_threshold: float = 0.0,
remove_vanished: bool = True,
remove_unaffected: bool = False,
figsize: tuple[int, int] | None = None,
show: bool = True,
**kwargs,
):
"""Generate a heatmap of differential variables based on traverse data.
This function creates a comprehensive heatmap visualization showing how
genes respond to latent dimension traversals. The heatmap displays
stepwise effects across all latent dimensions and genes, with genes
grouped by their maximum effect dimension.
Parameters
----------
traverse_adata
AnnData object containing traverse data from `traverse_latent` or
`make_traverse_adata`. Must contain differential effect data for the specified key.
key
Key prefix for the differential variables in `traverse_adata.varm`.
Should correspond to a key used in `find_differential_effects` or
`calculate_differential_vars` (e.g., "max_possible", "min_possible", "combined_score").
title_col
Column name in `traverse_adata.obs` to use as dimension labels.
These titles will be used for axis labels and grouping.
score_threshold
Threshold value for filtering genes based on their maximum effect score.
Only genes with maximum effects above this threshold will be included.
remove_vanished
Whether to remove latent dimensions that have vanished (have no effect).
This helps focus the visualization on meaningful dimensions.
remove_unaffected
Whether to remove genes that have no significant effect (below score_threshold).
When True, only genes with effects above the threshold are shown.
figsize
Size of the figure (width, height) in inches. If None, automatically
calculated based on the number of dimensions.
show
Whether to display the plot. If False, returns the plot object.
**kwargs
Additional keyword arguments passed to `sc.pl.heatmap`.
Returns
-------
matplotlib.axes.Axes or None
The heatmap plot axes if `show=False`, otherwise None.
Raises
------
KeyError
If required data is missing from `traverse_adata`.
ValueError
If the specified key doesn't exist in the AnnData object.
Notes
-----
The function performs the following steps:
1. Calculates maximum effects for each gene in both positive and negative directions
2. Identifies which dimension has the maximum effect for each gene
3. Groups genes by their maximum effect dimension
4. Creates a heatmap showing stepwise effects across all dimensions
5. Applies filtering based on score threshold and vanished dimensions
**Visualization Features:**
- **Color scale**: Red-blue diverging colormap centered at 0
- **Gene grouping**: Genes are grouped by their maximum effect dimension
- **Dimension ordering**: Dimensions are ordered by their `order` column
- **Gene ordering**: Within each group, genes are ordered by effect magnitude
**Interpretation:**
- **Red colors**: Positive effects (increased expression)
- **Blue colors**: Negative effects (decreased expression)
- **Intensity**: Magnitude of the effect
- **Gene groups**: Genes with similar maximum effects are grouped together
Examples
--------
>>> # Basic heatmap with combined scores
>>> differential_vars_heatmap(traverse_adata, "combined_score")
>>> # Heatmap with custom parameters
>>> differential_vars_heatmap(
... traverse_adata, "max_possible", score_threshold=1.0, remove_unaffected=True, figsize=(15, 8)
... )
"""
n_latent, n_steps, n_samples, n_vars = get_dimensions_of_traverse_data(traverse_adata)
max_effect_index_in_positive_direction_for_each_gene = np.abs(
traverse_adata.varm[f"{key}_traverse_effect_pos"]
).values.argmax(axis=1)
max_effect_in_positive_direction_for_each_gene = np.abs(
traverse_adata.varm[f"{key}_traverse_effect_pos"].values[
np.arange(traverse_adata.n_vars), max_effect_index_in_positive_direction_for_each_gene
]
)
max_effect_index_in_negative_direction_for_each_gene = np.abs(
traverse_adata.varm[f"{key}_traverse_effect_neg"]
).values.argmax(axis=1)
max_effect_in_negative_direction_for_each_gene = np.abs(
traverse_adata.varm[f"{key}_traverse_effect_neg"].values[
np.arange(traverse_adata.n_vars), max_effect_index_in_negative_direction_for_each_gene
]
)
traverse_adata.var["max_effect"] = np.maximum(
max_effect_in_positive_direction_for_each_gene, max_effect_in_negative_direction_for_each_gene
)
for col in ["dim_id", "order", title_col]:
title_mapping = dict(zip(traverse_adata.obs["dim_id"].values, traverse_adata.obs[col].values, strict=False))
traverse_adata.var[f"max_effect_dim_{col}"] = np.where(
traverse_adata.var["max_effect"] < score_threshold,
float("nan") if np.isreal(traverse_adata.obs[col].values[0]) else "NONE",
np.where(
max_effect_in_positive_direction_for_each_gene > max_effect_in_negative_direction_for_each_gene,
pd.Series(max_effect_index_in_positive_direction_for_each_gene).map(title_mapping),
pd.Series(max_effect_index_in_negative_direction_for_each_gene).map(title_mapping),
),
)
traverse_adata.var[f"max_effect_dim_{col}_plus"] = np.where(
traverse_adata.var["max_effect"] < score_threshold,
"NONE",
np.where(
max_effect_in_positive_direction_for_each_gene > max_effect_in_negative_direction_for_each_gene,
pd.Series(max_effect_index_in_positive_direction_for_each_gene).map(title_mapping).astype(str) + " +",
pd.Series(max_effect_index_in_negative_direction_for_each_gene).map(title_mapping).astype(str) + " -",
),
)
plot_adata = AnnData(
traverse_adata.uns[f"{key}_traverse_effect_stepwise"].reshape(n_latent * n_steps, n_vars),
var=traverse_adata.var,
obs=pd.DataFrame(
{
"dim_id": np.repeat(np.arange(n_latent), n_steps),
"step_id": np.tile(np.arange(n_steps), n_latent),
}
),
)
for col in ["dim_id", "order", title_col, "vanished"]:
title_mapping = dict(zip(traverse_adata.obs["dim_id"].values, traverse_adata.obs[col].values, strict=False))
plot_adata.obs[col] = plot_adata.obs["dim_id"].map(title_mapping)
if remove_vanished:
plot_adata = plot_adata[~plot_adata.obs["vanished"]].copy()
if remove_unaffected:
plot_adata = plot_adata[:, plot_adata.var["max_effect"] > score_threshold].copy()
plot_adata = plot_adata[
:, plot_adata.var.sort_values(["max_effect_dim_order", "max_effect_dim_order_plus", "max_effect"]).index
].copy()
plot_adata = plot_adata[plot_adata.obs.sort_values(["order"]).index].copy()
if figsize is None:
figsize = (20, plot_adata.obs["dim_id"].nunique() / 4)
vmin = min(
-1,
min(
traverse_adata.varm[f"{key}_traverse_effect_pos"].values.min(),
traverse_adata.varm[f"{key}_traverse_effect_neg"].values.min(),
),
)
vmax = max(
+1,
max(
traverse_adata.varm[f"{key}_traverse_effect_pos"].values.max(),
traverse_adata.varm[f"{key}_traverse_effect_neg"].values.max(),
),
)
var_group_positions, var_group_labels = make_heatmap_groups(plot_adata.var[f"max_effect_dim_{title_col}_plus"])
kwargs = {
**dict( # noqa: C408
vcenter=0,
vmin=vmin,
vmax=vmax,
cmap=cmap.saturated_red_blue_cmap,
var_group_positions=var_group_positions,
var_group_labels=var_group_labels,
var_group_rotation=90,
),
**kwargs,
}
return sc.pl.heatmap(
plot_adata,
plot_adata.var.index,
groupby=title_col,
layer=None,
figsize=figsize,
dendrogram=False,
show=show,
**kwargs,
)
def _bar_plot_top_differential_vars(
plot_info: Sequence[tuple[str, pd.Series]],
dim_subset: Sequence[str] | None = None,
n_top_genes: int = 10,
ncols: int = 5,
show: bool = True,
):
"""Plot the top differential variables in a group of bar plots.
This internal function creates horizontal bar plots showing the top genes
for each latent dimension based on their differential effect scores.
Parameters
----------
plot_info
Sequence of tuples containing dimension titles and corresponding gene data.
dim_subset
Subset of dimensions to plot. If None, all dimensions are plotted.
n_top_genes
Number of top genes to show in each plot.
ncols
Number of columns in the subplot grid.
show
Whether to display the plot. If False, returns the figure object.
Returns
-------
matplotlib.figure.Figure or None
The figure object if `show=False`, otherwise None.
Notes
-----
The function creates a grid of horizontal bar plots, with each subplot
showing the top genes for one latent dimension. Genes are sorted by
their effect scores in descending order.
**Plot Features:**
- **Horizontal bars**: Gene names on y-axis, scores on x-axis
- **Color**: Sky blue bars for all genes
- **Grid**: No grid lines for cleaner appearance
- **Layout**: Automatic grid layout based on number of dimensions
Examples
--------
>>> # Basic bar plot
>>> _bar_plot_top_differential_vars(plot_info)
>>> # Custom layout
>>> _bar_plot_top_differential_vars(plot_info, n_top_genes=15, ncols=3, show=False)
"""
if dim_subset is not None:
plot_info = dict(plot_info)
plot_info = [(dim_id, plot_info[dim_id]) for dim_id in dim_subset]
n_row = int(np.ceil(len(plot_info) / ncols))
fig, axes = plt.subplots(n_row, ncols, figsize=(3 * ncols, int(1 + 0.2 * n_top_genes) * n_row))
for ax, info in zip(axes.flatten(), plot_info, strict=False):
dim_title = info[0]
top_indices = info[1].sort_values(ascending=False)[:n_top_genes]
genes = top_indices.index
values = top_indices.values
# Create a horizontal bar plot
ax.barh(genes, values, color="skyblue")
ax.set_xlabel("Gene Score")
ax.set_title(dim_title)
ax.invert_yaxis()
ax.grid(False)
for ax in axes.flatten()[len(plot_info) :]:
fig.delaxes(ax)
plt.tight_layout()
if show:
plt.show()
else:
return fig
[docs]
def show_top_differential_vars(
traverse_adata: AnnData,
key: str,
title_col: str = "title",
order_col: str = "order",
dim_subset: Sequence[str] | None = None,
gene_symbols: str | None = None,
score_threshold: float = 0.0,
n_top_genes: int = 10,
ncols: int = 5,
show: bool = True,
):
"""Show top differential variables in a bar plot.
This function creates a comprehensive visualization of the top differentially
expressed genes for each latent dimension. It generates horizontal bar plots
showing the genes with the highest effect scores for each dimension.
Parameters
----------
traverse_adata
AnnData object containing the differential analysis results from
`calculate_differential_vars`. Must contain differential effect data
for the specified key.
key
Key prefix for the differential variables in `traverse_adata.varm`.
Should correspond to a key used in `find_differential_effects` or
`calculate_differential_vars` (e.g., "max_possible", "min_possible", "combined_score").
title_col
Column name in `traverse_adata.obs` that contains the titles for each dimension.
These titles will be used as subplot titles.
order_col
Column name in `traverse_adata.obs` that specifies the order of dimensions.
Results will be sorted by this column. Ignored if `dim_subset` is provided.
dim_subset
List of dimensions to plot in the bar plot. If None, all dimensions
with significant effects are plotted.
gene_symbols
Column name in `traverse_adata.var` that contains gene symbols.
If provided, gene symbols will be used in the plot instead of gene indices.
Useful for converting between gene IDs and readable gene names.
score_threshold
Threshold value for gene scores. Only genes with scores above this
threshold will be plotted.
n_top_genes
Number of top genes to plot for each dimension.
ncols
Number of columns in the plot grid.
show
Whether to display the plot. If False, returns the figure object.
Returns
-------
matplotlib.figure.Figure or None
The figure object if `show=False`, otherwise None.
Raises
------
KeyError
If required data is missing from `traverse_adata`.
ValueError
If the specified key doesn't exist in the AnnData object.
Notes
-----
The function performs the following steps:
1. Extracts top differential variables using `iterate_on_top_differential_vars`
2. Filters dimensions based on `dim_subset` if provided
3. Creates horizontal bar plots for each dimension
4. Displays top `n_top_genes` genes sorted by their effect scores
**Visualization Features:**
- **Gene symbols**: If provided, gene symbols will be used instead of gene indices.
- **Grid layout**: Automatic grid based on number of dimensions and `ncols`
- **Horizontal bars**: Gene names on y-axis, scores on x-axis
- **Color coding**: Sky blue bars for all genes
- **Dimension titles**: Each subplot shows the dimension title
- **Gene ordering**: Genes sorted by effect score (highest first)
**Interpretation:**
- **Bar length**: Represents the magnitude of the differential effect
- **Gene position**: Higher bars indicate stronger effects
- **Dimension separation**: Each subplot shows effects for one latent dimension
- **Direction indicators**: Dimension titles include "+" or "-" to indicate effect direction
Examples
--------
>>> # Basic visualization with combined scores
>>> show_top_differential_vars(traverse_adata, "combined_score")
>>> # Custom parameters with gene symbols
>>> show_top_differential_vars(
... traverse_adata, "max_possible", gene_symbols="gene_symbol", score_threshold=1.0, n_top_genes=15, ncols=3
... )
>>> # Subset of dimensions
>>> show_top_differential_vars(traverse_adata, "combined_score", dim_subset=["DR 5+", "DR 12+", "DR 14+"])
"""
plot_info = iterate_on_top_differential_vars(
traverse_adata, key, title_col, order_col, gene_symbols, score_threshold
)
return _bar_plot_top_differential_vars(plot_info, dim_subset, n_top_genes, ncols, show)
[docs]
def show_differential_vars_scatter_plot(
traverse_adata: AnnData,
key_x: str,
key_y: str,
key_combined: str,
title_col: str = "title",
order_col: str = "order",
gene_symbols: str | None = None,
score_threshold: float = 0.0,
dim_subset: Sequence[str] | None = None,
ncols: int = 3,
show: bool = True,
**kwargs,
):
"""Show a scatter plot of differential variables considering multiple criteria.
This function creates scatter plots comparing different differential effect
(usaully "max_possible" and "min_possible") measures for each latent dimension.
It is color-coded by the combined score. It's useful for understanding how
different analysis methods relate to each other and identifying genes
that show consistent effects across multiple criteria. The top 20 genes
are labeled with their names.
Parameters
----------
traverse_adata
AnnData object containing the differential analysis results from
`calculate_differential_vars`. Must contain differential effect data
for all specified keys.
key_x
Key for the x-axis variable in `traverse_adata.varm`.
Typically "max_possible" or "min_possible".
key_y
Key for the y-axis variable in `traverse_adata.varm`.
Typically "min_possible" or "max_possible".
key_combined
Key for the color-coded variable in `traverse_adata.varm`.
Typically "combined_score" for the final combined effect.
title_col
Column name in `traverse_adata.obs` that contains the titles for each dimension.
These titles will be used as subplot titles.
order_col
Column name in `traverse_adata.obs` that specifies the order of dimensions.
Results will be sorted by this column. Ignored if `dim_subset` is provided.
gene_symbols
Column name in `traverse_adata.var` that contains gene symbols.
If provided, gene symbols will be used for point labels instead of gene indices.
score_threshold
Threshold value for gene scores. Only genes with combined scores above
this threshold will be plotted.
dim_subset
Subset of dimensions to plot. If None, all dimensions with significant
effects are plotted.
ncols
Number of columns in the plot grid.
show
Whether to display the plot. If False, returns the figure object.
**kwargs
Additional keyword arguments passed to the scatter plot (e.g., alpha, s for point size).
Returns
-------
matplotlib.figure.Figure or None
The figure object if `show=False`, otherwise None.
Raises
------
KeyError
If required data is missing from `traverse_adata`.
ValueError
If any of the specified keys don't exist in the AnnData object.
Notes
-----
The function performs the following steps:
1. Extracts differential variables for all three keys (x, y, combined)
2. Creates scatter plots for each dimension comparing the two measures
3. Color-codes points by the combined score
4. Labels the top 20 genes by combined score
**Interpretation:**
- **X-axis**: Effect measure from `key_x` (e.g., max_possible)
- **Y-axis**: Effect measure from `key_y` (e.g., min_possible)
- **Color**: Combined score from `key_combined`
- **Point position**: Relationship between the two measures
- **Labeled points**: Genes with highest combined scores
"""
plot_info = {}
for key in [key_x, key_y, key_combined]:
plot_info[key] = iterate_on_top_differential_vars(
traverse_adata, key, title_col, order_col, gene_symbols, score_threshold
)
if dim_subset is None:
dim_ids = [dim_id for dim_id, _ in plot_info[key_combined]]
else:
dim_ids = [dim_id for dim_id, _ in plot_info[key_combined] if dim_id in dim_subset]
for key in [key_x, key_y, key_combined]:
plot_info[key] = dict(plot_info[key])
n_plots = len(dim_ids)
n_row = int(np.ceil(n_plots / ncols))
fig, axes = plt.subplots(n_row, ncols, figsize=(5 * ncols, 4 * n_row), sharex=False, sharey=False)
for ax, dim_id in zip(axes.flatten(), dim_ids, strict=False):
df = (
pd.concat(
{
key_x: plot_info[key_x][dim_id],
key_y: plot_info[key_y][dim_id],
key_combined: plot_info[key_combined][dim_id],
},
axis=1,
)
.dropna()
.sort_values(key_combined, ascending=False)
)
scatter = ax.scatter(df[key_x], df[key_y], c=df[key_combined], cmap="Reds", **kwargs)
cbar = fig.colorbar(scatter, ax=ax)
cbar.set_label(key_combined)
top_20 = df.nlargest(20, key_combined)
for idx, row in top_20.iterrows():
ax.text(row[key_x], row[key_y], str(idx), fontsize=8, ha="left", va="bottom")
ax.set_title(dim_id)
ax.set_xlabel(key_x)
ax.set_ylabel(key_y)
for ax in axes.flatten()[n_plots:]:
fig.delaxes(ax)
plt.tight_layout()
if show:
plt.show()
else:
return fig
def _umap_of_relevant_genes(
adata: AnnData,
embed: AnnData,
plot_info: Sequence[tuple[str, pd.Series]],
layer: str | None = None,
title_col: str = "title",
gene_symbols: str | None = None,
dim_subset: Sequence[str] | None = None,
n_top_genes: int = 10,
max_cells_to_plot: int | None = None,
show: bool = True,
**kwargs,
):
"""Plot UMAP embeddings for specific latent dimensions together with UMAP embeddings of relevant genes.
This internal function creates UMAP visualizations showing how genes
associated with specific latent dimensions are expressed across cells.
The latent dimension values are color-coded by the latent dimension values.
The top genes are color-coded by the gene expression.
Parameters
----------
adata
AnnData object containing single-cell data with UMAP coordinates.
Must have UMAP coordinates in `embed.obsm["X_umap"]`.
embed
AnnData object containing latent representations and dimension metadata.
Must have columns in `.var` corresponding to `title_col`.
plot_info
Information about the top differential variables. Each tuple contains
a dimension title and a pandas Series of gene scores.
layer
Layer name in `adata` to use for gene expression visualization.
If None, uses `.X`.
title_col
Column name in `embed.var` that contains dimension titles.
gene_symbols
Column name in `adata.var` that contains gene symbols.
If provided, gene symbols will be used instead of gene indices.
dim_subset
List of dimensions to plot. If None, all dimensions from plot_info are plotted.
n_top_genes
Number of top genes to visualize for each dimension.
max_cells_to_plot
Maximum number of cells to include in the plot. If None, all cells are plotted.
Useful for large datasets to improve performance.
show
Whether to display the plot. If False, returns a generator of Ax or Axes objects.
**kwargs
Additional keyword arguments passed to `sc.pl.embedding`.
Returns
-------
None or a generator yielding Ax or Axes objects.
Displays the plots directly if show=True, otherwise iteratively yields Ax or Axes objects.
Notes
-----
The function creates two types of visualizations for each dimension:
1. **Latent dimension values**: Shows how the latent dimension varies across cells
2. **Top gene expression**: Shows expression patterns of the top genes for that dimension
**Visualization features:**
- **UMAP coordinates**: Uses UMAP embedding from the embed object
- **Cell subsetting**: Can limit number of cells for performance
- **Gene labeling**: Shows gene names in plot titles
- **Dimension labeling**: Shows dimension names in plot titles
Examples
--------
>>> # Basic UMAP visualization
>>> plot_info = iterate_on_top_differential_vars(traverse_adata, "combined_score")
>>> _umap_of_relevant_genes(adata, embed, plot_info)
>>> # With custom parameters
>>> _umap_of_relevant_genes(
... adata,
... embed,
... plot_info,
... layer="counts",
... title_col="title",
... gene_symbols="gene_symbol",
... n_top_genes=5,
... max_cells_to_plot=5000,
... )
"""
if max_cells_to_plot is not None and adata.n_obs > max_cells_to_plot:
adata = sc.pp.subsample(adata, n_obs=max_cells_to_plot, copy=True)
if dim_subset is not None:
plot_info = dict(plot_info)
plot_info = [(dim_id, plot_info[dim_id]) for dim_id in dim_subset]
adata.obsm["X_umap_method"] = embed[adata.obs.index].obsm["X_umap"]
for dim_title, gene_scores in plot_info:
print(dim_title)
relevant_genes = gene_scores.sort_values(ascending=False).index.to_list()
if dim_title[-1] == "+":
_cmap = cmap.saturated_sky_cmap
real_dim_title = dim_title[:-1]
elif dim_title[-1] == "-":
_cmap = cmap.saturated_sky_cmap.reversed()
real_dim_title = dim_title[:-1]
else:
_cmap = cmap.saturated_red_blue_cmap
real_dim_title = dim_title
adata.obs[dim_title] = list(embed[adata.obs.index, embed.var[title_col] == real_dim_title].X[:, 0])
ax = sc.pl.embedding(
adata, "X_umap_method", color=[dim_title], cmap=_cmap, vcenter=0, show=False, frameon=False, **kwargs
)
ax.text(0.92, 0.05, ax.get_title(), size=15, ha="left", color="black", rotation=90, transform=ax.transAxes)
ax.set_title("")
if show:
plt.show()
else:
yield ax
axes = sc.pl.embedding(
adata,
"X_umap_method",
layer=layer,
color=relevant_genes[:n_top_genes],
cmap=cmap.saturated_just_sky_cmap,
gene_symbols=gene_symbols,
show=False,
frameon=False,
**kwargs,
)
if n_top_genes == 1 or len(relevant_genes) == 1:
axes = [axes]
for ax in axes:
ax.text(0.92, 0.05, ax.get_title(), size=15, ha="left", color="black", rotation=90, transform=ax.transAxes)
ax.set_title("")
if show:
plt.show()
else:
yield axes
[docs]
def plot_relevant_genes_on_umap(
adata: AnnData,
embed: AnnData,
traverse_adata: AnnData,
traverse_adata_key: str,
layer: str | None = None,
title_col: str = "title",
order_col: str = "order",
gene_symbols: str | None = None,
score_threshold: float = 0.0,
dim_subset: Sequence[str] | None = None,
n_top_genes: int = 10,
max_cells_to_plot: int | None = None,
show: bool = True,
**kwargs,
):
"""Plot relevant genes on UMAP embedding.
This function creates UMAP visualizations showing how genes associated
with specific latent dimensions are expressed across cells. The latent
dimension values are color-coded by the latent dimension values. The
top genes are color-coded by the gene expression.
Parameters
----------
adata
AnnData object containing single-cell data with gene expression.
This is the original data used for training the model.
embed
AnnData object containing latent representations and dimension metadata.
Must have UMAP coordinates in `embed.obsm["X_umap"]` and dimension
information in `.var` columns.
traverse_adata
AnnData object containing differential analysis results from
`calculate_differential_vars`. Must contain differential effect data
for the specified key.
traverse_adata_key
Key prefix for the differential variables in `traverse_adata.varm`.
Should correspond to a key used in `find_differential_effects` or
`calculate_differential_vars` (e.g., "max_possible", "min_possible", "combined_score").
layer
Layer name in `adata` to use for gene expression visualization.
If None, uses `.X`. Common options include "counts", "logcounts", etc..
title_col
Column name in `embed.var` that contains dimension titles.
These titles will be used to match dimensions between objects.
order_col
Column name in `embed.var` that specifies the order of dimensions.
Results will be sorted by this column. Ignored if `dim_subset` is provided.
gene_symbols
Column name in `adata.var` that contains gene symbols.
If provided, gene symbols will be used instead of gene indices.
score_threshold
Threshold value for gene scores. Only genes with scores above this
threshold will be visualized.
dim_subset
List of dimensions to plot. If None, all dimensions with significant
effects are plotted.
n_top_genes
Number of top genes to visualize for each dimension.
max_cells_to_plot
Maximum number of cells to include in the plot. If None, all cells are plotted.
Useful for large datasets to improve performance and reduce memory usage.
show
Whether to display the plot. If False, returns a generator of Ax or Axes objects.
**kwargs
Additional keyword arguments passed to `sc.pl.embedding`.
Returns
-------
None or a generator yielding Ax or Axes objects if show=False, otherwise None.
Displays the plots directly if show=True, otherwise iteratively yields Ax or Axes objects.
Raises
------
KeyError
If required data is missing from any of the AnnData objects.
ValueError
If the specified key doesn't exist in traverse_adata.
Notes
-----
The function performs the following steps:
1. Extracts top differential variables using `iterate_on_top_differential_vars`
2. For each dimension, creates two visualizations (I) UMAP of Latent dimension values across cells (II) UMAPs of Expression patterns of top genes for that dimension
**Interpretation:**
- **Latent dimension plots**: Show how the dimension varies across cell types
- **Gene expression plots**: Show expression patterns of dimension-specific genes
- **Color intensity**: Indicates magnitude of values/expression
**Common Use Cases:**
- **Biological validation**: Verify that latent dimensions capture meaningful biology
- **Gene discovery**: Identify genes associated with specific processes
- **Model interpretation**: Understand what biological processes each dimension represents
- **Quality assessment**: Evaluate the biological relevance of the model
Examples
--------
>>> # Basic UMAP visualization with combined scores
>>> plot_relevant_genes_on_umap(adata, embed, traverse_adata, "combined_score")
>>> # With custom parameters
>>> plot_relevant_genes_on_umap(
... adata,
... embed,
... traverse_adata,
... "max_possible",
... layer="logcounts",
... gene_symbols="gene_symbol",
... score_threshold=1.0,
... n_top_genes=5,
... max_cells_to_plot=5000,
... )
>>> # Subset of dimensions
>>> plot_relevant_genes_on_umap(
... adata, embed, traverse_adata, "combined_score", dim_subset=["DR 5+", "DR 12+", "DR 14+"]
... )
"""
plot_info = iterate_on_top_differential_vars(
traverse_adata, traverse_adata_key, title_col, order_col, gene_symbols, score_threshold
)
return _umap_of_relevant_genes(
adata,
embed,
plot_info,
layer,
title_col,
gene_symbols,
dim_subset,
n_top_genes,
max_cells_to_plot,
show,
**kwargs,
)