drvi.utils.metrics.DiscreteDisentanglementBenchmark#
- class drvi.utils.metrics.DiscreteDisentanglementBenchmark(embed, discrete_target=None, one_hot_target=None, dim_titles=None, metrics=('SMI-disc', 'SPN', 'ASC'), aggregation_methods=('LMS', 'MSAS', 'MSGS'), additional_metric_params=None)[source]#
Benchmark for evaluating discrete disentanglement in latent representations.
This class provides a comprehensive framework for evaluating how well latent dimensions capture discrete categorical variables (e.g., cell types, experimental conditions, biological processes). It supports multiple evaluation metrics and aggregation methods to provide robust assessment of disentanglement quality.
- Parameters:
embed – Latent representations to evaluate. Shape should be (n_samples, n_dimensions).
discrete_target (default:
None) – Discrete categorical target variable. Should contain categorical labels for each sample. Mutually exclusive withone_hot_target.one_hot_target (default:
None) – One-hot encoded target variable. Shape should be of shape (n_samples, n_categories). Mutually exclusive withdiscrete_target.dim_titles (default:
None) – Titles for each latent dimension. If None, will use “dim_0”, “dim_1”, etc..metrics (default:
('SMI-disc', 'SPN', 'ASC')) – Metrics to compute for evaluation. Available options: - “SMI-disc”: Discrete mutual information score - “SPN”: Nearest neighbor alignment score - “ASC”: Spearman correlation score - “SMI-cont”: Continuous mutual information score (SMI-cont is not working as expected. More info: scikit-learn/scikit-learn#30772).aggregation_methods (default:
('LMS', 'MSAS', 'MSGS')) – Methods to aggregate metric scores across dimensions. Available options: - “LMS”: Latent matching score - “MSAS”: Most similar averaging score - “MSGS”: Most similar gap score.additional_metric_params (default:
None) – Additional parameters to pass to specific metrics. Keys should be metric names, values should be parameter dictionaries.
- embed#
Copy of the input latent representations.
- one_hot_target#
One-hot encoded target variable.
- dim_titles#
Titles for each latent dimension.
- metrics#
Metrics used for evaluation.
- aggregation_methods#
Aggregation methods used.
- additional_metric_params#
Additional parameters for metrics.
- results#
Raw metric results for each dimension and category.
- aggregated_results#
Aggregated scores across dimensions.
- Raises:
ValueError – If neither
discrete_targetnorone_hot_targetis provided. If bothdiscrete_targetandone_hot_targetare provided. Ifdiscrete_targetis not a pandas Series or numpy array. Ifone_hot_targetis not a pandas DataFrame or numpy array.
Notes
The benchmark evaluates disentanglement by measuring how well each latent dimension captures information about discrete categorical variables. Higher scores indicate better disentanglement.
Available Metrics:
SMI-disc: Discrete mutual information between latent dimensions and categorical targets. Measures how much information each dimension contains about the categorical variable.
SPN: Nearest neighbor alignment score. Measures how well the nearest neighbor structure in latent space preserves categorical relationships.
ASC: Spearman correlation score. Measures linear correlation between latent dimensions and categorical targets (less suitable for discrete targets).
SMI-cont: Continuous mutual information (SMI-cont is not working as expected. More info: scikit-learn/scikit-learn#30772).
Aggregation Methods:
LMS: Latent matching score. Finds the optimal matching between latent dimensions and categories. This aggregation discourages presense of multiple irrelevant biological processes in a single dimension.
MSAS: Most similar averaging score. Averages scores for the most similar dimension-category pairs.
MSGS: Most similar gap score. Measures the gap between the best and second-best matches. This aggregation discourages redundancy.
Examples
>>> # Basic usage with discrete targets >>> import numpy as np >>> import pandas as pd >>> # Generate sample data >>> n_samples, n_dims = 1000, 10 >>> embed = np.random.randn(n_samples, n_dims) >>> cell_types = pd.Series(np.random.choice(["A", "B", "C"], n_samples)) >>> # Create benchmark >>> benchmark = DiscreteDisentanglementBenchmark( ... embed, discrete_target=cell_types, metrics=("SMI-disc", "SPN"), aggregation_methods=("LMS", "MSAS") ... ) >>> # Run evaluation >>> benchmark.evaluate() >>> results = benchmark.get_results() >>> print(f"LMS-SMI-disc: {results['LMS-SMI-disc']:.3f}") >>> # With one-hot targets >>> one_hot = pd.get_dummies(cell_types) >>> benchmark = DiscreteDisentanglementBenchmark(embed, one_hot_target=one_hot)
Attributes table#
Methods table#
|
Compute all required metrics and aggregations. |
Get aggregated benchmark results. |
|
Get detailed metric results for each dimension and category. |
|
Check if all metrics and aggregations have been computed. |
|
|
Load a saved benchmark instance. |
|
Load only the aggregated results from a saved benchmark. |
|
Load only the detailed results from a saved benchmark. |
|
Save benchmark results to a file. |
Attributes#
- DiscreteDisentanglementBenchmark.version = 'v2'#
Methods#
- DiscreteDisentanglementBenchmark.evaluate()[source]#
Compute all required metrics and aggregations.
This method computes any missing metrics and updates the aggregated results. It’s safe to call multiple times - only missing computations will be performed.
Notes
The method performs the following steps: 1. Identifies any missing metrics that need to be computed 2. Computes the missing metrics using
_compute_metrics3. Updates the aggregated results using_aggregate_metricsAggregation is always performed as it’s computationally cheap and ensures consistency with any new metric results.
Examples
>>> # Run evaluation >>> benchmark.evaluate() >>> # Check results >>> results = benchmark.get_results() >>> print(f"Number of results: {len(results)}")
- DiscreteDisentanglementBenchmark.get_results()[source]#
Get aggregated benchmark results.
- Returns:
dict Dictionary with “{aggregation_method}-{metric_name}” as keys and aggregated scores as values.
Notes
This method returns the final aggregated scores that summarize the disentanglement performance across all dimensions. Each key follows the pattern “{aggregation_method}-{metric_name}” (e.g., “LMS-SMI-disc”).
Examples
>>> benchmark.evaluate() >>> results = benchmark.get_results() >>> for key, value in results.items(): ... print(f"{key}: {value:.3f}")
- DiscreteDisentanglementBenchmark.get_results_details()[source]#
Get detailed metric results for each dimension and category.
- Returns:
dict Dictionary with metric names as keys and pandas DataFrames as values. Each DataFrame shows scores for each dimension (rows) and category (columns).
Notes
This method returns the raw metric results before aggregation, allowing you to examine how each individual dimension performs for each category. This is useful for identifying which dimensions capture which categories and for debugging or detailed analysis.
Examples
>>> details = benchmark.get_results_details() >>> smi_scores = details["SMI-disc"] >>> print(f"SMI-disc shape: {smi_scores.shape}") >>> print(f"Best dimension for category A: {smi_scores['A'].idxmax()}")
- DiscreteDisentanglementBenchmark.is_complete()[source]#
Check if all metrics and aggregations have been computed.
- Returns:
bool True if all requested metrics and aggregations are complete.
- classmethod DiscreteDisentanglementBenchmark.load(path, embed, discrete_target=None, one_hot_target=None, metrics=None, aggregation_methods=None)[source]#
Load a saved benchmark instance.
- Parameters:
path – File path to the saved benchmark data.
embed – Latent representations (must match the original data).
discrete_target (default:
None) – Discrete categorical target variable.one_hot_target (default:
None) – One-hot encoded target variable.metrics (default:
None) – Override the metrics from the saved data.aggregation_methods (default:
None) – Override the aggregation methods from the saved data.
- Returns:
DiscreteDisentanglementBenchmark Loaded benchmark instance with all results restored.
- Raises:
AssertionError – If the saved version doesn’t match the current class version.
FileNotFoundError – If the specified file doesn’t exist.
ValueError – If the target data is invalid (same as constructor).
Notes
This method creates a new benchmark instance with the same configuration as the saved one, but allows you to override metrics and aggregation methods if needed. The embed and target data must be provided to recreate the instance, but the actual results are loaded from the file.
Examples
>>> # Load with same configuration >>> benchmark = DiscreteDisentanglementBenchmark.load("results.pkl", embed, discrete_target=cell_types)
- classmethod DiscreteDisentanglementBenchmark.load_results(path)[source]#
Load only the aggregated results from a saved benchmark.
- Parameters:
path – File path to the saved benchmark data.
- Returns:
dict Dictionary with aggregated results (same format as
get_results()).
Notes
This is a convenience method for quickly accessing just the final aggregated scores without needing to recreate the full benchmark instance. Useful when you only need the results for analysis or comparison.
Examples
>>> # Quick access to results >>> results = DiscreteDisentanglementBenchmark.load_results("results.pkl") >>> print(f"LMS-SMI-disc score: {results['LMS-SMI-disc']:.3f}")
- classmethod DiscreteDisentanglementBenchmark.load_results_details(path)[source]#
Load only the detailed results from a saved benchmark.
- Parameters:
path – File path to the saved benchmark data.
- Returns:
dict Dictionary with detailed results (same format as
get_results_details()).
Notes
This is a convenience method for quickly accessing the detailed metric results without needing to recreate the full benchmark instance. Useful for detailed analysis of dimension-category relationships.
Examples
>>> # Quick access to detailed results >>> details = DiscreteDisentanglementBenchmark.load_results_details("results.pkl") >>> smi_scores = details["SMI-disc"] >>> print(f"Best dimension for each category:") >>> for col in smi_scores.columns: ... best_dim = smi_scores[col].idxmax() ... print(f" {col}: {best_dim}")
- DiscreteDisentanglementBenchmark.save(path)[source]#
Save benchmark results to a file.
- Parameters:
path – File path where to save the benchmark data.
Notes
The saved data includes: - Version information for compatibility - Raw metric results - Aggregated results - Configuration parameters (metrics, aggregation methods, etc.)
The data is saved using Python’s pickle format, which preserves all object structures and data types.
Examples
>>> benchmark.evaluate() >>> benchmark.save("benchmark_results.pkl") >>> # Load later >>> loaded_benchmark = DiscreteDisentanglementBenchmark.load( ... "benchmark_results.pkl", embed, discrete_target=cell_types ... )