relaxit.distributions
Distribution Classes
Bases:
TorchDistributionCorrelated Relaxed Bernoulli distribution class from https://openreview.net/pdf?id=oDFvtxzPOx.
- Parameters:
pi (torch.Tensor) – Selection probability vector.
R (torch.Tensor) – Covariance matrix.
tau (torch.Tensor) – Temperature hyper-parameter.
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if pi is a tensor of shape (batch_size, pi_shape), the batch shape will be [batch_size], indicating batch_size independent Bernoulli distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event. For example, if pi is a tensor of shape (batch_size, pi_shape), the event shape will be [pi_shape].
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- class relaxit.distributions.GaussianRelaxedBernoulli.GaussianRelaxedBernoulli(loc: Tensor, scale: Tensor, validate_args: bool = None)[source]
Bases:
TorchDistributionGaussian-based continuous Relaxed Bernoulli distribution class from https://arxiv.org/abs/1810.04247.
- Parameters:
loc (torch.Tensor) – Mean of the normal distribution.
scale (torch.Tensor) – Standard deviation of the normal distribution.
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if loc is a vector of length 3, the batch shape will be [3], indicating 3 independent Bernoulli distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- sample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- support = Real()
- class relaxit.distributions.GumbelSoftmaxTopK.GumbelSoftmaxTopK(probs: Tensor = None, logits: Tensor = None, K: Tensor = tensor(1), tau: Tensor = tensor(0.1000), hard: bool = True, validate_args: bool = None)[source]
Bases:
TorchDistributionImplementation of the Gaussian-Softmax TOP-K trick from https://arxiv.org/pdf/1903.06059.
- Parameters:
probs (torch.Tensor, optional) – Probabilities of the categories.
logits (torch.Tensor, optional) – Logits of the categories.
K (torch.Tensor, optional) – How many samples without replacement to pick. Defaults to 1.
tau (torch.Tensor, optional) – Temperature hyper-parameter. Defaults to 0.1.
hard (bool, optional) – If True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Defaults to True.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'K': IntegerGreaterThan(lower_bound=1), 'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'tau': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if probs is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- class relaxit.distributions.HardConcrete.HardConcrete(alpha: Tensor, beta: Tensor, xi: Tensor, gamma: Tensor, validate_args: bool = None)[source]
Bases:
TorchDistributionHardConcrete distribution class from https://arxiv.org/abs/1712.01312.
- Parameters:
alpha (torch.Tensor) – Parameter alpha.
beta (torch.Tensor) – Parameter beta.
xi (torch.Tensor) – Parameter xi.
gamma (torch.Tensor) – Parameter gamma.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'alpha': GreaterThan(lower_bound=0.0), 'beta': GreaterThan(lower_bound=0.0), 'gamma': LessThan(upper_bound=0.0), 'xi': GreaterThan(lower_bound=1.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if alpha is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- sample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- support = Real()
- class relaxit.distributions.InvertibleGaussian.InvertibleGaussian(loc, scale, temperature, validate_args: bool = None)[source]
Bases:
TorchDistributionInvertible Gaussian distribution class from https://arxiv.org/abs/1912.09588.
- Parameters:
loc (torch.Tensor) – The mean (mu) of the normal distribution.
scale (torch.Tensor) – The standard deviation (sigma) of the normal distribution.
temperature (float) – Temperature parameter for the softmax++ function.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log likelihood of a value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the generated samples. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- softmax_plus_plus(y: Tensor, delta: float = 1) Tensor[source]
Computes the softmax++ function.
- Parameters:
y (torch.Tensor) – Input tensor of shape (batch_size, num_classes).
delta (float, optional) – Additional term delta > 0. Defaults to 1.
- Returns:
Output tensor of the same shape as y.
- Return type:
torch.Tensor
- support = Real()
- class relaxit.distributions.LogisticNormalSoftmax.LogisticNormalSoftmax(loc, scale, validate_args=None)[source]
Bases:
TransformedDistributionCreates a logistic-normal distribution parameterized by
locandscalethat define the base Normal distribution transformed with the SoftmaxTransform such that:X ~ LogisticNormal(loc, scale) Y = Logistic(X) ~ Normal(loc, scale)
- Parameters:
loc (float or torch.Tensor) – Mean of the base distribution.
scale (float or torch.Tensor) – Standard deviation of the base distribution.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints: dict[str, Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
- expand(batch_shape, _instance=None)[source]
Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape.
- Parameters:
batch_shape (torch.Size) – The desired expanded size.
_instance (LogisticNormalSoftmax, optional) – New instance of the distribution to populate. Defaults to None.
- Returns:
New distribution instance with batch dimensions expanded to batch_shape.
- Return type:
- has_rsample = True
- property loc
Returns the location (mean) of the base distribution.
- Returns:
The location of the base distribution.
- Return type:
float or torch.Tensor
- property scale
Returns the scale (standard deviation) of the base distribution.
- Returns:
The scale of the base distribution.
- Return type:
float or torch.Tensor
- support = Simplex()
- class relaxit.distributions.StochasticTimesSmooth.StochasticTimesSmooth(*args, **kwargs)[source]
Bases:
BernoulliImplementation of the Stochastic Times Smooth from https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235.
Creates a Bernoulli distribution parameterized by
probsorlogits(but not both).Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
However, supports gradient flow through parameters due to the stochastic times smooth gradient estimator.
- Parameters:
probs (torch.Tensor, optional) – Event probabilities.
logits (torch.Tensor, optional) – Event log-odds.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- has_rsample = True
- class relaxit.distributions.StraightThroughBernoulli.StraightThroughBernoulli(*args, **kwargs)[source]
Bases:
BernoulliImplementation of the Straight Through Bernoulli from https://arxiv.org/abs/1910.02176.
Creates a Bernoulli distribution parameterized by
probsorlogits(but not both).Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
However, supports gradient flow through parameters due to the straight through gradient estimator.
- Parameters:
probs (torch.Tensor, optional) – Event probabilities.
logits (torch.Tensor, optional) – Event log-odds.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- has_rsample = True
- class relaxit.distributions.GeneralizedGumbelSoftmax.GeneralizedGumbelSoftmax(values: Tensor, probs: Tensor = None, logits: Tensor = None, tau: Tensor = tensor(0.5000), hard: bool = False, validate_args: bool = None)[source]
Bases:
TorchDistributionGeneralized Gumbel-Softmax from https://arxiv.org/abs/2003.01847.
- Parameters:
values (torch.Tensor) – Discrete support values.
probs (torch.Tensor, optional) – Category probabilities. Provide either probs or logits.
logits (torch.Tensor, optional) – Category logits. Provide either probs or logits.
tau (torch.Tensor, optional) – Temperature hyper-parameter. Defaults to 0.5.
hard (bool, optional) – If True, returned samples are discretized but differentiated as soft.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'tau': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns: torch.Size: Batch shape.
- property event_shape: Size
Returns: torch.Size: Event shape.
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log probability of a soft or hard one-hot value.
- Parameters:
value (torch.Tensor) – Soft or hard one-hot vector.
- Returns:
Log probability.
- Return type:
torch.Tensor
- rsample() Tensor[source]
Generates a reparameterized sample using the Gumbel-Softmax trick.
- Returns:
Soft or hard one-hot sample.
- Return type:
torch.Tensor
- rsample_value() Tensor[source]
Generates a reparameterized sample and projects it to value space.
- Returns:
Sampled scalar values.
- Return type:
torch.Tensor
- sample() Tensor[source]
Generates a non-differentiable sample.
- Returns:
Sample (no grad).
- Return type:
torch.Tensor
- class relaxit.distributions.DecoupledStraightThroughGumbelSoftmax.DecoupledStraightThroughGumbelSoftmax(temperature_forward, temperature_backward, logits=None, probs=None, validate_args=None)[source]
Bases:
TorchDistributionDecoupled Straight-Through Gumbel-Softmax distribution.
- This distribution uses two temperatures:
temperature_forward: for generating the hard (discrete) sample (forward pass).
temperature_backward: for computing smooth gradients (backward pass).
The output is a one-hot vector (hard sample), but gradients flow through a soft Gumbel-Softmax sample computed with a different (typically higher) temperature.
- Parameters:
temperature_forward (torch.Tensor) – Temperature for hard sampling (low, e.g., 0.1).
temperature_backward (torch.Tensor) – Temperature for gradient estimation (higher, e.g., 1.0).
logits (torch.Tensor, optional) – Event logits.
probs (torch.Tensor, optional) – Event probabilities.
validate_args (bool, optional) – Whether to validate distribution arguments.
- arg_constraints = {'logits': IndependentConstraint(Real(), 1), 'probs': Simplex(), 'temperature_backward': GreaterThan(lower_bound=0.0), 'temperature_forward': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if probs is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- enumerate_support(expand: bool = True) Tensor[source]
Enumerate all one-hot vectors in the support. Same as original Categorical support.
- Parameters:
expand (bool, optional) – Whether to expand the support. Defaults to True.
- Returns:
The enumerated support.
- Return type:
torch.Tensor
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_enumerate_support = True
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- property mean: Tensor
Returns the mean of the distribution.
- Returns:
The mean of the distribution.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor[source]
- Generates a decoupled straight-through sample:
Hard sample from Gumbel-Softmax with temperature_forward.
Gradient flows through soft sample from temperature_backward.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
One-hot sample with straight-through gradients.
- Return type:
torch.Tensor
- sample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution without gradients.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- support = Real()
- property variance: Tensor
Returns the variance of the distribution.
- Returns:
The variance of the distribution.
- Return type:
torch.Tensor
- class relaxit.distributions.RebarRelaxation.RebarRelaxation(theta: Tensor, lambd: Tensor, validate_args: bool = None)[source]
Bases:
TorchDistributionRebar continuous Relaxed Bernoulli distribution class from https://arxiv.org/pdf/1703.07370.
- Parameters:
lambd (torch.Tensor) – Gumbel-Softmax constant.
theta (torch.Tensor) – Mean of the Bernoulli distribution.
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if theta is a vector of length 3, the batch shape will be [3], indicating 3 independent Bernoulli distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor[source]
Computes the log probability density of the Relaxed Bernoulli (Concrete) distribution.
- Parameters:
value (torch.Tensor) – Values in (0, 1) for which to compute log probability.
- Returns:
Log probability density.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- sample(sample_shape: Size = ()) Tensor[source]
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- support = Real()
Utility Modules
- relaxit.distributions.approx.dirichlet_approximation_fn(lognorm_distribution: LogisticNormalSoftmax) Dirichlet[source]
Approximates a LogisticNormalSoftmax distribution with a Dirichlet distribution.
- Parameters:
lognorm_distribution (LogisticNormalSoftmax) – The LogisticNormalSoftmax distribution to approximate.
- Returns:
The approximated Dirichlet distribution.
- Return type:
Dirichlet
- relaxit.distributions.approx.lognorm_approximation_fn(dirichlet_distribution: Dirichlet) LogisticNormalSoftmax[source]
Approximates a Dirichlet distribution with a LogisticNormalSoftmax distribution.
- Parameters:
dirichlet_distribution (Dirichlet) – The Dirichlet distribution to approximate.
- Returns:
The approximated LogisticNormalSoftmax distribution.
- Return type: