Source code for drvi.utils.plotting._latent

from __future__ import annotations

from typing import TYPE_CHECKING

import anndata as ad
import numpy as np
import scanpy as sc
from matplotlib import pyplot as plt

from drvi.utils.plotting import cmap

if TYPE_CHECKING:
    from collections.abc import Sequence
    from typing import Literal

    from anndata import AnnData


[docs] def make_balanced_subsample(adata: AnnData, col: str, min_count: int = 10) -> AnnData: """Create a balanced subsample of AnnData based on a categorical column. This function creates a balanced subsample by sampling an equal number of cells from each category in the specified column, ensuring balanced representation. Parameters ---------- adata Annotated data object to subsample. col Column name in `adata.obs` containing categorical labels for balancing. min_count Minimum number of samples per category. If a category has fewer samples than this, sampling will be done with replacement. Returns ------- AnnData Balanced subsample of the input AnnData object. Notes ----- The function uses a fixed random state (0) for reproducible results. If a category has fewer samples than `min_count`, sampling is done with replacement. """ n_sample_per_cond = adata.obs[col].value_counts().min() balanced_sample_index = ( adata.obs.groupby(col) .sample(n=max(min_count, n_sample_per_cond), random_state=0, replace=n_sample_per_cond < min_count) .index ) adata = adata[balanced_sample_index].copy() return adata
[docs] def plot_latent_dimension_stats( embed: AnnData, figsize: tuple[int, int] = (5, 3), log_scale: bool | Literal["try"] = "try", ncols: int = 5, columns: Sequence[str] = ("reconstruction_effect", "max_value", "mean", "std"), titles: dict[str, str] | None = None, remove_vanished: bool = False, show: bool = True, ): """Plot the statistics of latent dimensions. This function creates line plots showing various statistics of latent dimensions across their ranking order. It can optionally distinguish between vanished and non-vanished dimensions. Parameters ---------- embed Annotated data object containing the latent dimensions and their statistics in the `.var` attribute. figsize The size of each subplot (width, height) in inches. log_scale Whether to use a log scale for the y-axis. If "try", log scale is used only if the minimum value is greater than 0. ncols The maximum number of columns in the subplot grid. columns The columns from `embed.var` to plot. These should be numeric columns containing dimension statistics. titles Custom titles for each column in the plot. If None, default titles are used. remove_vanished Whether to exclude vanished dimensions from the plot. show Whether to display the plot. If False, returns the figure object. Returns ------- matplotlib.figure.Figure or None The matplotlib figure object if `show=False`, otherwise None. Notes ----- The function expects the following columns in `embed.var`: - `order`: Ranking of dimensions - `vanished`: Boolean indicating vanished dimensions - The columns specified in the `columns` parameter If `remove_vanished=False`, a legend is added to distinguish between vanished (black dots) and non-vanished (blue dots) dimensions. Examples -------- >>> # Default plot >>> plot_latent_dimension_stats(embed) >>> >>> # Plot basic statistics >>> plot_latent_dimension_stats(embed, columns=["reconstruction_effect", "max_value"]) >>> # Plot with custom titles and log scale >>> titles = {"reconstruction_effect": "Reconstruction Impact", "max_value": "Max Activation"} >>> plot_latent_dimension_stats(embed, titles=titles, log_scale=True) """ if titles is None: titles = { "reconstruction_effect": "Reconstruction effect", "max_value": "Max value", "mean": "Mean", "std": "Standard Deviation", } nrows = int(np.ceil(len(columns) / ncols)) if nrows == 1: ncols = len(columns) fig, axes = plt.subplots( nrows, ncols, figsize=(figsize[0] * ncols, figsize[1] * nrows), sharey=False, sharex=False, squeeze=False ) # Iterate through columns and plot the data for ax, col in zip(axes.flatten(), columns, strict=False): df = embed.var if remove_vanished: df = df.query("vanished == False") df = df.sort_values("order") ranks = df["order"] x = df[col] ax.plot(ranks, x, linestyle="-", color="grey", label="Line") # Solid line plot for vanished_status_to_plot in [True, False]: indices = df["vanished"] == vanished_status_to_plot ax.plot( ranks[indices], x[indices], "o", markersize=3, color="black" if vanished_status_to_plot else "blue", label="Data Points", ) # Adding labels and title ax.set_xlabel("Rank based on Explanation Share") ax.set_ylabel(titles[col] if col in titles else col) if isinstance(log_scale, str): if log_scale == "try": if x.min() > 0: ax.set_yscale("log") else: if log_scale: ax.set_yscale("log") # Removing the legend ax.legend().remove() # Adding grid ax.grid(axis="x") if not remove_vanished: # Create custom legend entries handles = [] for vanished_status_to_plot in [False, True]: color = "black" if vanished_status_to_plot else "blue" label = "Vanished" if vanished_status_to_plot else "Non-vanished" handles.append( plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color, markersize=5, label=label) ) # Add the legend to the first subplot or the entire figure fig.legend( handles=handles, labels=[handle.get_label() for handle in handles], loc="center left", bbox_to_anchor=(1, 0.5), title=None, ) for ax in axes.flatten()[len(columns) :]: fig.delaxes(ax) plt.tight_layout() if show: plt.show() else: return fig
[docs] def plot_latent_dims_in_umap( embed: AnnData, title_col: str = "title", additional_columns: Sequence[str] = (), max_cells_to_plot: int | None = None, order_col: str = "order", dim_subset: Sequence[str] | None = None, directional: bool = False, remove_vanished: bool = True, rearrange_titles: bool = True, color_bar_rescale_ratio: float = 1.0, show: bool = True, **kwargs, ): """Plot the latent dimensions on a UMAP embedding. This function creates UMAP plots for each latent dimension, showing how cells are distributed in the UMAP space based on their values for each dimension. It can optionally handle directional dimensions and subsample cells for performance. Parameters ---------- embed Annotated data object containing the UMAP embedding in `.obsm['X_umap']` and latent dimensions in `.X`. title_col Name of the column in `embed.var` to use as titles for each dimension. If None, default titles will be used. additional_columns Additional columns from `embed.obs` to plot alongside the latent dimensions. max_cells_to_plot Maximum number of cells to plot. If the number of cells in `embed` is greater than this value, a subsample will be taken. order_col The column in `embed.var` to use for ordering the dimensions. Ignored if `dim_subset` is provided. dim_subset The subset of dimensions to plot. If provided, overrides `order_col`. directional Whether to consider positive and negative directions as separate dimensions. If True, creates separate plots for + and - directions. remove_vanished Whether to remove vanished dimensions from the plot. rearrange_titles Whether to rearrange titles to the bottom right of each plot. color_bar_rescale_ratio Ratio to rescale the height of colorbars. show Whether to display the plot. If False, returns the figure object. **kwargs Additional keyword arguments passed to `sc.pl.umap`. Returns ------- matplotlib.figure.Figure or None The UMAP plot figure if `show=False`, otherwise None. Raises ------ ValueError If required columns (`order_col` or "vanished") are not found in `embed.var`. Notes ----- The function expects the following columns in `embed.var`: - `order_col`: For ordering dimensions - `title_col`: For dimension titles - `vanished`: Boolean indicating vanished dimensions (if `remove_vanished=True`) - `min`, `max`: For setting color scale limits When `directional=True`, the function creates separate plots for positive and negative directions, effectively doubling the number of plots. Examples -------- >>> # Basic UMAP plot of latent dimensions >>> plot_latent_dims_in_umap(embed) >>> # Plot with directional dimensions and custom subset >>> plot_latent_dims_in_umap(embed, directional=True, dim_subset=["DR 1", "DR 2"]) >>> # Plot with additional metadata columns >>> plot_latent_dims_in_umap(embed, additional_columns=["cell_type", "batch"]) """ if order_col not in embed.var: raise ValueError( f'Column "{order_col}" not found in `embed.var`. Please run `set_latent_dimension_stats` to set order.' ) if remove_vanished and "vanished" not in embed.var: raise ValueError( 'Column "vanished" not found in `embed.var`. Please run `set_latent_dimension_stats` to set vanished status.' ) if max_cells_to_plot is not None and embed.n_obs > max_cells_to_plot: embed = sc.pp.subsample(embed, n_obs=max_cells_to_plot, copy=True) if directional: embed_pos = embed.copy() if "layer" in kwargs: embed_pos.X = embed_pos.layers[kwargs["layer"]] del kwargs["layer"] embed_neg = embed_pos.copy() embed_neg.X = -embed_neg.X embed_neg.var["min"], embed_neg.var["max"] = -embed_neg.var["max"], -embed_neg.var["min"] embed_pos.var["_direction"] = "+" embed_neg.var["_direction"] = "-" embed_pos.var[title_col] = embed_pos.var[title_col] + "+" embed_neg.var[title_col] = embed_neg.var[title_col] + "-" if embed.var[order_col].dtype in [np.int32, np.int64, np.float32, np.float64]: embed_pos.var[order_col] = embed_pos.var[order_col] + 1e-8 embed_neg.var[order_col] = embed_neg.var[order_col] - 1e-8 else: embed_pos.var[order_col] = embed_pos.var[order_col].astype(str) + "+" embed_neg.var[order_col] = embed_neg.var[order_col].astype(str) + "-" embed = ad.concat([embed_pos, embed_neg], axis=1, join="inner", merge="first") embed.var.reset_index(drop=True, inplace=True) tmp_df = embed.var.sort_values(order_col) if remove_vanished: tmp_df = tmp_df.query("vanished == False") if dim_subset: tmp_df = tmp_df.set_index(title_col).loc[dim_subset].reset_index() cols_to_show = tmp_df.index if title_col is None else tmp_df[title_col] if additional_columns: cols_to_show = list(cols_to_show) + list(additional_columns) kwargs = { **dict( # noqa: C408 frameon=False, cmap=cmap.saturated_red_blue_cmap if not directional else cmap.saturated_sky_cmap, vmin=list(np.minimum(tmp_df["min"].values, -1)), vcenter=0, vmax=list(np.maximum(tmp_df["max"].values, +1)), ), **kwargs, } fig = sc.pl.umap(embed, gene_symbols=title_col, color=cols_to_show, return_fig=True, **kwargs) for i, ax in enumerate(fig.axes[1 : 2 * len(tmp_df) : 2]): assert hasattr(ax, "_colorbar") pos = ax.get_position() new_pos = [pos.x0, pos.y0, pos.width, pos.height * color_bar_rescale_ratio] ax.set_position(new_pos) if directional: direction = tmp_df["_direction"].iloc[i] if direction == "-": ax.invert_yaxis() labels = -ax.get_yticks() if all(x == int(x) for x in labels): labels = [int(x) for x in labels] ax.set_yticklabels(labels) if rearrange_titles: for ax in fig.axes: ax.text(0.935, 0.05, ax.get_title(), size=15, ha="left", color="black", rotation=90, transform=ax.transAxes) ax.set_title("") if show: plt.show() else: return fig
[docs] def plot_latent_dims_in_heatmap( embed: AnnData, categorical_column: str, title_col: str | None = "title", sort_by_categorical: bool = False, make_balanced: bool = True, order_col: str | None = "order", remove_vanished: bool = True, figsize: tuple[int, int] | None = None, show: bool = True, **kwargs, ): """Plot the latent dimensions in a heatmap. This function creates a heatmap showing the values of latent dimensions across different categories. It can optionally create balanced subsamples and sort dimensions based on categorical differences. Parameters ---------- embed Annotated data object containing the latent dimensions in `.X` and categorical metadata in `.obs`. categorical_column The column in `embed.obs` that represents the categorical variable for grouping cells. title_col The column in `embed.var` to use as titles for each dimension. If None, uses the dimension indices. sort_by_categorical Whether to sort dimensions based on their maximum absolute values within each category. If True, `order_col` is ignored. make_balanced Whether to create a balanced subsample of the data based on the categorical variable using `make_balanced_subsample`. order_col The column in `embed.var` to use for ordering the dimensions. Ignored if `sort_by_categorical=True`. remove_vanished Whether to remove vanished dimensions from the plot. figsize The size of the figure (width, height) in inches. If None, automatically calculated based on number of categories. show Whether to display the plot. If False, returns the plot object. **kwargs Additional keyword arguments passed to `sc.pl.heatmap`. Returns ------- matplotlib.axes.Axes or None The heatmap axes if `show=False`, otherwise None. Raises ------ ValueError If required columns (`order_col` or "vanished") are not found in `embed.var`. Notes ----- The function expects the following columns in `embed.var`: - `order_col`: For ordering dimensions (if `sort_by_categorical=False`) - `title_col`: For dimension titles - `vanished`: Boolean indicating vanished dimensions (if `remove_vanished=True`) If `figsize=None`, the figure height is automatically calculated as `len(unique_categories) / 6` to accommodate all categories. The heatmap uses a red-blue color map centered at 0, with no dendrogram. Examples -------- >>> # Basic heatmap of latent dimensions by cell type >>> plot_latent_dims_in_heatmap(embed, categorical_column="cell_type") >>> # Heatmap with balanced sampling and custom sorting >>> plot_latent_dims_in_heatmap(embed, categorical_column="condition", sort_by_categorical=True, make_balanced=True) >>> # Heatmap with custom figure size >>> plot_latent_dims_in_heatmap(embed, categorical_column="batch", figsize=(12, 8)) """ if order_col is not None and order_col not in embed.var: raise ValueError( f'Column "{order_col}" not found in `embed.var`. Please run `set_latent_dimension_stats` to set order.' ) if remove_vanished: if "vanished" not in embed.var: raise ValueError( 'Column "vanished" not found in `embed.var`. Please run `set_latent_dimension_stats` to set vanished status.' ) embed = embed[:, ~embed.var["vanished"]] if make_balanced: embed = make_balanced_subsample(embed, categorical_column) if sort_by_categorical: dim_order = np.abs(embed.X).argmax(axis=0).argsort().tolist() elif order_col is None: dim_order = np.arange(embed.n_vars) else: dim_order = embed.var[order_col].argsort().tolist() if title_col is None: vars_to_show = embed.var.iloc[dim_order].index else: vars_to_show = embed.var.iloc[dim_order][title_col] if figsize is None: figsize = (10, len(embed.obs[categorical_column].unique()) / 6) kwargs = { **dict( # noqa: C408 vcenter=0, cmap=cmap.saturated_red_blue_cmap, dendrogram=False, ), **kwargs, } return sc.pl.heatmap( embed, vars_to_show, categorical_column, gene_symbols=title_col, figsize=figsize, show_gene_labels=True, show=show, **kwargs, )