Source code for drvi.scvi_tools_based.module._drvi

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
import torch
from scvi import REGISTRY_KEYS
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from torch.distributions import Normal

from drvi.nn_modules.embedding import MultiEmbedding
from drvi.nn_modules.layer.factory import LayerFactory
from drvi.nn_modules.noise_model import (
    LogNegativeBinomialNoiseModel,
    NegativeBinomialNoiseModel,
    NormalNoiseModel,
    PoissonNoiseModel,
)
from drvi.nn_modules.prior import StandardPrior
from drvi.scvi_tools_based.module._constants import MODULE_KEYS
from drvi.scvi_tools_based.module._metrics import LatentStats, StreamingPairwiseMI
from drvi.scvi_tools_based.nn import DecoderDRVI, Encoder

if TYPE_CHECKING:
    from collections.abc import Callable, Iterable, Sequence
    from typing import Any, Literal

    TensorDict = dict[str, torch.Tensor]


[docs] class DRVIModule(BaseModuleClass): """DRVI (Disentangled Representation Variational Inference) pytorch module. Parameters ---------- n_input Number of input genes. n_latent Dimensionality of the latent space. n_labels Number of ground-truth labels. Only used for calculation of metrics during training. n_split_latent Number of splits in the latent space. -1 means split all dimensions (n_split_latent=n_latent). split_aggregation How to aggregate splits in the last layer of the decoder. split_method How to make splits: - "split" : Split the latent space - "power" : Transform the latent space to n_split vectors of size n_latent - "power@Z" : Transform the latent space to n_split vectors of size n_latent Z - "split_map" : Split the latent space then map each to latent space using unique transformations - "split_map@Z" : Split the latent space then map each to vector of size Z using unique transformations - "split_diag" : Simple diagonal splitting decoder_reuse_weights Where to reuse the weights of the decoder layers when using splitting. Possible values are 'everywhere', 'last', 'intermediate', 'nowhere', 'not_first'. Defaults to "everywhere". encoder_dims Number of nodes in hidden layers of the encoder. decoder_dims Number of nodes in hidden layers of the decoder. n_cats_per_cov Number of categories for each categorical covariate. n_continuous_cov Number of continuous covariates. encode_covariates Whether to concatenate covariates to expression in encoder. deeply_inject_covariates Whether to concatenate covariates into output of hidden layers in encoder/decoder. This option only applies when `n_layers` >= 1. The covariates are concatenated to the input of subsequent hidden layers. categorical_covariate_dims Embedding dimension of covariate keys if applicable. covariate_modeling_strategy The strategy model takes to remove covariates. use_batch_norm Whether to use batch norm in layers. affine_batch_norm Whether to use affine batch norm in layers. use_layer_norm Whether to use layer norm in layers. fill_in_the_blanks_ratio Ratio for fill-in-the-blanks training. reconstruction_strategy Strategy for reconstruction. - "dense" : Reconstruct all features. - "random_batch@M" : Reconstruct M random features for each batch. input_dropout_rate Dropout rate to apply to the input. encoder_dropout_rate Dropout rate to apply to each of the encoder hidden layers. decoder_dropout_rate Dropout rate to apply to each of the decoder hidden layers. gene_likelihood Gene likelihood model. Options include: - "poisson" : Poisson distributions - "nb" : Negative binomial distributions - "pnb": Log negative binomial distributions - "normal" : Normal distributions - "normal_unit_var" : Normal distributions with unit variance dispersion Dispersion parameter modeling strategy for negative binomial distributions. Options: - "gene" (default): Dispersion parameter is constant per gene across all cells - "gene-batch": Dispersion can differ between different batches - "gene-cell": Dispersion can differ for every gene in every cell Only used when relevant to the gene likelihood model. prior Prior model. var_activation The activation function to ensure positivity of the variational distribution. Options include "exp", "pow2", "2sig" or a custom callable. mean_activation The activation function at the end of mean encoder. Options include "identity", "relu", "leaky_relu", "leaky_relu_{slope}", "elu", "elu_{min_value}" or a custom callable. encoder_layer_factory A layer Factory instance for building encoder layers. decoder_layer_factory A layer Factory instance for building decoder layers. last_layer_gradient_scale Gradient scale for the last layer of the decoder. extra_encoder_kwargs Extra keyword arguments passed into encoder. extra_decoder_kwargs Extra keyword arguments passed into decoder. """ def __init__( self, n_input: int, n_latent: int = 32, n_labels: int = 1, n_split_latent: int | None = -1, split_aggregation: Literal["sum", "logsumexp", "max"] = "logsumexp", split_method: Literal["split", "power", "split_map", "split_diag"] | str = "split_map", decoder_reuse_weights: Literal["everywhere", "last", "intermediate", "nowhere", "not_first"] = "everywhere", encoder_dims: Sequence[int] = (128, 128), decoder_dims: Sequence[int] = (128, 128), n_cats_per_cov: Iterable[int] | None = (), n_continuous_cov: int = 0, encode_covariates: bool = False, deeply_inject_covariates: bool = False, categorical_covariate_dims: Sequence[int] = (), covariate_modeling_strategy: Literal[ "one_hot", "emb", "emb_shared", "one_hot_linear", "emb_linear", "emb_shared_linear", ] = "one_hot", use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "none", affine_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "both", fill_in_the_blanks_ratio: float = 0.0, reconstruction_strategy: str = "dense", input_dropout_rate: float = 0.0, encoder_dropout_rate: float = 0.1, decoder_dropout_rate: float = 0.0, gene_likelihood: Literal["pnb", "nb", "poisson", "normal", "normal_unit_var"] = "pnb", dispersion: Literal["gene", "gene-batch", "gene-cell"] = "gene", prior: Literal["normal"] = "normal", var_activation: Callable | Literal["exp", "pow2", "2sig"] = "exp", mean_activation: Callable | str = "identity", encoder_layer_factory: LayerFactory | None = None, decoder_layer_factory: LayerFactory | None = None, last_layer_gradient_scale: float = 1.0, extra_encoder_kwargs: dict[str, Any] | None = None, extra_decoder_kwargs: dict[str, Any] | None = None, ) -> None: super().__init__() self.n_latent = n_latent self.n_labels = n_labels self.encoder_dims = encoder_dims self.decoder_dims = decoder_dims if n_split_latent is None or n_split_latent == -1: n_split_latent = n_latent self.n_split_latent = n_split_latent self.split_aggregation = split_aggregation self.latent_distribution = "normal" self.gene_likelihood = gene_likelihood self.encode_covariates = encode_covariates self.deeply_inject_covariates = deeply_inject_covariates self.gene_likelihood_module = self._construct_gene_likelihood_module(gene_likelihood, dispersion) self.fill_in_the_blanks_ratio = fill_in_the_blanks_ratio self.reconstruction_strategy = reconstruction_strategy use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" affine_batch_norm_encoder = affine_batch_norm == "encoder" or affine_batch_norm == "both" affine_batch_norm_decoder = affine_batch_norm == "decoder" or affine_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" assert covariate_modeling_strategy in [ "one_hot", "emb", "emb_shared", "one_hot_linear", "emb_linear", "emb_shared_linear", ] if covariate_modeling_strategy in ["emb_shared", "emb_shared_linear"] and len(n_cats_per_cov) > 0: self.shared_covariate_emb = MultiEmbedding( n_cats_per_cov, categorical_covariate_dims, init_method="normal", max_norm=1.0 ) else: self.register_module("shared_covariate_emb", None) self.z_encoder = Encoder( n_input, n_latent, layers_dim=encoder_dims, input_dropout_rate=input_dropout_rate, dropout_rate=encoder_dropout_rate, n_cat_list=n_cats_per_cov if self.encode_covariates else [], n_continuous_cov=n_continuous_cov if self.encode_covariates else 0, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_encoder, affine_batch_norm=affine_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, mean_activation=mean_activation, layer_factory=encoder_layer_factory, covariate_modeling_strategy=covariate_modeling_strategy, categorical_covariate_dims=categorical_covariate_dims if self.encode_covariates else [], **(extra_encoder_kwargs or {}), ) self.decoder = DecoderDRVI( n_latent, n_input, n_split=n_split_latent, split_aggregation=split_aggregation, split_method=split_method, reuse_weights=decoder_reuse_weights, gene_likelihood_module=self.gene_likelihood_module, layers_dim=decoder_dims, dropout_rate=decoder_dropout_rate, n_cat_list=n_cats_per_cov, n_continuous_cov=n_continuous_cov, inject_covariates=deeply_inject_covariates, use_batch_norm=use_batch_norm_decoder, affine_batch_norm=affine_batch_norm_decoder, use_layer_norm=use_layer_norm_decoder, layer_factory=decoder_layer_factory, covariate_modeling_strategy=covariate_modeling_strategy, categorical_covariate_dims=categorical_covariate_dims, last_layer_gradient_scale=last_layer_gradient_scale, **(extra_decoder_kwargs or {}), ) self._setup_streaming_metrics(n_latent, n_labels) self.prior = self._construct_prior(prior) self.inspect_mode = False @property def fully_deterministic(self) -> bool: return self.z_encoder.fully_deterministic @fully_deterministic.setter def fully_deterministic(self, value: bool) -> None: self.z_encoder.fully_deterministic = value def _construct_gene_likelihood_module(self, gene_likelihood: str, dispersion: str) -> Any: """Construct the gene likelihood module based on the specified type. Parameters ---------- gene_likelihood Type of gene likelihood model to construct. dispersion Dispersion parameter modeling strategy. Only used for relevant likelihoods. Returns ------- object Constructed gene likelihood module. Raises ------ NotImplementedError If the gene likelihood type is not supported. """ if gene_likelihood == "normal_unit_var": return NormalNoiseModel(model_var="fixed=1") elif gene_likelihood == "normal": return NormalNoiseModel(model_var=dispersion) elif gene_likelihood == "poisson": return PoissonNoiseModel(mean_transformation="softmax", library_normalization="none") elif gene_likelihood in ["nb"]: return NegativeBinomialNoiseModel( dispersion=dispersion, mean_transformation="softmax", library_normalization="none" ) elif gene_likelihood in ["pnb"]: return LogNegativeBinomialNoiseModel( dispersion=dispersion, mean_transformation="softmax", library_normalization="none" ) else: raise NotImplementedError() def _construct_prior(self, prior: str) -> Any: """Construct the prior model based on the specified type. Parameters ---------- prior Type of prior model to construct. Returns ------- object Constructed prior model. """ if prior == "normal": return StandardPrior() else: raise NotImplementedError() def _get_inference_input(self, tensors: TensorDict) -> dict[str, torch.Tensor | None]: """Parse the dictionary to get appropriate args. Parameters ---------- tensors Dictionary containing tensor data. Returns ------- dict Dictionary with parsed input data. """ tensors[REGISTRY_KEYS.X_KEY] = tensors[REGISTRY_KEYS.X_KEY].to_dense() x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors.get(REGISTRY_KEYS.BATCH_KEY) cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) if batch_index is not None: if cat_covs is not None: cat_covs = torch.cat([batch_index, cat_covs], dim=1) else: cat_covs = batch_index input_dict = {MODULE_KEYS.X_KEY: x, MODULE_KEYS.CONT_COVS_KEY: cont_covs, MODULE_KEYS.CAT_COVS_KEY: cat_covs} return input_dict def _input_pre_processing( self, x: torch.Tensor, cont_covs: torch.Tensor | None = None, cat_covs: torch.Tensor | None = None ) -> dict[str, Any]: """Pre-process input data for the model. Parameters ---------- x Input tensor data. cont_covs Continuous covariates. cat_covs Categorical covariates. Returns ------- dict Dictionary containing pre-processed input data. """ # log the input to the variational distribution for numerical stability x_ = self.gene_likelihood_module.initial_transformation(x) encoder_input = x_ return { MODULE_KEYS.X_KEY: encoder_input, MODULE_KEYS.CAT_COVS_KEY: cat_covs if self.encode_covariates else None, MODULE_KEYS.CONT_COVS_KEY: cont_covs if self.encode_covariates else None, }
[docs] @auto_move_data def inference( self, x: torch.Tensor, cont_covs: torch.Tensor | None = None, cat_covs: torch.Tensor | None = None, n_samples: int = 1, ) -> dict[str, Any]: """High level inference method. Runs the inference (encoder) model. Parameters ---------- x Input tensor data. cont_covs Continuous covariates. cat_covs Categorical covariates. n_samples Number of samples to generate. Returns ------- dict Dictionary containing inference outputs including latent variables. """ pre_processed_input = self._input_pre_processing(x, cont_covs, cat_covs).copy() x_ = pre_processed_input[MODULE_KEYS.X_KEY] library = self._get_library_size(x_original=x, reconstruction_indices=None) # Mask if needed if self.fill_in_the_blanks_ratio > 0.0 and self.training: assert cont_covs is None # We do not consider cont_cov here x_mask = torch.where(torch.rand_like(x_) >= self.fill_in_the_blanks_ratio, 1.0, 0.0) x_ = x_ * x_mask # TODO: check this: # x_ = x_ * x_mask / x_mask.mean(dim=1, keepdim=True) else: x_mask = None # Prepare shared emb if self.shared_covariate_emb is not None and self.encode_covariates: pre_processed_input[MODULE_KEYS.CAT_COVS_KEY] = self.shared_covariate_emb( pre_processed_input[MODULE_KEYS.CAT_COVS_KEY].int() ) # get variational parameters via the encoder networks qz_m, qz_v, z = self.z_encoder( x_, cat_full_tensor=pre_processed_input[MODULE_KEYS.CAT_COVS_KEY], cont_full_tensor=pre_processed_input[MODULE_KEYS.CONT_COVS_KEY], ) outputs = { MODULE_KEYS.Z_KEY: z, MODULE_KEYS.QZM_KEY: qz_m, MODULE_KEYS.QZV_KEY: qz_v, MODULE_KEYS.QL_KEY: None, # We do not model library size MODULE_KEYS.LIBRARY_KEY: library, MODULE_KEYS.X_MASK_KEY: x_mask, MODULE_KEYS.N_SAMPLES_KEY: n_samples, } if n_samples > 1: for key in [ MODULE_KEYS.Z_KEY, MODULE_KEYS.QZM_KEY, MODULE_KEYS.QZV_KEY, MODULE_KEYS.LIBRARY_KEY, MODULE_KEYS.X_MASK_KEY, ]: if outputs[key] is None: continue assert outputs[key].shape[0] == z.shape[0] outputs[key] = outputs[key].unsqueeze(0).repeat(n_samples, *([1] * outputs[key].ndim)) outputs[MODULE_KEYS.Z_KEY] = Normal( outputs[MODULE_KEYS.QZM_KEY], outputs[MODULE_KEYS.QZV_KEY].sqrt() ).rsample() return outputs
def _get_reconstruction_indices(self, tensors: TensorDict) -> None | torch.Tensor: # We also reconstruct a fraction in validation set if not self.training: return None if self.reconstruction_strategy == "dense": return None elif self.reconstruction_strategy.startswith("random_batch@"): x = tensors[REGISTRY_KEYS.X_KEY] n_random_features = int(self.reconstruction_strategy.split("@")[1]) random_indices = torch.randperm(x.shape[1])[:n_random_features] return random_indices else: raise NotImplementedError(f"Reconstruction strategy {self.reconstruction_strategy} not implemented.") def _get_library_size( self, x_original: TensorDict, reconstruction_indices: torch.Tensor | None = None ) -> torch.Tensor: # Note: this is different from scvi implementation of library size that is log transformed # All our noise models accept non-normalized library to work if reconstruction_indices is None: return x_original.sum(1) elif reconstruction_indices.dim() == 1: return x_original[:, reconstruction_indices.to(x_original.device)].sum(1) else: raise NotImplementedError(f"Reconstruction indices {reconstruction_indices} not implemented.") def _get_generative_input( self, tensors: TensorDict, inference_outputs: dict[str, Any], transform_batch: int | None = None, library_to_inject: torch.Tensor | None = None, ) -> dict[str, Any]: """Prepare input for the generative model. Parameters ---------- tensors Dictionary containing tensor data. inference_outputs Outputs from the inference step. library_to_inject Library size to inject (it should not be log transformed. Just x.sum(1)). transform_batch Batch to condition on. Returns ------- dict Dictionary containing input for generative model. """ z = inference_outputs[MODULE_KEYS.Z_KEY] batch_index = tensors.get(REGISTRY_KEYS.BATCH_KEY) cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY) cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY) if transform_batch is not None: batch_index = torch.ones_like(batch_index) * transform_batch if batch_index is not None: if cat_covs is not None: cat_covs = torch.cat([batch_index, cat_covs], dim=1) else: cat_covs = batch_index n_samples = inference_outputs.get(MODULE_KEYS.N_SAMPLES_KEY, 1) reconstruction_indices = self._get_reconstruction_indices(tensors) # Set library size if library_to_inject is not None: library = library_to_inject assert reconstruction_indices is None elif reconstruction_indices is not None: # Override library size as we do not decode everything library = self._get_library_size(tensors[REGISTRY_KEYS.X_KEY], reconstruction_indices) elif MODULE_KEYS.LIBRARY_KEY in inference_outputs: library = inference_outputs[MODULE_KEYS.LIBRARY_KEY] else: raise ValueError("Library size not found in inference outputs.") input_dict = { MODULE_KEYS.Z_KEY: z, MODULE_KEYS.LIBRARY_KEY: library, MODULE_KEYS.CONT_COVS_KEY: cont_covs, MODULE_KEYS.CAT_COVS_KEY: cat_covs, MODULE_KEYS.N_SAMPLES_KEY: n_samples, MODULE_KEYS.RECONSTRUCTION_INDICES: reconstruction_indices, } if n_samples > 1: # Repeat the covariates for each sample for key in [MODULE_KEYS.CAT_COVS_KEY, MODULE_KEYS.CONT_COVS_KEY]: if input_dict[key] is None: continue input_dict[key] = input_dict[key].repeat(n_samples, *([1] * (input_dict[key].ndim - 1))) # Combine samples and batch dimensions into the first dimension for key in [MODULE_KEYS.Z_KEY, MODULE_KEYS.LIBRARY_KEY]: if input_dict[key] is not None: input_dict[key] = input_dict[key].flatten(0, 1) return input_dict
[docs] @auto_move_data def generative( self, z: torch.Tensor, library: torch.Tensor, cat_covs: torch.Tensor | None = None, cont_covs: torch.Tensor | None = None, transform_batch: torch.Tensor | None = None, reconstruction_indices: torch.Tensor | None = None, n_samples: int = 1, ) -> dict[str, Any]: """Runs the generative model. Parameters ---------- z Latent variables. library Library size information. cont_covs Continuous covariates. cat_covs Categorical covariates. transform_batch Batch to condition on. Currently not used but required for RNASeqMixin compatibility. reconstruction_indices Indices of features to reconstruct. Returns ------- dict Dictionary containing generative model outputs. """ # Parameter transform_batch is not used! # But, we keep it here since _rna_mixin.py checks module.generative to include this as a parameter! if self.shared_covariate_emb is not None: cat_covs = self.shared_covariate_emb(cat_covs.int()) # form the likelihood px, params, original_params = self.decoder( z, cat_full_tensor=cat_covs, cont_full_tensor=cont_covs, library=library, reconstruction_indices=reconstruction_indices, return_original_params=self.inspect_mode, ) if n_samples > 1: n_batch = z.shape[0] // n_samples for key, value in params.items(): value = value.reshape(n_samples, n_batch, *value.shape[1:]) # Shape : (n_batch, n_samples, n_genes) params[key] = value library = library.reshape(n_samples, n_batch, *library.shape[1:]) px = self.gene_likelihood_module.dist(parameters=params, lib_y=library) return { MODULE_KEYS.PX_KEY: px, MODULE_KEYS.PX_PARAMS_KEY: params, MODULE_KEYS.PX_UNAGGREGATED_PARAMS_KEY: original_params, MODULE_KEYS.RECONSTRUCTION_INDICES: reconstruction_indices, }
def _get_reconstruction_loss( self, px: torch.distributions.Distribution, x: torch.Tensor, x_mask: torch.Tensor, reconstruction_indices: torch.Tensor, ) -> dict[str, torch.Tensor]: """Get reconstruction loss.""" fill_in_the_blanks = self.fill_in_the_blanks_ratio > 0.0 and self.training if self.reconstruction_strategy == "dense" or reconstruction_indices is None: pass elif self.reconstruction_strategy.startswith("random_batch@"): reconstruction_indices = reconstruction_indices.to(x.device) x = x[:, reconstruction_indices] if fill_in_the_blanks: x_mask = x_mask[:, reconstruction_indices] else: raise NotImplementedError(f"Reconstruction strategy {self.reconstruction_strategy} not implemented.") if fill_in_the_blanks: reconst_loss = -(px.log_prob(x) * (1 - x_mask)).sum(dim=-1) mse = torch.nn.functional.mse_loss(x * x_mask, px.mean * x_mask, reduction="none").sum(dim=1).mean(dim=0) else: reconst_loss = -px.log_prob(x).sum(dim=-1) mse = torch.nn.functional.mse_loss(x, px.mean, reduction="none").sum(dim=1).mean(dim=0) return { MODULE_KEYS.RECONSTRUCTION_LOSS_KEY: reconst_loss, MODULE_KEYS.MSE_LOSS_KEY: mse, } def _get_kl_divergence_z(self, qz_m: torch.Tensor, qz_v: torch.Tensor) -> torch.Tensor: """Get KL divergence term for z.""" return self.prior.kl(Normal(qz_m, torch.sqrt(qz_v))).sum(dim=-1) def _setup_streaming_metrics(self, n_latent: int, n_labels: int) -> None: """Setup Latent stats and MI metrics.""" self.latent_stats = LatentStats(n_latent=n_latent) if n_labels > 1: self.mi_metric = StreamingPairwiseMI(latent_stats=self.latent_stats, n_label=n_labels) else: self.mi_metric = None def _streaming_metrics_step( self, tensors: TensorDict, inference_outputs: dict[str, Any], ) -> None: """Compute MI.""" z = inference_outputs[MODULE_KEYS.QZM_KEY] self.latent_stats.update(z) if self.n_labels > 1: labels = tensors[REGISTRY_KEYS.LABELS_KEY] if labels is not None: labels_flat = torch.clamp(labels.view(-1).long(), 0, self.n_labels - 1) if self.mi_metric is not None: self.mi_metric.update(z, labels_flat, is_train=self.training)
[docs] def loss( self, tensors: TensorDict, inference_outputs: dict[str, Any], generative_outputs: dict[str, Any], kl_weight: float = 1.0, ) -> LossOutput: """Loss function. Parameters ---------- tensors Dictionary containing tensor data. inference_outputs Outputs from the inference step. generative_outputs Outputs from the generative step. kl_weight Weight for KL divergence term. Returns ------- LossOutput Loss output object containing various loss components. """ x = tensors[REGISTRY_KEYS.X_KEY] x_mask = inference_outputs[MODULE_KEYS.X_MASK_KEY] qz_m = inference_outputs[MODULE_KEYS.QZM_KEY] qz_v = inference_outputs[MODULE_KEYS.QZV_KEY] px = generative_outputs[MODULE_KEYS.PX_KEY] reconstruction_indices = generative_outputs[MODULE_KEYS.RECONSTRUCTION_INDICES] kl_divergence_z = self._get_kl_divergence_z(qz_m, qz_v) reconst_losses = self._get_reconstruction_loss(px, x, x_mask, reconstruction_indices) reconst_loss = reconst_losses[MODULE_KEYS.RECONSTRUCTION_LOSS_KEY] assert kl_divergence_z.shape == reconst_loss.shape kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = 0.0 weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup loss = torch.mean(reconst_loss + weighted_kl_local) self._streaming_metrics_step(tensors, inference_outputs) kl_local = {MODULE_KEYS.KL_Z_KEY: kl_divergence_z} reconstruction_loss = {MODULE_KEYS.RECONSTRUCTION_LOSS_KEY: reconst_loss} return LossOutput( loss=loss, reconstruction_loss=reconstruction_loss, kl_local=kl_local, extra_metrics={ MODULE_KEYS.MSE_LOSS_KEY: reconst_losses[MODULE_KEYS.MSE_LOSS_KEY], }, )
[docs] @torch.no_grad() def sample( self, tensors: TensorDict, n_samples: int = 1, library_size: int = 1, generative_kwargs: dict | None = None, ) -> torch.Tensor: # Note: Not tested r""" Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. Parameters ---------- tensors Dictionary containing tensor data. n_samples Number of required samples for each cell. library_size Library size to scale samples to. generative_kwargs Keyword args for ``generative()`` in fwd pass Returns ------- torch.Tensor Tensor with shape (n_cells, n_genes, n_samples). """ inference_kwargs = dict(n_samples=n_samples) # noqa: C408 ( _, generative_outputs, ) = self.forward( tensors, inference_kwargs=inference_kwargs, generative_kwargs=generative_kwargs, compute_loss=False, ) dist = generative_outputs[MODULE_KEYS.PX_KEY] if n_samples > 1: exprs = dist.sample().movedim(0, -1) else: exprs = dist.sample() return exprs.cpu()
[docs] @torch.no_grad() @auto_move_data def marginal_ll(self, tensors: TensorDict, n_mc_samples: int) -> float: """Compute marginal log-likelihood. Parameters ---------- tensors Dictionary containing tensor data. n_mc_samples Number of Monte Carlo samples for estimation. Returns ------- float Marginal log-likelihood value. """ sample_batch = tensors[REGISTRY_KEYS.X_KEY] to_sum = torch.zeros(sample_batch.size()[0], n_mc_samples) for i in range(n_mc_samples): # Distribution parameters and sampled variables inference_outputs, _, losses = self.forward(tensors) qz_m = inference_outputs[MODULE_KEYS.QZM_KEY] qz_v = inference_outputs[MODULE_KEYS.QZV_KEY] z = inference_outputs[MODULE_KEYS.Z_KEY] # Reconstruction Loss reconst_loss = losses.dict_sum(losses.reconstruction_loss) # Log-probabilities p_z = Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)).log_prob(z).sum(dim=-1) p_x_zl = -reconst_loss to_sum[:, i] = p_z + p_x_zl batch_log_lkl = torch.logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) log_lkl = torch.sum(batch_log_lkl).item() return log_lkl