relaxit.distributions
Distribution Classes
Bases:
TorchDistribution
Correlated Relaxed Bernoulli distribution class from https://openreview.net/pdf?id=oDFvtxzPOx.
- Parameters:
pi (torch.Tensor) – Selection probability vector.
R (torch.Tensor) – Covariance matrix.
tau (torch.Tensor) – Temperature hyper-parameter.
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if pi is a tensor of shape (batch_size, pi_shape), the batch shape will be [batch_size], indicating batch_size independent Bernoulli distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event. For example, if pi is a tensor of shape (batch_size, pi_shape), the event shape will be [pi_shape].
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- class relaxit.distributions.GaussianRelaxedBernoulli.GaussianRelaxedBernoulli(loc: Tensor, scale: Tensor, validate_args: bool = None)[source]
Bases:
TorchDistribution
Gaussian-based continuous Relaxed Bernoulli distribution class from https://arxiv.org/abs/1810.04247.
- Parameters:
loc (torch.Tensor) – Mean of the normal distribution.
scale (torch.Tensor) – Standard deviation of the normal distribution.
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if loc is a vector of length 3, the batch shape will be [3], indicating 3 independent Bernoulli distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor [source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor [source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- sample(sample_shape: Size = ()) Tensor [source]
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- support = Real()
- class relaxit.distributions.GumbelSoftmaxTopK.GumbelSoftmaxTopK(probs: Tensor = None, logits: Tensor = None, K: Tensor = tensor(1), tau: Tensor = tensor(0.1000), hard: bool = True, validate_args: bool = None)[source]
Bases:
TorchDistribution
Implementation of the Gaussian-Softmax TOP-K trick from https://arxiv.org/pdf/1903.06059.
- Parameters:
probs (torch.Tensor, optional) – Probabilities of the categories.
logits (torch.Tensor, optional) – Logits of the categories.
K (torch.Tensor, optional) – How many samples without replacement to pick. Defaults to 1.
tau (torch.Tensor, optional) – Temperature hyper-parameter. Defaults to 0.1.
hard (bool, optional) – If True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Defaults to True.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'K': IntegerGreaterThan(lower_bound=1), 'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'tau': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if probs is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor [source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- class relaxit.distributions.HardConcrete.HardConcrete(alpha: Tensor, beta: Tensor, xi: Tensor, gamma: Tensor, validate_args: bool = None)[source]
Bases:
TorchDistribution
HardConcrete distribution class from https://arxiv.org/abs/1712.01312.
- Parameters:
alpha (torch.Tensor) – Parameter alpha.
beta (torch.Tensor) – Parameter beta.
xi (torch.Tensor) – Parameter xi.
gamma (torch.Tensor) – Parameter gamma.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'alpha': GreaterThan(lower_bound=0.0), 'beta': GreaterThan(lower_bound=0.0), 'gamma': LessThan(upper_bound=0.0), 'xi': GreaterThan(lower_bound=1.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions. For example, if alpha is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor [source]
Computes the log probability of the given value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor [source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- sample(sample_shape: Size = ()) Tensor [source]
Generates a sample from the distribution.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- support = Real()
- class relaxit.distributions.InvertibleGaussian.InvertibleGaussian(loc, scale, temperature, validate_args: bool = None)[source]
Bases:
TorchDistribution
Invertible Gaussian distribution class from https://arxiv.org/abs/1912.09588.
- Parameters:
loc (torch.Tensor) – The mean (mu) of the normal distribution.
scale (torch.Tensor) – The standard deviation (sigma) of the normal distribution.
temperature (float) – Temperature parameter for the softmax++ function.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
- property batch_shape: Size
Returns the batch shape of the distribution.
The batch shape represents the shape of independent distributions.
- Returns:
The batch shape of the distribution.
- Return type:
torch.Size
- property event_shape: Size
Returns the event shape of the distribution.
The event shape represents the shape of each individual event.
- Returns:
The event shape of the distribution.
- Return type:
torch.Size
- has_rsample = True
- log_prob(value: Tensor) Tensor [source]
Computes the log likelihood of a value.
- Parameters:
value (torch.Tensor) – The value for which to compute the log probability.
- Returns:
The log probability of the given value.
- Return type:
torch.Tensor
- rsample(sample_shape: Size = ()) Tensor [source]
Generates a sample from the distribution using the reparameterization trick.
- Parameters:
sample_shape (torch.Size, optional) – The shape of the generated samples. Defaults to torch.Size().
- Returns:
A sample from the distribution.
- Return type:
torch.Tensor
- softmax_plus_plus(y: Tensor, delta: float = 1) Tensor [source]
Computes the softmax++ function.
- Parameters:
y (torch.Tensor) – Input tensor of shape (batch_size, num_classes).
delta (float, optional) – Additional term delta > 0. Defaults to 1.
- Returns:
Output tensor of the same shape as y.
- Return type:
torch.Tensor
- support = Real()
- class relaxit.distributions.LogisticNormalSoftmax.LogisticNormalSoftmax(loc, scale, validate_args=None)[source]
Bases:
TransformedDistribution
Creates a logistic-normal distribution parameterized by
loc
andscale
that define the base Normal distribution transformed with the SoftmaxTransform such that:X ~ LogisticNormal(loc, scale) Y = Logistic(X) ~ Normal(loc, scale)
- Parameters:
loc (float or torch.Tensor) – Mean of the base distribution.
scale (float or torch.Tensor) – Standard deviation of the base distribution.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- arg_constraints: Dict[str, constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
- expand(batch_shape, _instance=None)[source]
Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape.
- Parameters:
batch_shape (torch.Size) – The desired expanded size.
_instance (LogisticNormalSoftmax, optional) – New instance of the distribution to populate. Defaults to None.
- Returns:
New distribution instance with batch dimensions expanded to batch_shape.
- Return type:
- has_rsample = True
- property loc
Returns the location (mean) of the base distribution.
- Returns:
The location of the base distribution.
- Return type:
float or torch.Tensor
- property scale
Returns the scale (standard deviation) of the base distribution.
- Returns:
The scale of the base distribution.
- Return type:
float or torch.Tensor
- support = Simplex()
- class relaxit.distributions.StochasticTimesSmooth.StochasticTimesSmooth(*args, **kwargs)[source]
Bases:
Bernoulli
Implementation of the Stochastic Times Smooth from https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235.
Creates a Bernoulli distribution parameterized by
probs
orlogits
(but not both).Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
However, supports gradient flow through parameters due to the stochastic times smooth gradient estimator.
- Parameters:
probs (torch.Tensor, optional) – Event probabilities.
logits (torch.Tensor, optional) – Event log-odds.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- has_rsample = True
- class relaxit.distributions.StraightThroughBernoulli.StraightThroughBernoulli(*args, **kwargs)[source]
Bases:
Bernoulli
Implementation of the Straight Through Bernoulli from https://arxiv.org/abs/1910.02176.
Creates a Bernoulli distribution parameterized by
probs
orlogits
(but not both).Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.
However, supports gradient flow through parameters due to the straight through gradient estimator.
- Parameters:
probs (torch.Tensor, optional) – Event probabilities.
logits (torch.Tensor, optional) – Event log-odds.
validate_args (bool, optional) – Whether to validate arguments. Defaults to None.
- has_rsample = True
Utility Modules
- relaxit.distributions.approx.dirichlet_approximation_fn(lognorm_distribution: LogisticNormalSoftmax) Dirichlet [source]
Approximates a LogisticNormalSoftmax distribution with a Dirichlet distribution.
- Parameters:
lognorm_distribution (LogisticNormalSoftmax) – The LogisticNormalSoftmax distribution to approximate.
- Returns:
The approximated Dirichlet distribution.
- Return type:
Dirichlet
- relaxit.distributions.approx.lognorm_approximation_fn(dirichlet_distribution: Dirichlet) LogisticNormalSoftmax [source]
Approximates a Dirichlet distribution with a LogisticNormalSoftmax distribution.
- Parameters:
dirichlet_distribution (Dirichlet) – The Dirichlet distribution to approximate.
- Returns:
The approximated LogisticNormalSoftmax distribution.
- Return type: