from __future__ import annotations
import logging
import warnings
from typing import TYPE_CHECKING
import numpy as np
import scvi
from scvi import REGISTRY_KEYS, settings
from scvi.data import AnnDataManager
from scvi.data.fields import CategoricalObsField, LayerField, NumericalJointObsField
from scvi.model.base import BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin, VAEMixin
from scvi.utils import setup_anndata_dsp
import drvi
from drvi.scvi_tools_based.data.fields import FixedCategoricalJointObsField
from drvi.scvi_tools_based.model.base import DRVIArchesMixin, GenerativeMixin, InterpretabilityMixin
from drvi.scvi_tools_based.module import DRVIModule, DRVITrainingPlan
if TYPE_CHECKING:
from typing import Any
from anndata import AnnData
_DRVI_LATENT_QZM = "_drvi_latent_qzm"
_DRVI_LATENT_QZV = "_drvi_latent_qzv"
_DRVI_OBSERVED_LIB_SIZE = "_drvi_observed_lib_size"
logger = logging.getLogger(__name__)
[docs]
class DRVI(
RNASeqMixin,
VAEMixin,
DRVIArchesMixin,
UnsupervisedTrainingMixin,
BaseModelClass,
GenerativeMixin,
InterpretabilityMixin,
):
"""DRVI model based on scvi-tools framework for disentangled representation learning.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~drvi.model.DRVI.setup_anndata`.
n_latent
Dimensionality of the latent space.
categorical_embedding_dims
Dictionary mapping categorical covariate names to their embedding dimensions.
Used only if `covariate_modeling_strategy` passed to DRVIModule is based on embedding (not onehot encoding).
Keys should match the covariate names used in :meth:`~drvi.model.DRVI.setup_anndata`.
If not provided, default embedding dimension of 10 is used for all covariates.
**model_kwargs
Additional keyword arguments passed to :class:`~drvi.model.DRVIModule`.
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> drvi.model.DRVI.setup_anndata(adata, categorical_covariate_keys=["batch"])
>>> vae = drvi.model.DRVI(adata)
>>> vae.train()
>>> adata.obsm["latent"] = vae.get_latent_representation()
"""
_module_cls = DRVIModule
_training_plan_cls = DRVITrainingPlan
_LATENT_QZM_KEY = _DRVI_LATENT_QZM
_LATENT_QZV_KEY = _DRVI_LATENT_QZV
_DEFAULT_CATEGORICAL_EMBEDDING_DIM = 10
def __init__(
self,
adata: AnnData | None = None,
registry: dict | None = None,
n_latent: int = 32,
categorical_embedding_dims: dict[str, int] | None = None,
**model_kwargs,
) -> None:
if scvi.__version__ >= "1.3.1":
super().__init__(adata, registry)
else:
super().__init__(adata)
self._handle_backward_compatibility_for_model_kwargs(registry, model_kwargs)
n_batch = self.summary_stats.n_batch
n_cats_per_cov = [n_batch] + self._get_n_cats_per_cov()
n_continuous_cov = self.summary_stats.get("n_extra_continuous_covs", 0)
categorical_covariates_dims = self._compute_categorical_covariates_dims(
n_cats_per_cov, categorical_embedding_dims
)
self._module_kwargs = dict(
n_input=self.summary_stats["n_vars"],
n_latent=n_latent,
n_cats_per_cov=n_cats_per_cov,
n_continuous_cov=n_continuous_cov,
categorical_covariate_dims=categorical_covariates_dims,
n_labels=self.summary_stats.get("n_labels", 1),
**model_kwargs,
)
self.module = self._module_cls(**self._module_kwargs)
self._model_summary_string = (
"DRVI \n"
f"Latent size: {self.module.n_latent}, "
f"splits: {self.module.n_split_latent}, "
f"pooling of splits: '{self.module.split_aggregation}', \n"
f"Encoder dims: {self.module.encoder_dims}, \n"
f"Decoder dims: {self.module.decoder_dims}, \n"
f"Gene likelihood: {self.module.gene_likelihood}, \n"
)
# necessary line to get params that will be used for saving/loading
self.init_params_ = self._get_init_params(locals())
logger.info("The model has been initialized")
def _get_n_cats_per_cov(self) -> list[int]:
"""Get the number of categories per categorical covariate."""
if scvi.__version__ < "1.3.1" and REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry:
cat_cov_stats = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)
return list(cat_cov_stats.n_cats_per_key)
elif scvi.__version__ >= "1.3.1" and REGISTRY_KEYS.CAT_COVS_KEY in self.registry["field_registries"]:
cat_cov_stats = self.registry["field_registries"][REGISTRY_KEYS.CAT_COVS_KEY]["state_registry"]
return cat_cov_stats.get("n_cats_per_key", [])
return []
def _compute_categorical_covariates_dims(
self,
n_cats_per_cov: list[int],
categorical_embedding_dims: dict[str, int] | None,
) -> list[int]:
"""Compute categorical covariates embedding dimensions.
Parameters
----------
n_cats_per_cov
Number of categories per categorical covariate. This includes the batch covariate.
categorical_embedding_dims
Dictionary mapping covariate names to their embedding dimensions. This includes the batch covariate.
Returns
-------
list[int]
List of embedding dimensions for each categorical covariate.
"""
categorical_embedding_dims = categorical_embedding_dims or {}
categorical_covariates_dims = [self._DEFAULT_CATEGORICAL_EMBEDDING_DIM] * len(n_cats_per_cov)
if n_cats_per_cov[0] > 1:
batch_original_key = self.adata_manager.get_state_registry(REGISTRY_KEYS.BATCH_KEY).original_key
if batch_original_key in categorical_embedding_dims:
categorical_covariates_dims[0] = categorical_embedding_dims[batch_original_key]
if len(n_cats_per_cov) > 1:
cat_cov_names = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).field_keys
for i, obs_name in enumerate(cat_cov_names):
if obs_name in categorical_embedding_dims:
categorical_covariates_dims[i + 1] = categorical_embedding_dims[obs_name]
return categorical_covariates_dims
[docs]
@classmethod
@setup_anndata_dsp.dedent
def setup_anndata(
cls,
adata: AnnData,
layer: str | None = None,
is_count_data: bool = True,
batch_key: str | None = None,
labels_key: str | None = None,
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
**kwargs,
) -> None:
"""
%(summary)s.
Parameters
----------
%(param_adata)s
%(param_labels_key)s
%(param_layer)s
%(param_batch_key)s
%(param_cat_cov_keys)s
%(param_cont_cov_keys)s
Returns
-------
%(returns)s
"""
setup_method_args = cls._get_setup_method_args(**locals())
setup_method_args["drvi_version"] = drvi.__version__
# Manupulate kwargs in case of version updates (only when loading a model).
if "source_registry" in kwargs:
cls._handle_backward_compatibility_for_data_setup(kwargs["source_registry"])
anndata_fields = [
LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=is_count_data),
CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key),
CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key),
FixedCategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys),
NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys),
]
adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args)
adata_manager.register_fields(adata, **kwargs)
cls.register_manager(adata_manager)
@staticmethod
def _handle_backward_compatibility_for_data_setup(source_registry: dict[str, Any]) -> dict[str, Any]:
"""Handle backward compatibility for data setup."""
from packaging.version import Version
source_registry_drvi_version = Version(
source_registry.get("drvi_version", "0.1.0")
) # "0.1.0" for legacy code before pypi release
logger.info(f"The model is trained with DRVI version {source_registry_drvi_version}.")
logger.info("Updaging data setup config ...")
while source_registry_drvi_version < Version(drvi.__version__):
if source_registry_drvi_version < Version("0.1.10"):
# No braking change up to 0.1.10
source_registry_drvi_version = Version("0.1.10")
elif source_registry_drvi_version == Version("0.1.10"):
# log the transfer
logger.info("Modifying data args from 0.1.10 to 0.1.11 (no user action required)")
logger.info("Adding empty batch key ...")
source_registry["setup_args"]["batch_key"] = None
source_registry["field_registries"]["batch"] = {
"data_registry": {"attr_name": "obs", "attr_key": "_scvi_batch"},
"state_registry": {"categorical_mapping": np.array([0]), "original_key": "_scvi_batch"},
"summary_stats": {"n_batch": 1},
}
source_registry_drvi_version = Version("0.1.11")
logger.info("Done updating data source registry from 0.1.10 to 0.1.11.")
else:
# No braking change yet!
source_registry_drvi_version = Version(drvi.__version__)
logger.info(f"Done updating data source registry. Loading in DRVI version {drvi.__version__}.")
def _handle_backward_compatibility_for_model_kwargs(self, source_registry: dict | None, model_kwargs: dict) -> None:
"""Handle backward compatibility for model kwargs."""
# TODO: Deprecate in near future
for key in ["categorical_covariates", "batch_key"]:
if key in model_kwargs:
model_kwargs.pop(key)
warnings.warn(
f"Passing {key} to DRVI model is deprecated."
"It is enough to pass this argument to DRVI.setup_anndata."
"This will cause error in the near future.",
DeprecationWarning,
stacklevel=settings.warnings_stacklevel,
)
# This means we are loading a model from a previous version
if source_registry is not None:
from packaging.version import Version
saved_version = Version(source_registry.get("drvi_version", "0.1.0"))
current_version = Version(drvi.__version__)
logger.info(f"Loading model from DRVI version {saved_version}.")
# Process version transitions incrementally
while saved_version < current_version:
if saved_version <= Version("0.2.1"): # TODO: update accordingly before merging
logger.info("Modifying model args from 0.2.1 to 0.2.2 (no user action required)")
# Handle gene_likelihood backward compatibility
if "gene_likelihood" in model_kwargs:
gl = model_kwargs["gene_likelihood"]
gl_mapping = {
"pnb_softmax": "pnb",
"poisson_orig": "poisson",
"nb_orig": "nb",
"nb": "nb",
"normal": "normal_unit_var",
"normal_v": "normal",
"normal_sv": "normal",
}
if gl not in gl_mapping:
raise ValueError(
f"Gene likelihood '{gl}' support is dropped in version 0.2.2"
"Please create an issue if using this gene likelihood is still needed."
)
logger.info(f"Mapping gene likelihood '{gl}' to '{gl_mapping[gl]}'.")
model_kwargs["gene_likelihood"] = gl_mapping[gl]
if gl == "normal_v":
model_kwargs["dispersion"] = "gene-cell"
elif gl == "normal_sv":
model_kwargs["dispersion"] = "gene"
# Handle prior backward compatibility
if "prior" in model_kwargs:
prior = model_kwargs["prior"]
if prior != "normal":
raise ValueError(
f"Prior '{prior}' support is dropped in version 0.2.2"
"Please create an issue if using this prior is still needed."
)
model_kwargs["prior"] = "normal"
if "prior_init_obs" in model_kwargs:
assert model_kwargs["prior_init_obs"] is None
logger.info("Removing prior_init_obs from model args.")
del model_kwargs["prior_init_obs"]
saved_version = Version("0.2.2")
else:
# No breaking changes for versions >= change_version
saved_version = current_version
logger.info(f"Done updating model args. Loading in {current_version}.")