Source code for drvi.scvi_tools_based.model._drvi

from __future__ import annotations

import logging
import warnings
from typing import TYPE_CHECKING

import numpy as np
import scvi
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField, NumericalJointObsField
from scvi.model.base import BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin, VAEMixin
from scvi.utils import setup_anndata_dsp

import drvi
from drvi.scvi_tools_based.data.fields import FixedCategoricalJointObsField
from drvi.scvi_tools_based.model.base import DRVIArchesMixin, GenerativeMixin, InterpretabilityMixin
from drvi.scvi_tools_based.module import DRVIModule, DRVITrainingPlan

if TYPE_CHECKING:
    from typing import Any

    from anndata import AnnData


_DRVI_LATENT_QZM = "_drvi_latent_qzm"
_DRVI_LATENT_QZV = "_drvi_latent_qzv"
_DRVI_OBSERVED_LIB_SIZE = "_drvi_observed_lib_size"


logger = logging.getLogger(__name__)


[docs] class DRVI( RNASeqMixin, VAEMixin, DRVIArchesMixin, UnsupervisedTrainingMixin, BaseModelClass, GenerativeMixin, InterpretabilityMixin, ): """DRVI model based on scvi-tools framework for disentangled representation learning. Parameters ---------- adata AnnData object that has been registered via :meth:`~drvi.model.DRVI.setup_anndata`. n_latent Dimensionality of the latent space. categorical_embedding_dims Dictionary mapping categorical covariate names to their embedding dimensions. Used only if `covariate_modeling_strategy` passed to DRVIModule is based on embedding (not onehot encoding). Keys should match the covariate names used in :meth:`~drvi.model.DRVI.setup_anndata`. If not provided, default embedding dimension of 10 is used for all covariates. **model_kwargs Additional keyword arguments passed to :class:`~drvi.model.DRVIModule`. Examples -------- >>> adata = anndata.read_h5ad(path_to_anndata) >>> drvi.model.DRVI.setup_anndata(adata, categorical_covariate_keys=["batch"]) >>> vae = drvi.model.DRVI(adata) >>> vae.train() >>> adata.obsm["latent"] = vae.get_latent_representation() """ _module_cls = DRVIModule _training_plan_cls = DRVITrainingPlan _LATENT_QZM_KEY = _DRVI_LATENT_QZM _LATENT_QZV_KEY = _DRVI_LATENT_QZV _DEFAULT_CATEGORICAL_EMBEDDING_DIM = 10 def __init__( self, adata: AnnData | None = None, registry: dict | None = None, n_latent: int = 32, categorical_embedding_dims: dict[str, int] | None = None, **model_kwargs, ) -> None: if scvi.__version__ >= "1.3.1": super().__init__(adata, registry) else: super().__init__(adata) self._handle_backward_compatibility_for_model_kwargs(registry, model_kwargs) n_batch = self.summary_stats.n_batch n_cats_per_cov = [n_batch] + self._get_n_cats_per_cov() n_continuous_cov = self.summary_stats.get("n_extra_continuous_covs", 0) categorical_covariates_dims = self._compute_categorical_covariates_dims( n_cats_per_cov, categorical_embedding_dims ) self._module_kwargs = dict( n_input=self.summary_stats["n_vars"], n_latent=n_latent, n_cats_per_cov=n_cats_per_cov, n_continuous_cov=n_continuous_cov, categorical_covariate_dims=categorical_covariates_dims, n_labels=self.summary_stats.get("n_labels", 1), **model_kwargs, ) self.module = self._module_cls(**self._module_kwargs) self._model_summary_string = ( "DRVI \n" f"Latent size: {self.module.n_latent}, " f"splits: {self.module.n_split_latent}, " f"pooling of splits: '{self.module.split_aggregation}', \n" f"Encoder dims: {self.module.encoder_dims}, \n" f"Decoder dims: {self.module.decoder_dims}, \n" f"Gene likelihood: {self.module.gene_likelihood}, \n" ) # necessary line to get params that will be used for saving/loading self.init_params_ = self._get_init_params(locals()) logger.info("The model has been initialized") def _get_n_cats_per_cov(self) -> list[int]: """Get the number of categories per categorical covariate.""" if scvi.__version__ < "1.3.1" and REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry: cat_cov_stats = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY) return list(cat_cov_stats.n_cats_per_key) elif scvi.__version__ >= "1.3.1" and REGISTRY_KEYS.CAT_COVS_KEY in self.registry["field_registries"]: cat_cov_stats = self.registry["field_registries"][REGISTRY_KEYS.CAT_COVS_KEY]["state_registry"] return cat_cov_stats.get("n_cats_per_key", []) return [] def _compute_categorical_covariates_dims( self, n_cats_per_cov: list[int], categorical_embedding_dims: dict[str, int] | None, ) -> list[int]: """Compute categorical covariates embedding dimensions. Parameters ---------- n_cats_per_cov Number of categories per categorical covariate. This includes the batch covariate. categorical_embedding_dims Dictionary mapping covariate names to their embedding dimensions. This includes the batch covariate. Returns ------- list[int] List of embedding dimensions for each categorical covariate. """ categorical_embedding_dims = categorical_embedding_dims or {} categorical_covariates_dims = [self._DEFAULT_CATEGORICAL_EMBEDDING_DIM] * len(n_cats_per_cov) if n_cats_per_cov[0] > 1: batch_original_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key if batch_original_key in categorical_embedding_dims: categorical_covariates_dims[0] = categorical_embedding_dims[batch_original_key] if len(n_cats_per_cov) > 1: cat_cov_names = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).field_keys for i, obs_name in enumerate(cat_cov_names): if obs_name in categorical_embedding_dims: categorical_covariates_dims[i + 1] = categorical_embedding_dims[obs_name] return categorical_covariates_dims
[docs] @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, layer: str | None = None, is_count_data: bool = True, batch_key: str | None = None, labels_key: str | None = None, categorical_covariate_keys: list[str] | None = None, continuous_covariate_keys: list[str] | None = None, **kwargs, ) -> None: """ %(summary)s. Parameters ---------- %(param_adata)s %(param_labels_key)s %(param_layer)s %(param_batch_key)s %(param_cat_cov_keys)s %(param_cont_cov_keys)s Returns ------- %(returns)s """ setup_method_args = cls._get_setup_method_args(**locals()) setup_method_args["drvi_version"] = drvi.__version__ # Manupulate kwargs in case of version updates (only when loading a model). if "source_registry" in kwargs: cls._handle_backward_compatibility_for_data_setup(kwargs["source_registry"]) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=is_count_data), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), FixedCategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), ] adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager)
@staticmethod def _handle_backward_compatibility_for_data_setup(source_registry: dict[str, Any]) -> dict[str, Any]: """Handle backward compatibility for data setup.""" from packaging.version import Version source_registry_drvi_version = Version( source_registry.get("drvi_version", "0.1.0") ) # "0.1.0" for legacy code before pypi release logger.info(f"The model is trained with DRVI version {source_registry_drvi_version}.") logger.info("Updaging data setup config ...") while source_registry_drvi_version < Version(drvi.__version__): if source_registry_drvi_version < Version("0.1.10"): # No braking change up to 0.1.10 source_registry_drvi_version = Version("0.1.10") elif source_registry_drvi_version == Version("0.1.10"): # log the transfer logger.info("Modifying data args from 0.1.10 to 0.1.11 (no user action required)") logger.info("Adding empty batch key ...") source_registry["setup_args"]["batch_key"] = None source_registry["field_registries"]["batch"] = { "data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"}, "state_registry": {"categorical_mapping": np.array([0]), "original_key": "_scvi_batch"}, "summary_stats": {"n_batch": 1}, } source_registry_drvi_version = Version("0.1.11") logger.info("Done updating data source registry from 0.1.10 to 0.1.11.") else: # No braking change yet! source_registry_drvi_version = Version(drvi.__version__) logger.info(f"Done updating data source registry. Loading in DRVI version {drvi.__version__}.") def _handle_backward_compatibility_for_model_kwargs(self, source_registry: dict | None, model_kwargs: dict) -> None: """Handle backward compatibility for model kwargs.""" # TODO: Deprecate in near future for key in ["categorical_covariates", "batch_key"]: if key in model_kwargs: model_kwargs.pop(key) warnings.warn( f"Passing {key} to DRVI model is deprecated." "It is enough to pass this argument to DRVI.setup_anndata." "This will cause error in the near future.", DeprecationWarning, stacklevel=settings.warnings_stacklevel, ) # This means we are loading a model from a previous version if source_registry is not None: from packaging.version import Version saved_version = Version(source_registry.get("drvi_version", "0.1.0")) current_version = Version(drvi.__version__) logger.info(f"Loading model from DRVI version {saved_version}.") # Process version transitions incrementally while saved_version < current_version: if saved_version <= Version("0.2.1"): # TODO: update accordingly before merging logger.info("Modifying model args from 0.2.1 to 0.2.2 (no user action required)") # Handle gene_likelihood backward compatibility if "gene_likelihood" in model_kwargs: gl = model_kwargs["gene_likelihood"] gl_mapping = { "pnb_softmax": "pnb", "poisson_orig": "poisson", "nb_orig": "nb", "nb": "nb", "normal": "normal_unit_var", "normal_v": "normal", "normal_sv": "normal", } if gl not in gl_mapping: raise ValueError( f"Gene likelihood '{gl}' support is dropped in version 0.2.2" "Please create an issue if using this gene likelihood is still needed." ) logger.info(f"Mapping gene likelihood '{gl}' to '{gl_mapping[gl]}'.") model_kwargs["gene_likelihood"] = gl_mapping[gl] if gl == "normal_v": model_kwargs["dispersion"] = "gene-cell" elif gl == "normal_sv": model_kwargs["dispersion"] = "gene" # Handle prior backward compatibility if "prior" in model_kwargs: prior = model_kwargs["prior"] if prior != "normal": raise ValueError( f"Prior '{prior}' support is dropped in version 0.2.2" "Please create an issue if using this prior is still needed." ) model_kwargs["prior"] = "normal" if "prior_init_obs" in model_kwargs: assert model_kwargs["prior_init_obs"] is None logger.info("Removing prior_init_obs from model args.") del model_kwargs["prior_init_obs"] saved_version = Version("0.2.2") else: # No breaking changes for versions >= change_version saved_version = current_version logger.info(f"Done updating model args. Loading in {current_version}.")