relaxit.distributions

Distribution Classes

class relaxit.distributions.CorrelatedRelaxedBernoulli.CorrelatedRelaxedBernoulli(pi: Tensor, R: Tensor, tau: Tensor, validate_args: bool = None)[source]

Bases: TorchDistribution

Correlated Relaxed Bernoulli distribution class from https://openreview.net/pdf?id=oDFvtxzPOx.

Parameters:
  • pi (torch.Tensor) – Selection probability vector.

  • R (torch.Tensor) – Covariance matrix.

  • tau (torch.Tensor) – Temperature hyper-parameter.

arg_constraints = {'R': PositiveDefinite(), 'pi': Interval(lower_bound=0, upper_bound=1), 'tau': GreaterThan(lower_bound=0.0)}
property batch_shape: Size

Returns the batch shape of the distribution.

The batch shape represents the shape of independent distributions. For example, if pi is a tensor of shape (batch_size, pi_shape), the batch shape will be [batch_size], indicating batch_size independent Bernoulli distributions.

Returns:

The batch shape of the distribution.

Return type:

torch.Size

property event_shape: Size

Returns the event shape of the distribution.

The event shape represents the shape of each individual event. For example, if pi is a tensor of shape (batch_size, pi_shape), the event shape will be [pi_shape].

Returns:

The event shape of the distribution.

Return type:

torch.Size

has_rsample = True
log_prob(value: Tensor) Tensor[source]

Computes the log probability of the given value.

Parameters:

value (torch.Tensor) – The value for which to compute the log probability.

Returns:

The log probability of the given value.

Return type:

torch.Tensor

rsample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution using the reparameterization trick.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

sample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

support = Interval(lower_bound=0, upper_bound=1)
class relaxit.distributions.GaussianRelaxedBernoulli.GaussianRelaxedBernoulli(loc: Tensor, scale: Tensor, validate_args: bool = None)[source]

Bases: TorchDistribution

Gaussian-based continuous Relaxed Bernoulli distribution class from https://arxiv.org/abs/1810.04247.

Parameters:
  • loc (torch.Tensor) – Mean of the normal distribution.

  • scale (torch.Tensor) – Standard deviation of the normal distribution.

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
property batch_shape: Size

Returns the batch shape of the distribution.

The batch shape represents the shape of independent distributions. For example, if loc is a vector of length 3, the batch shape will be [3], indicating 3 independent Bernoulli distributions.

Returns:

The batch shape of the distribution.

Return type:

torch.Size

property event_shape: Size

Returns the event shape of the distribution.

The event shape represents the shape of each individual event.

Returns:

The event shape of the distribution.

Return type:

torch.Size

has_rsample = True
log_prob(value: Tensor) Tensor[source]

Computes the log probability of the given value.

Parameters:

value (torch.Tensor) – The value for which to compute the log probability.

Returns:

The log probability of the given value.

Return type:

torch.Tensor

rsample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution using the reparameterization trick.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

sample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

support = Real()
class relaxit.distributions.GumbelSoftmaxTopK.GumbelSoftmaxTopK(probs: Tensor = None, logits: Tensor = None, K: Tensor = tensor(1), tau: Tensor = tensor(0.1000), hard: bool = True, validate_args: bool = None)[source]

Bases: TorchDistribution

Implementation of the Gaussian-Softmax TOP-K trick from https://arxiv.org/pdf/1903.06059.

Parameters:
  • probs (torch.Tensor, optional) – Probabilities of the categories.

  • logits (torch.Tensor, optional) – Logits of the categories.

  • K (torch.Tensor, optional) – How many samples without replacement to pick. Defaults to 1.

  • tau (torch.Tensor, optional) – Temperature hyper-parameter. Defaults to 0.1.

  • hard (bool, optional) – If True, the returned samples will be discretized as one-hot vectors, but will be differentiated as if it is the soft sample in autograd. Defaults to True.

  • validate_args (bool, optional) – Whether to validate arguments. Defaults to None.

arg_constraints = {'K': IntegerGreaterThan(lower_bound=1), 'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'tau': GreaterThan(lower_bound=0.0)}
property batch_shape: Size

Returns the batch shape of the distribution.

The batch shape represents the shape of independent distributions. For example, if probs is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.

Returns:

The batch shape of the distribution.

Return type:

torch.Size

property event_shape: Size

Returns the event shape of the distribution.

The event shape represents the shape of each individual event.

Returns:

The event shape of the distribution.

Return type:

torch.Size

has_rsample = True
log_prob(value: Tensor) Tensor[source]

Computes the log probability of the given value.

Parameters:

value (torch.Tensor) – The value for which to compute the log probability.

Returns:

The log probability of the given value.

Return type:

torch.Tensor

rsample() Tensor[source]

Generates a sample from the distribution using the Gaussian-softmax top-K trick.

Returns:

A sample from the distribution.

Return type:

torch.Tensor

sample() Tensor[source]

Generates a sample from the distribution with no grad.

Returns:

A sample from the distribution.

Return type:

torch.Tensor

class relaxit.distributions.HardConcrete.HardConcrete(alpha: Tensor, beta: Tensor, xi: Tensor, gamma: Tensor, validate_args: bool = None)[source]

Bases: TorchDistribution

HardConcrete distribution class from https://arxiv.org/abs/1712.01312.

Parameters:
  • alpha (torch.Tensor) – Parameter alpha.

  • beta (torch.Tensor) – Parameter beta.

  • xi (torch.Tensor) – Parameter xi.

  • gamma (torch.Tensor) – Parameter gamma.

  • validate_args (bool, optional) – Whether to validate arguments. Defaults to None.

arg_constraints = {'alpha': GreaterThan(lower_bound=0.0), 'beta': GreaterThan(lower_bound=0.0), 'gamma': LessThan(upper_bound=0.0), 'xi': GreaterThan(lower_bound=1.0)}
property batch_shape: Size

Returns the batch shape of the distribution.

The batch shape represents the shape of independent distributions. For example, if alpha is a vector of length 3, the batch shape will be [3], indicating 3 independent distributions.

Returns:

The batch shape of the distribution.

Return type:

torch.Size

property event_shape: Size

Returns the event shape of the distribution.

The event shape represents the shape of each individual event.

Returns:

The event shape of the distribution.

Return type:

torch.Size

has_rsample = True
log_prob(value: Tensor) Tensor[source]

Computes the log probability of the given value.

Parameters:

value (torch.Tensor) – The value for which to compute the log probability.

Returns:

The log probability of the given value.

Return type:

torch.Tensor

rsample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution using the reparameterization trick.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

sample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

support = Real()
class relaxit.distributions.InvertibleGaussian.InvertibleGaussian(loc, scale, temperature, validate_args: bool = None)[source]

Bases: TorchDistribution

Invertible Gaussian distribution class from https://arxiv.org/abs/1912.09588.

Parameters:
  • loc (torch.Tensor) – The mean (mu) of the normal distribution.

  • scale (torch.Tensor) – The standard deviation (sigma) of the normal distribution.

  • temperature (float) – Temperature parameter for the softmax++ function.

  • validate_args (bool, optional) – Whether to validate arguments. Defaults to None.

arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
property batch_shape: Size

Returns the batch shape of the distribution.

The batch shape represents the shape of independent distributions.

Returns:

The batch shape of the distribution.

Return type:

torch.Size

property event_shape: Size

Returns the event shape of the distribution.

The event shape represents the shape of each individual event.

Returns:

The event shape of the distribution.

Return type:

torch.Size

has_rsample = True
log_prob(value: Tensor) Tensor[source]

Computes the log likelihood of a value.

Parameters:

value (torch.Tensor) – The value for which to compute the log probability.

Returns:

The log probability of the given value.

Return type:

torch.Tensor

rsample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution using the reparameterization trick.

Parameters:

sample_shape (torch.Size, optional) – The shape of the generated samples. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

softmax_plus_plus(y: Tensor, delta: float = 1) Tensor[source]

Computes the softmax++ function.

Parameters:
  • y (torch.Tensor) – Input tensor of shape (batch_size, num_classes).

  • delta (float, optional) – Additional term delta > 0. Defaults to 1.

Returns:

Output tensor of the same shape as y.

Return type:

torch.Tensor

support = Real()
class relaxit.distributions.LogisticNormalSoftmax.LogisticNormalSoftmax(loc, scale, validate_args=None)[source]

Bases: TransformedDistribution

Creates a logistic-normal distribution parameterized by loc and scale that define the base Normal distribution transformed with the SoftmaxTransform such that:

X ~ LogisticNormal(loc, scale) Y = Logistic(X) ~ Normal(loc, scale)

Parameters:
  • loc (float or torch.Tensor) – Mean of the base distribution.

  • scale (float or torch.Tensor) – Standard deviation of the base distribution.

  • validate_args (bool, optional) – Whether to validate arguments. Defaults to None.

arg_constraints: Dict[str, constraints.Constraint] = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
expand(batch_shape, _instance=None)[source]

Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape.

Parameters:
  • batch_shape (torch.Size) – The desired expanded size.

  • _instance (LogisticNormalSoftmax, optional) – New instance of the distribution to populate. Defaults to None.

Returns:

New distribution instance with batch dimensions expanded to batch_shape.

Return type:

LogisticNormalSoftmax

has_rsample = True
property loc

Returns the location (mean) of the base distribution.

Returns:

The location of the base distribution.

Return type:

float or torch.Tensor

property scale

Returns the scale (standard deviation) of the base distribution.

Returns:

The scale of the base distribution.

Return type:

float or torch.Tensor

support = Simplex()
class relaxit.distributions.StochasticTimesSmooth.StochasticTimesSmooth(*args, **kwargs)[source]

Bases: Bernoulli

Implementation of the Stochastic Times Smooth from https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=62c76ca0b2790c34e85ba1cce09d47be317c7235.

Creates a Bernoulli distribution parameterized by probs or logits (but not both).

Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.

However, supports gradient flow through parameters due to the stochastic times smooth gradient estimator.

Parameters:
  • probs (torch.Tensor, optional) – Event probabilities.

  • logits (torch.Tensor, optional) – Event log-odds.

  • validate_args (bool, optional) – Whether to validate arguments. Defaults to None.

has_rsample = True
rsample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution using the reparameterization trick.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

class relaxit.distributions.StraightThroughBernoulli.StraightThroughBernoulli(*args, **kwargs)[source]

Bases: Bernoulli

Implementation of the Straight Through Bernoulli from https://arxiv.org/abs/1910.02176.

Creates a Bernoulli distribution parameterized by probs or logits (but not both).

Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.

However, supports gradient flow through parameters due to the straight through gradient estimator.

Parameters:
  • probs (torch.Tensor, optional) – Event probabilities.

  • logits (torch.Tensor, optional) – Event log-odds.

  • validate_args (bool, optional) – Whether to validate arguments. Defaults to None.

has_rsample = True
rsample(sample_shape: Size = ()) Tensor[source]

Generates a sample from the distribution using the reparameterization trick.

Parameters:

sample_shape (torch.Size, optional) – The shape of the sample. Defaults to torch.Size().

Returns:

A sample from the distribution.

Return type:

torch.Tensor

Utility Modules

relaxit.distributions.approx.dirichlet_approximation_fn(lognorm_distribution: LogisticNormalSoftmax) Dirichlet[source]

Approximates a LogisticNormalSoftmax distribution with a Dirichlet distribution.

Parameters:

lognorm_distribution (LogisticNormalSoftmax) – The LogisticNormalSoftmax distribution to approximate.

Returns:

The approximated Dirichlet distribution.

Return type:

Dirichlet

relaxit.distributions.approx.lognorm_approximation_fn(dirichlet_distribution: Dirichlet) LogisticNormalSoftmax[source]

Approximates a Dirichlet distribution with a LogisticNormalSoftmax distribution.

Parameters:

dirichlet_distribution (Dirichlet) – The Dirichlet distribution to approximate.

Returns:

The approximated LogisticNormalSoftmax distribution.

Return type:

LogisticNormalSoftmax