# mypy: allow-untyped-defs
import math
from numbers import Number, Real
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from torch.autograd.functional import jacobian
from torch.distributions import (
Bernoulli,
Binomial,
ContinuousBernoulli,
Distribution,
Geometric,
NegativeBinomial,
RelaxedBernoulli,
constraints,
)
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import broadcast_all, lazy_property
from torch.types import _size
default_size = torch.Size()
[docs]class Beta(ExponentialFamily):
r"""
Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`.
Example::
>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
>>> m.sample()
tensor([0.1046])
Args:
concentration1 (float or Tensor): 1st concentration parameter of the distribution
(often referred to as alpha)
concentration0 (float or Tensor): 2nd concentration parameter of the distribution
(often referred to as beta)
"""
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
}
support = constraints.unit_interval
has_rsample = True
def __init__(
self, concentration1: torch.Tensor, concentration0: torch.Tensor, validate_args: Optional[bool] = None
) -> None:
"""
Initializes the Beta distribution with the given concentration parameters.
Args:
concentration1: First concentration parameter (alpha).
concentration0: Second concentration parameter (beta).
validate_args: If True, validates the distribution's parameters.
"""
self.concentration1 = concentration1
self.concentration0 = concentration0
self._gamma1 = Gamma(self.concentration1, torch.ones_like(concentration1), validate_args=validate_args)
self._gamma0 = Gamma(self.concentration0, torch.ones_like(concentration0), validate_args=validate_args)
self._dirichlet = Dirichlet(torch.stack([self.concentration1, self.concentration0], -1))
super().__init__(self._gamma0._batch_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape: torch.Size, _instance: Optional["Beta"] = None) -> "Beta":
"""
Expands the Beta distribution to a new batch shape.
Args:
batch_shape: Desired batch shape.
_instance: Instance to validate.
Returns:
A new Beta distribution instance with expanded parameters.
"""
new = self._get_checked_instance(Beta, _instance)
batch_shape = torch.Size(batch_shape)
new._gamma1 = self._gamma1.expand(batch_shape)
new._gamma0 = self._gamma0.expand(batch_shape)
super(Beta, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@property
def mean(self) -> torch.Tensor:
"""
Computes the mean of the Beta distribution.
Returns:
torch.Tensor: Mean of the distribution.
"""
return self.concentration1 / (self.concentration1 + self.concentration0)
@property
def mode(self) -> torch.Tensor:
"""
Computes the mode of the Beta distribution.
Returns:
torch.Tensor: Mode of the distribution.
"""
return (self.concentration1 - 1) / (self.concentration1 + self.concentration0 - 2)
@property
def variance(self) -> torch.Tensor:
"""
Computes the variance of the Beta distribution.
Returns:
torch.Tensor: Variance of the distribution.
"""
total = self.concentration1 + self.concentration0
return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1))
[docs] def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
"""
Generates a reparameterized sample from the Beta distribution.
Args:
sample_shape (_size): Shape of the sample.
Returns:
torch.Tensor: Sample from the Beta distribution.
"""
z1 = self._gamma1.rsample(sample_shape)
z0 = self._gamma0.rsample(sample_shape)
return z1 / (z1 + z0)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability density of a value under the Beta distribution.
Args:
value: Value to evaluate.
Returns:
Log probability of the value.
"""
if self._validate_args:
self._validate_sample(value)
heads_tails = torch.stack([value, 1.0 - value], -1)
return self._dirichlet.log_prob(heads_tails)
[docs] def entropy(self) -> torch.Tensor:
"""
Computes the entropy of the Beta distribution.
Returns:
Entropy of the distribution.
"""
return self._dirichlet.entropy()
@property
def _natural_params(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns the natural parameters of the distribution.
Returns:
Natural parameters.
"""
return self.concentration1, self.concentration0
def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the log normalizer for the natural parameters.
Args:
x: Parameter 1.
y: Parameter 2.
Returns:
Log normalizer value.
"""
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
[docs]class Dirichlet(ExponentialFamily):
"""
Dirichlet distribution parameterized by a concentration vector.
The Dirichlet distribution is a multivariate generalization of the Beta distribution. It
is commonly used in Bayesian statistics, particularly for modeling proportions.
"""
arg_constraints = {"concentration": constraints.independent(constraints.positive, 1)}
support = constraints.simplex
has_rsample = True
def __init__(self, concentration: torch.Tensor, validate_args: Optional[bool] = None) -> None:
"""
Initializes the Dirichlet distribution.
Args:
concentration: Positive concentration parameter vector (alpha).
validate_args: If True, validates the distribution's parameters.
"""
if torch.numel(concentration) < 1:
raise ValueError("`concentration` parameter must be at least one-dimensional.")
self.concentration = concentration
self.gamma = Gamma(self.concentration, torch.ones_like(self.concentration))
batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:]
super().__init__(batch_shape, event_shape, validate_args=validate_args)
@property
def mean(self) -> torch.Tensor:
"""
Computes the mean of the Dirichlet distribution.
Returns:
Mean vector, calculated as `concentration / concentration.sum(-1, keepdim=True)`.
"""
return self.concentration / self.concentration.sum(-1, keepdim=True)
@property
def mode(self) -> torch.Tensor:
"""
Computes the mode of the Dirichlet distribution.
Note:
- The mode is defined only when all concentration values are > 1.
- For concentrations ≤ 1, the mode vector is clamped to enforce positivity.
Returns:
Mode vector.
"""
concentration_minus_one = (self.concentration - 1).clamp(min=0.0)
mode = concentration_minus_one / concentration_minus_one.sum(-1, keepdim=True)
mask = (self.concentration < 1).all(dim=-1)
mode[mask] = F.one_hot(mode[mask].argmax(dim=-1), concentration_minus_one.shape[-1]).to(mode)
return mode
@property
def variance(self) -> torch.Tensor:
"""
Computes the variance of the Dirichlet distribution.
Returns:
Variance vector for each component.
"""
total_concentration = self.concentration.sum(-1, keepdim=True)
return (
self.concentration
* (total_concentration - self.concentration)
/ (total_concentration.pow(2) * (total_concentration + 1))
)
[docs] def rsample(self, sample_shape: _size = ()) -> torch.Tensor:
"""
Generates a reparameterized sample from the Dirichlet distribution.
Args:
sample_shape (_size): Desired sample shape.
Returns:
torch.Tensor: A reparameterized sample.
"""
z = self.gamma.rsample(sample_shape) # Sample from underlying Gamma distribution
return z / torch.sum(z, dim=-1, keepdims=True)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability density for a given value.
Args:
value (torch.Tensor): Value to evaluate the log probability at.
Returns:
torch.Tensor: Log probability density of the value.
"""
if self._validate_args:
self._validate_sample(value)
return (
torch.xlogy(self.concentration - 1.0, value).sum(-1)
+ torch.lgamma(self.concentration.sum(-1))
- torch.lgamma(self.concentration).sum(-1)
)
[docs] def entropy(self) -> torch.Tensor:
"""
Computes the entropy of the Dirichlet distribution.
Returns:
torch.Tensor: Entropy of the distribution.
"""
k = self.concentration.size(-1)
total_concentration = self.concentration.sum(-1)
return (
torch.lgamma(self.concentration).sum(-1)
- torch.lgamma(total_concentration)
- (k - total_concentration) * torch.digamma(total_concentration)
- ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1)
)
[docs] def expand(self, batch_shape: torch.Size, _instance: Optional["Dirichlet"] = None) -> "Dirichlet":
"""
Expands the distribution parameters to a new batch shape.
Args:
batch_shape (torch.Size): Desired batch shape.
_instance (Optional): Instance to validate.
Returns:
A new Dirichlet distribution instance with expanded parameters.
"""
new = self._get_checked_instance(Dirichlet, _instance)
batch_shape = torch.Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape + self.event_shape)
super(Dirichlet, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@property
def _natural_params(self) -> tuple:
"""
Returns the natural parameters of the distribution.
Returns:
tuple: Natural parameter tuple `(concentration,)`.
"""
return (self.concentration,)
def _log_normalizer(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes the log normalizer for the natural parameters.
Args:
x (torch.Tensor): Natural parameter.
Returns:
torch.Tensor: Log normalizer value.
"""
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
[docs]class StudentT(Distribution):
"""
Student's t-distribution parameterized by degrees of freedom (df), location (loc), and scale (scale).
This distribution is commonly used for robust statistical modeling, particularly when the data
may have outliers or heavier tails than a Normal distribution.
"""
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real,
"scale": constraints.positive,
}
support = constraints.real
has_rsample = True
def __init__(
self, df: torch.Tensor, loc: float = 0.0, scale: float = 1.0, validate_args: Optional[bool] = None
) -> None:
"""
Initializes the Student's t-distribution.
Args:
df (torch.Tensor): Degrees of freedom (must be positive).
loc (float or torch.Tensor): Location parameter (default: 0.0).
scale (float or torch.Tensor): Scale parameter (default: 1.0).
validate_args (Optional[bool]): If True, validates distribution parameters.
"""
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
batch_shape = self.df.size()
super().__init__(batch_shape, validate_args=validate_args)
@property
def mean(self) -> torch.Tensor:
"""
Computes the mean of the distribution.
Note: The mean is undefined when `df <= 1`.
Returns:
torch.Tensor: Mean of the distribution, or NaN for undefined cases.
"""
m = self.loc.clone(memory_format=torch.contiguous_format)
m[self.df <= 1] = float("nan") # Mean is undefined for df <= 1
return m
@property
def mode(self) -> torch.Tensor:
"""
Computes the mode of the distribution.
Returns:
torch.Tensor: Mode of the distribution, which is equal to `loc`.
"""
return self.loc
@property
def variance(self) -> torch.Tensor:
"""
Computes the variance of the distribution.
Note:
- Variance is infinite for 1 < df <= 2.
- Variance is undefined (NaN) for df <= 1.
Returns:
torch.Tensor: Variance of the distribution, or appropriate values for edge cases.
"""
m = self.df.clone(memory_format=torch.contiguous_format)
# Variance for df > 2
m[self.df > 2] = self.scale[self.df > 2].pow(2) * self.df[self.df > 2] / (self.df[self.df > 2] - 2)
# Infinite variance for 1 < df <= 2
m[(self.df <= 2) & (self.df > 1)] = float("inf")
# Undefined variance for df <= 1
m[self.df <= 1] = float("nan")
return m
[docs] def expand(self, batch_shape: torch.Size, _instance: Optional["StudentT"] = None) -> "StudentT":
"""
Expands the distribution parameters to a new batch shape.
Args:
batch_shape (torch.Size): Desired batch size for the expanded distribution.
_instance (Optional): Instance to validate.
Returns:
StudentT: A new StudentT distribution with expanded parameters.
"""
new = self._get_checked_instance(StudentT, _instance)
batch_shape = torch.Size(batch_shape)
new.df = self.df.expand(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
super(StudentT, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability density for a given value.
Args:
value (torch.Tensor): Value to evaluate the log probability at.
Returns:
torch.Tensor: Log probability density of the given value.
"""
if self._validate_args:
self._validate_sample(value)
y = (value - self.loc) / self.scale
Z = (
self.scale.log()
+ 0.5 * self.df.log()
+ 0.5 * math.log(math.pi)
+ torch.lgamma(0.5 * self.df)
- torch.lgamma(0.5 * (self.df + 1.0))
)
return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z
[docs] def entropy(self) -> torch.Tensor:
"""
Computes the entropy of the Student's t-distribution.
Returns:
torch.Tensor: Entropy of the distribution.
"""
lbeta = torch.lgamma(0.5 * self.df) + math.lgamma(0.5) - torch.lgamma(0.5 * (self.df + 1))
return (
self.scale.log()
+ 0.5 * (self.df + 1) * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df))
+ 0.5 * self.df.log()
+ lbeta
)
def _transform(self, z: torch.Tensor) -> torch.Tensor:
"""
Transforms an input tensor `z` to a standardized form based on the location and scale.
Args:
z (torch.Tensor): Input tensor to transform.
Returns:
torch.Tensor: Transformed tensor representing the standardized form.
"""
return (z - self.loc) / self.scale
def _d_transform_d_z(self) -> torch.Tensor:
"""
Computes the derivative of the transform function with respect to `z`.
Returns:
torch.Tensor: Reciprocal of the scale, representing the gradient for reparameterization.
"""
return 1 / self.scale
[docs] def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
"""
Generates a reparameterized sample from the Student's t-distribution.
Args:
sample_shape (_size): Shape of the sample.
Returns:
torch.Tensor: Reparameterized sample, enabling gradient tracking.
"""
self.loc = self.loc.expand(self._extended_shape(sample_shape))
self.scale = self.scale.expand(self._extended_shape(sample_shape))
gamma_samples = Gamma(self.df * 0.5, self.df * 0.5).rsample(sample_shape)
normal_samples = Normal(torch.zeros(gamma_samples.shape), torch.ones(gamma_samples.shape)).sample()
# Sample from Normal distribution (shape must match after broadcasting)
x = self.loc.detach() + self.scale.detach() * normal_samples * torch.rsqrt(gamma_samples)
transform = self._transform(x.detach()) # Standardize the sample
surrogate_x = -transform / self._d_transform_d_z().detach() # Compute surrogate gradient
return x + (surrogate_x - surrogate_x.detach())
[docs]class Gamma(ExponentialFamily):
"""
Gamma distribution parameterized by `concentration` (shape) and `rate` (inverse scale).
The Gamma distribution is often used to model the time until an event occurs,
and it is a continuous probability distribution defined for non-negative real values.
"""
arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
}
support = constraints.nonnegative
has_rsample = True
_mean_carrier_measure = 0
def __init__(
self,
concentration: torch.Tensor,
rate: torch.Tensor,
validate_args: Optional[bool] = None,
) -> None:
"""
Initializes the Gamma distribution.
Args:
concentration (torch.Tensor): Shape parameter of the distribution (often referred to as alpha).
rate (torch.Tensor): Rate parameter (inverse of scale, often referred to as beta).
validate_args (Optional[bool]): If True, validates the distribution's parameters.
"""
self.concentration, self.rate = broadcast_all(concentration, rate)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()
super().__init__(batch_shape, validate_args=validate_args)
@property
def mean(self) -> torch.Tensor:
"""
Computes the mean of the Gamma distribution.
Returns:
torch.Tensor: Mean of the distribution, calculated as `concentration / rate`.
"""
return self.concentration / self.rate
@property
def mode(self) -> torch.Tensor:
"""
Computes the mode of the Gamma distribution.
Note:
- The mode is defined only for `concentration > 1`. For `concentration <= 1`,
the mode is clamped to 0.
Returns:
torch.Tensor: Mode of the distribution.
"""
return ((self.concentration - 1) / self.rate).clamp(min=0)
@property
def variance(self) -> torch.Tensor:
"""
Computes the variance of the Gamma distribution.
Returns:
torch.Tensor: Variance of the distribution, calculated as `concentration / rate^2`.
"""
return self.concentration / self.rate.pow(2)
[docs] def expand(self, batch_shape: torch.Size, _instance: Optional["Gamma"] = None) -> "Gamma":
"""
Expands the distribution parameters to a new batch shape.
Args:
batch_shape (torch.Size): Desired batch shape.
_instance (Optional): Instance to validate.
Returns:
Gamma: A new Gamma distribution instance with expanded parameters.
"""
new = self._get_checked_instance(Gamma, _instance)
batch_shape = torch.Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
new.rate = self.rate.expand(batch_shape)
super(Gamma, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
[docs] def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
"""
Generates a reparameterized sample from the Gamma distribution.
Args:
sample_shape (_size): Shape of the sample.
Returns:
torch.Tensor: A reparameterized sample.
"""
shape = self._extended_shape(sample_shape)
concentration = self.concentration.expand(shape)
rate = self.rate.expand(shape)
# Generate a sample using the underlying C++ implementation for efficiency
value = torch._standard_gamma(concentration) / rate.detach()
# Detach u for surrogate computation
u = value.detach() * rate.detach() / rate
value = value + (u - u.detach())
# Ensure numerical stability for gradients
value.detach().clamp_(min=torch.finfo(value.dtype).tiny)
return value
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability density for a given value.
Args:
value (torch.Tensor): Value to evaluate the log probability at.
Returns:
torch.Tensor: Log probability density of the given value.
"""
value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
self._validate_sample(value)
return (
torch.xlogy(self.concentration, self.rate)
+ torch.xlogy(self.concentration - 1, value)
- self.rate * value
- torch.lgamma(self.concentration)
)
[docs] def entropy(self) -> torch.Tensor:
"""
Computes the entropy of the Gamma distribution.
Returns:
torch.Tensor: Entropy of the distribution.
"""
return (
self.concentration
- torch.log(self.rate)
+ torch.lgamma(self.concentration)
+ (1.0 - self.concentration) * torch.digamma(self.concentration)
)
@property
def _natural_params(self) -> tuple:
"""
Returns the natural parameters of the distribution.
Returns:
tuple: Tuple of natural parameters `(concentration - 1, -rate)`.
"""
return self.concentration - 1, -self.rate
def _log_normalizer(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
Computes the log normalizer for the natural parameters.
Args:
x (torch.Tensor): First natural parameter.
y (torch.Tensor): Second natural parameter.
Returns:
torch.Tensor: Log normalizer value.
"""
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
[docs] def cdf(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the cumulative distribution function (CDF) for the Gamma distribution.
Args:
value (torch.Tensor): Value to evaluate the CDF at.
Returns:
torch.Tensor: CDF of the given value.
"""
if self._validate_args:
self._validate_sample(value)
return torch.special.gammainc(self.concentration, self.rate * value)
[docs]class Normal(ExponentialFamily):
"""
Represents the Normal (Gaussian) distribution with specified mean (loc) and standard deviation (scale).
Inherits from PyTorch's ExponentialFamily distribution class.
"""
has_rsample = True
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
def __init__(
self,
loc: torch.Tensor,
scale: torch.Tensor,
validate_args: Optional[bool] = None,
) -> None:
"""
Initializes the Normal distribution.
Args:
loc (torch.Tensor): Mean (location) parameter of the distribution.
scale (torch.Tensor): Standard deviation (scale) parameter of the distribution.
validate_args (Optional[bool]): If True, checks the distribution parameters for validity.
"""
self.loc, self.scale = broadcast_all(loc, scale)
# Determine batch shape based on the type of `loc` and `scale`.
batch_shape = torch.Size() if isinstance(loc, Number) and isinstance(scale, Number) else self.loc.size()
super().__init__(batch_shape, validate_args=validate_args)
@property
def mean(self) -> torch.Tensor:
"""
Returns the mean of the distribution.
Returns:
torch.Tensor: The mean (location) parameter `loc`.
"""
return self.loc
@property
def mode(self) -> torch.Tensor:
"""
Returns the mode of the distribution.
Returns:
torch.Tensor: The mode (equal to `loc` in a Normal distribution).
"""
return self.loc
@property
def stddev(self) -> torch.Tensor:
"""
Returns the standard deviation of the distribution.
Returns:
torch.Tensor: The standard deviation (scale) parameter `scale`.
"""
return self.scale
@property
def variance(self) -> torch.Tensor:
"""
Returns the variance of the distribution.
Returns:
torch.Tensor: The variance, computed as `scale ** 2`.
"""
return self.stddev.pow(2)
[docs] def entropy(self) -> torch.Tensor:
"""
Computes the entropy of the distribution.
Returns:
torch.Tensor: The entropy of the Normal distribution, which is a measure of uncertainty.
"""
return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale)
[docs] def cdf(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the cumulative distribution function (CDF) of the distribution at a given value.
Args:
value (torch.Tensor): The value at which to evaluate the CDF.
Returns:
torch.Tensor: The probability that a random variable from the distribution is less than or equal to `value`.
"""
return 0.5 * (1 + torch.erf((value - self.loc) / (self.scale * math.sqrt(2))))
[docs] def expand(self, batch_shape: torch.Size, _instance: Optional["Normal"] = None) -> "Normal":
"""
Expands the distribution parameters to a new batch shape.
Args:
batch_shape (torch.Size): Desired batch size for the expanded distribution.
_instance (Optional): Instance to check for validity.
Returns:
Normal: A new Normal distribution with parameters expanded to the specified batch shape.
"""
new = self._get_checked_instance(Normal, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
super(Normal, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
[docs] def icdf(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the inverse cumulative distribution function (quantile function) at a given value.
Args:
value (torch.Tensor): The probability value at which to evaluate the inverse CDF.
Returns:
torch.Tensor: The quantile corresponding to `value`.
"""
return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Computes the log probability density of the distribution at a given value.
Args:
value (torch.Tensor): The value at which to evaluate the log probability.
Returns:
torch.Tensor: The log probability density at `value`.
"""
var = self.scale**2
log_scale = self.scale.log() if not isinstance(self.scale, Real) else math.log(self.scale)
return -((value - self.loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
def _transform(self, z: torch.Tensor) -> torch.Tensor:
"""
Transforms an input tensor `z` to a standardized form based on the mean and scale.
Args:
z (torch.Tensor): Input tensor to transform.
Returns:
torch.Tensor: The transformed tensor, representing the standardized normal form.
"""
return (z - self.loc) / self.scale
def _d_transform_d_z(self) -> torch.Tensor:
"""
Computes the derivative of the transform function with respect to `z`.
Returns:
torch.Tensor: The reciprocal of the scale, representing the gradient for reparameterization.
"""
return 1 / self.scale
[docs] def sample(self, sample_shape: torch.Size = default_size) -> torch.Tensor:
"""
Generates a sample from the Normal distribution using `torch.normal`.
Args:
sample_shape (torch.Size): Shape of the sample to generate.
Returns:
torch.Tensor: A tensor with samples from the Normal distribution, detached from the computation graph.
"""
shape = self._extended_shape(sample_shape)
with torch.no_grad():
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
[docs] def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
"""
Generates a reparameterized sample from the Normal distribution, enabling gradient backpropagation.
Returns:
torch.Tensor: A tensor containing a reparameterized sample, useful for gradient-based optimization.
"""
# Sample a point from the distribution
x = self.sample(sample_shape)
# Transform the sample to standard normal form
transform = self._transform(x)
# Compute a surrogate value for backpropagation
surrogate_x = -transform / self._d_transform_d_z().detach()
# Return the sample with gradient tracking enabled
return x + (surrogate_x - surrogate_x.detach())
[docs]class MixtureSameFamily(torch.distributions.MixtureSameFamily):
"""
Represents a mixture of distributions from the same family.
Supporting reparameterized sampling for gradient-based optimization.
"""
has_rsample = True
def __init__(self, *args, **kwargs) -> None:
"""
Initializes the MixtureSameFamily distribution and checks if the component distributions.
Support reparameterized sampling (required for `rsample`).
Raises:
ValueError: If the component distributions do not support reparameterized sampling.
"""
super().__init__(*args, **kwargs)
if not self._component_distribution.has_rsample:
raise ValueError("Cannot reparameterize a mixture of non-reparameterizable components.")
# Define a list of discrete distributions for checking in `_log_cdf`
self.discrete_distributions: List[Distribution] = [
Bernoulli,
Binomial,
ContinuousBernoulli,
Geometric,
NegativeBinomial,
RelaxedBernoulli,
]
[docs] def rsample(self, sample_shape: torch.Size = default_size) -> torch.Tensor:
"""
Generates a reparameterized sample from the mixture of distributions.
This method generates a sample, applies a distributional transformation,
and computes a surrogate sample to enable gradient flow during optimization.
Args:
sample_shape (torch.Size): The shape of the sample to generate.
Returns:
torch.Tensor: A reparameterized sample with gradients enabled.
"""
# Generate a sample from the mixture distribution
x = self.sample(sample_shape=sample_shape)
event_size = math.prod(self.event_shape)
if event_size != 1:
# For multi-dimensional events, use reshaped distributional transformations
def reshaped_dist_trans(input_x: torch.Tensor) -> torch.Tensor:
return torch.reshape(self._distributional_transform(input_x), (-1, event_size))
def reshaped_dist_trans_summed(x_2d: torch.Tensor) -> torch.Tensor:
return torch.sum(reshaped_dist_trans(x_2d), dim=0)
x_2d = x.reshape((-1, event_size))
transform_2d = reshaped_dist_trans(x)
jac = jacobian(reshaped_dist_trans_summed, x_2d).detach().movedim(1, 0)
surrogate_x_2d = -torch.linalg.solve_triangular(jac.detach(), transform_2d[..., None], upper=False)
surrogate_x = surrogate_x_2d.reshape(x.shape)
else:
# For one-dimensional events, apply the standard distributional transformation
transform = self._distributional_transform(x)
log_prob_x = self.log_prob(x)
if self._event_ndims > 1:
log_prob_x = log_prob_x.reshape(log_prob_x.shape + (1,) * self._event_ndims)
surrogate_x = -transform * torch.exp(-log_prob_x.detach())
return x + (surrogate_x - surrogate_x.detach())
def _distributional_transform(self, x: torch.Tensor) -> torch.Tensor:
"""
Applies a distributional transformation to the input sample `x`, using cumulative
distribution functions (CDFs) and posterior weights.
Args:
x (torch.Tensor): The input sample to transform.
Returns:
torch.Tensor: The transformed tensor based on the mixture model's CDFs.
"""
if isinstance(self._component_distribution, torch.distributions.Independent):
univariate_components = self._component_distribution.base_dist
else:
univariate_components = self._component_distribution
# Expand input tensor and compute log-probabilities in each component
x = self._pad(x) # [S, B, 1, E]
log_prob_x = univariate_components.log_prob(x) # [S, B, K, E]
event_size = math.prod(self.event_shape)
if event_size != 1:
# CDF transformation for multi-dimensional events
cumsum_log_prob_x = log_prob_x.reshape(-1, event_size)
cumsum_log_prob_x = torch.cumsum(cumsum_log_prob_x, dim=-1)
cumsum_log_prob_x = cumsum_log_prob_x.roll(shifts=1, dims=-1)
cumsum_log_prob_x[:, 0] = 0
cumsum_log_prob_x = cumsum_log_prob_x.reshape(log_prob_x.shape)
logits_mix_prob = self._pad_mixture_dimensions(self._mixture_distribution.logits)
log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x
component_axis = -self._event_ndims - 1
cdf_x = univariate_components.cdf(x)
posterior_weights_x = torch.softmax(log_posterior_weights_x, dim=component_axis)
else:
# CDF transformation for one-dimensional events
log_posterior_weights_x = self._mixture_distribution.logits
component_axis = -self._event_ndims - 1
cdf_x = univariate_components.cdf(x)
posterior_weights_x = torch.softmax(log_posterior_weights_x, dim=-1)
posterior_weights_x = self._pad_mixture_dimensions(posterior_weights_x)
return torch.sum(posterior_weights_x * cdf_x, dim=component_axis)
def _log_cdf(self, x: torch.Tensor) -> torch.Tensor:
"""
Computes the logarithm of the cumulative distribution function (CDF) for the mixture distribution.
Args:
x (torch.Tensor): The input tensor for which to compute the log CDF.
Returns:
torch.Tensor: The log CDF values.
"""
x = self._pad(x)
if isinstance(self._component_distribution, torch.distributions.Independent):
univariate_components = self._component_distribution.base_dist
else:
univariate_components = self._component_distribution
if callable(getattr(univariate_components, "_log_cdf", None)):
log_cdf_x = univariate_components._log_cdf(x)
else:
log_cdf_x = torch.log(univariate_components.cdf(x))
if isinstance(univariate_components, tuple(self.discrete_distributions)):
log_mix_prob = torch.sigmoid(self._mixture_distribution.logits)
else:
log_mix_prob = F.log_softmax(self._mixture_distribution.logits, dim=-1)
return torch.logsumexp(log_cdf_x + log_mix_prob, dim=-1)
def _eval_poly(y: torch.Tensor, coef: torch.Tensor) -> torch.Tensor:
"""
Evaluate a polynomial at given points.
Args:
y: Input tensor.
coeffs: Polynomial coefficients.
Returns:
Evaluated polynomial tensor.
"""
coef = list(coef)
result = coef.pop()
while coef:
result = coef.pop() + y * result
return result
_I0_COEF_SMALL = [
1.0,
3.5156229,
3.0899424,
1.2067492,
0.2659732,
0.360768e-1,
0.45813e-2,
]
_I0_COEF_LARGE = [
0.39894228,
0.1328592e-1,
0.225319e-2,
-0.157565e-2,
0.916281e-2,
-0.2057706e-1,
0.2635537e-1,
-0.1647633e-1,
0.392377e-2,
]
_I1_COEF_SMALL = [
0.5,
0.87890594,
0.51498869,
0.15084934,
0.2658733e-1,
0.301532e-2,
0.32411e-3,
]
_I1_COEF_LARGE = [
0.39894228,
-0.3988024e-1,
-0.362018e-2,
0.163801e-2,
-0.1031555e-1,
0.2282967e-1,
-0.2895312e-1,
0.1787654e-1,
-0.420059e-2,
]
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
def _log_modified_bessel_fn(x: torch.Tensor, order: int = 0) -> torch.Tensor:
"""
Compute the logarithm of the modified Bessel function of the first kind.
Args:
x: Input tensor, must be positive.
order: Order of the Bessel function (0 or 1).
Returns:
Logarithm of the Bessel function.
"""
if order not in {0, 1}:
raise ValueError("Order must be 0 or 1.")
# compute small solution
y = x / 3.75
y = y * y
small = _eval_poly(y, _COEF_SMALL[order])
if order == 1:
small = x.abs() * small
small = small.log()
# compute large solution
y = 3.75 / x
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
result = torch.where(x < 3.75, small, large)
return result
@torch.jit.script_if_tracing
def _rejection_sample(
loc: torch.Tensor, concentration: torch.Tensor, proposal_r: torch.Tensor, x: torch.Tensor
) -> torch.Tensor:
"""
Perform rejection sampling for the von Mises distribution.
Args:
loc: Location parameter.
concentration: Concentration parameter.
proposal_r: Precomputed proposal parameter.
x: Tensor to fill with samples.
Returns:
Tensor of samples.
"""
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
z = torch.cos(math.pi * u1)
f = (1 + proposal_r * z) / (proposal_r + z)
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi
[docs]class VonMises(Distribution):
"""Von Mises distribution class for circular data."""
arg_constraints = {
"loc": constraints.real,
"concentration": constraints.positive,
}
support = constraints.real
has_rsample = True
def __init__(
self,
loc: torch.Tensor,
concentration: torch.Tensor,
validate_args: bool = None,
) -> None:
"""
Args:
loc: loc parameter of the distribution.
concentration: concentration parameter of the distribution.
validate_args: If True, checks the distribution parameters for validity.
"""
self.loc, self.concentration = broadcast_all(loc, concentration)
batch_shape = self.loc.shape
super().__init__(batch_shape, torch.Size(), validate_args)
@lazy_property
@torch.no_grad()
def _proposal_r(self) -> torch.Tensor:
"""Compute the proposal parameter for sampling."""
kappa = self._concentration
tau = 1 + (1 + 4 * kappa**2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
_proposal_r = (1 + rho**2) / (2 * rho)
# second order Taylor expansion around 0 for small kappa
_proposal_r_taylor = 1 / kappa + kappa
return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
[docs] def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute the log probability of the given value.
Args:
value: Tensor of values.
Returns:
Tensor of log probabilities.
"""
if self._validate_args:
self._validate_sample(value)
log_prob = self.concentration * torch.cos(value - self.loc)
log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
return log_prob
@lazy_property
def _loc(self) -> torch.Tensor:
return self.loc.to(torch.double)
@lazy_property
def _concentration(self) -> torch.Tensor:
return self.concentration.to(torch.double)
@torch.no_grad()
def sample(self, sample_shape: _size = default_size) -> torch.Tensor:
"""
The sampling algorithm for the von Mises distribution is based on the
following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
von Mises distribution." Applied Statistics (1979): 152-157.
Sampling is always done in double precision internally to avoid a hang
in _rejection_sample() for small values of the concentration, which
starts to happen for single precision around 1e-4 (see issue #88443).
"""
shape = self._extended_shape(sample_shape)
x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
return _rejection_sample(self._loc, self._concentration, self._proposal_r, x).to(self.loc.dtype)
[docs] def rsample(self, sample_shape: _size = default_size) -> torch.Tensor:
"""Generate reparameterized samples from the distribution"""
shape = self._extended_shape(sample_shape)
samples = _VonMisesSampler.apply(self.concentration, self._proposal_r, shape)
samples = samples + self.loc
# Map the samples to [-pi, pi].
return samples - 2.0 * torch.pi * torch.round(samples / (2.0 * torch.pi))
@property
def mean(self) -> torch.Tensor:
"""Mean of the distribution."""
return self.loc
@property
def variance(self) -> torch.Tensor:
"""Variance of the distribution."""
return (
1
- (
_log_modified_bessel_fn(self.concentration, order=1)
- _log_modified_bessel_fn(self.concentration, order=0)
).exp()
)
@torch.jit.script_if_tracing
@torch.no_grad()
def _rejection_rsample(concentration: torch.Tensor, proposal_r: torch.Tensor, shape: torch.Size) -> torch.Tensor:
"""
Perform rejection sampling to draw samples from the von Mises distribution.
Args:
concentration (torch.Tensor): Concentration parameter (kappa) of the distribution.
proposal_r (torch.Tensor): Proposal distribution parameter.
shape (torch.Size): Desired shape of the samples.
Returns:
torch.Tensor: Samples from the von Mises distribution.
"""
x = torch.empty(shape, dtype=concentration.dtype, device=concentration.device)
done = torch.zeros(x.shape, dtype=torch.bool, device=concentration.device)
while not done.all():
u = torch.rand((3,) + x.shape, dtype=concentration.dtype, device=concentration.device)
u1, u2, u3 = u.unbind()
z = torch.cos(math.pi * u1)
f = (1 + proposal_r * z) / (proposal_r + z)
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return x
[docs]def cosxm1(x: torch.Tensor) -> torch.Tensor:
"""
Compute cos(x) - 1 using a numerically stable formula.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor, `cos(x) - 1`.
"""
return -2 * torch.square(torch.sin(x / 2.0))
class _VonMisesSampler(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
concentration: torch.Tensor,
proposal_r: torch.Tensor,
shape: torch.Size,
) -> torch.Tensor:
"""
Perform forward sampling using rejection sampling.
Args:
ctx (torch.autograd.function.FunctionCtx): Context object for saving tensors.
concentration (torch.Tensor): Concentration parameter (kappa).
proposal_r (torch.Tensor): Proposal distribution parameter.
shape (torch.Size): Desired shape of the samples.
Returns:
torch.Tensor: Samples from the von Mises distribution.
"""
samples = _rejection_rsample(concentration, proposal_r, shape)
ctx.save_for_backward(concentration, proposal_r, samples)
return samples
@staticmethod
@torch.autograd.function.once_differentiable
def backward(
ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
) -> Tuple[torch.Tensor, None, None]:
"""
Compute gradients for backward pass using implicit reparameterization.
Args:
ctx (torch.autograd.function.FunctionCtx): Context object containing saved tensors.
grad_output (torch.Tensor): Gradient of the loss with respect to the output.
Returns:
Tuple[torch.Tensor, None, None]: Gradients with respect to the input tensors.
"""
concentration, proposal_r, samples = ctx.saved_tensors
num_periods = torch.round(samples / (2.0 * torch.pi))
x_mapped = samples - (2.0 * torch.pi) * num_periods
# Parameters from the paper
ck = 10.5
num_terms = 20
# Compute series and normal approximation
cdf_series, dcdf_dconcentration_series = von_mises_cdf_series(x_mapped, concentration, num_terms)
cdf_normal, dcdf_dconcentration_normal = von_mises_cdf_normal(x_mapped, concentration)
use_series = concentration < ck
# cdf = torch.where(use_series, cdf_series, cdf_normal) + num_periods
dcdf_dconcentration = torch.where(use_series, dcdf_dconcentration_series, dcdf_dconcentration_normal)
# Compute CDF gradient terms
inv_prob = torch.exp(concentration * cosxm1(samples)) / (2 * math.pi * torch.special.i0e(concentration))
grad_concentration = grad_output * (-dcdf_dconcentration / inv_prob)
return grad_concentration, None, None
[docs]def von_mises_cdf_series(
x: torch.Tensor, concentration: torch.Tensor, num_terms: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the CDF of the von Mises distribution using a series approximation.
Args:
x (torch.Tensor): Input tensor.
concentration (torch.Tensor): Concentration parameter (kappa).
num_terms (int): Number of terms in the series.
Returns:
Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration.
"""
vn = torch.zeros_like(x)
dvn_dconcentration = torch.zeros_like(x)
n = torch.tensor(num_terms, dtype=x.dtype, device=x.device)
rn = torch.zeros_like(x)
drn_dconcentration = torch.zeros_like(x)
while n > 0:
denominator = 2.0 * n / concentration + rn
ddenominator_dk = -2.0 * n / concentration**2 + drn_dconcentration
rn = 1.0 / denominator
drn_dconcentration = -ddenominator_dk / denominator**2
multiplier = torch.sin(n * x) / n + vn
vn = rn * multiplier
dvn_dconcentration = drn_dconcentration * multiplier + rn * dvn_dconcentration
n -= 1
cdf = 0.5 + x / (2.0 * torch.pi) + vn / torch.pi
dcdf_dconcentration = dvn_dconcentration / torch.pi
cdf_clipped = torch.clamp(cdf, 0.0, 1.0)
dcdf_dconcentration *= (cdf >= 0.0) & (cdf <= 1.0)
return cdf_clipped, dcdf_dconcentration
[docs]def cdf_func(concentration: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
Approximate the CDF of the von Mises distribution.
Args:
concentration (torch.Tensor): Concentration parameter (kappa).
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Approximate CDF values.
"""
# Calculate the z value based on the approximation
z = (torch.sqrt(torch.tensor(2.0 / torch.pi)) / torch.special.i0e(concentration)) * torch.sin(0.5 * x)
# Apply corrections to z to improve the approximation
z2 = z**2
z3 = z2 * z
z4 = z2**2
c = 24.0 * concentration
c1 = 56.0
xi = z - z3 / (((c - 2.0 * z2 - 16.0) / 3.0) - (z4 + (7.0 / 4.0) * z2 + 167.0 / 2.0) / (c - c1 - z2 + 3.0)) ** 2
# Use the standard normal distribution for the approximation
distrib = torch.distributions.Normal(
torch.tensor(0.0, dtype=x.dtype, device=x.device), torch.tensor(1.0, dtype=x.dtype, device=x.device)
)
return distrib.cdf(xi)
[docs]def von_mises_cdf_normal(x: torch.Tensor, concentration: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute the CDF of the von Mises distribution using a normal approximation.
Args:
x (torch.Tensor): Input tensor.
concentration (torch.Tensor): Concentration parameter (kappa).
Returns:
Tuple[torch.Tensor, torch.Tensor]: CDF and its gradient with respect to concentration.
"""
with torch.enable_grad():
concentration_ = concentration.detach().clone().requires_grad_(True)
cdf = cdf_func(concentration_, x)
cdf.backward(torch.ones_like(cdf)) # Compute gradients
dcdf_dconcentration = concentration_.grad.clone() # Copy the gradient
# Detach gradients to prevent further autograd tracking
concentration_.grad = None
return cdf, dcdf_dconcentration