drvi.utils.metrics.discrete_scaled_mutual_info_score#
- drvi.utils.metrics.discrete_scaled_mutual_info_score(all_vars_continues, gt_cat_series=None, gt_one_hot=None, n_bins=10)[source]#
Compute mutual information scores using discretized continuous variables.
This function discretizes continuous variables into bins and then computes mutual information between the discretized variables and categorical ground truth. This approach can capture non-linear relationships that might be missed by linear correlation measures.
- Parameters:
all_vars_continues (
ndarray) – Matrix of continuous variables with shape (n_samples, n_variables). Each column represents a different continuous variable.gt_cat_series (default:
None) – Categorical series with ground truth labels.gt_one_hot (default:
None) – One-hot encoded ground truth matrix with shape (n_samples, n_categories).n_bins (default:
10) – Number of bins to use for discretizing continuous variables. More bins capture finer details but may be more sensitive to noise.
- Return type:
- Returns:
np.ndarray Mutual information score matrix with shape (n_variables, n_categories). Element [i, j] represents the normalized mutual information between discretized variable i and category j. Scores range from 0 to 1.
Notes
This function uses uniform binning to discretize continuous variables, then computes mutual information between the discretized variables and categorical ground truth. The scores are normalized by the entropy of each ground truth category.
Advantages over continuous mutual information: The continuous mutual information is not working as expected. More info: scikit-learn/scikit-learn#30772
Examples
>>> import numpy as np >>> import pandas as pd >>> # Simple example: 3 variables, 2 categories >>> all_vars = np.array([[1.0, 2.0, 0.5], [2.0, 1.0, 0.8], [3.0, 0.5, 1.2], [0.5, 3.0, 0.9]]) >>> gt_series = pd.Series(["A", "A", "B", "B"], dtype="category") >>> scores = discrete_scaled_mutual_info_score(all_vars, gt_cat_series=gt_series, n_bins=2) >>> print(scores.shape) # (3, 2) >>> print(scores)