Uncertainty
bensemble.uncertainty.decomposition
decompose_classification_uncertainty
decompose_classification_uncertainty(
probs: Tensor, eps: float = 1e-08
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Decomposes total predictive uncertainty for classification into aleatoric and epistemic components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
probs | Tensor | Predicted probabilities from the ensemble. Expected shape:[M_models, Batch_size, Num_classes]. | required |
eps | float | Small epsilon for numerical stability in torch.log(). Defaults to 1e-8. | 1e-08 |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor, Tensor] | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - total_unc (torch.Tensor): Total uncertainty (entropy of the mean). Shape: [Batch]. - aleatoric_unc (torch.Tensor): Aleatoric uncertainty (mean of entropies). Shape: [Batch]. - epistemic_unc (torch.Tensor): Epistemic uncertainty (mutual information). Shape: [Batch]. |
Source code in bensemble/uncertainty/decomposition.py
decompose_regression_uncertainty
decompose_regression_uncertainty(
means: Tensor, variances: Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Decomposes total predictive uncertainty for regression into aleatoric and epistemic components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
means | Tensor | Predicted means from the ensemble. Expected shape: [M_models, Batch_size, Out_dim]. | required |
variances | Tensor | Predicted variances from the ensemble. Expected shape:[M_models, Batch_size, Out_dim]. | required |
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor, Tensor] | Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: - total_unc (torch.Tensor): Total predictive variance. Shape:[Batch_size, Out_dim]. - aleatoric_unc (torch.Tensor): Aleatoric uncertainty (mean of variances). Shape: [Batch_size, Out_dim]. - epistemic_unc (torch.Tensor): Epistemic uncertainty (variance of means). Shape: [Batch_size, Out_dim]. |