Implicit Reparameterization Trick
A PyTorch library implementing implicit reparameterization gradients for continuous distributions that lack tractable inverse CDFs, based on Figurnov et al. (NeurIPS 2018).
Implemented distributions: Normal, Gamma, Beta, Dirichlet, StudentT, VonMises, MixtureSameFamily, ImplicitReparam (universal CDF wrapper).
Installation
git clone https://github.com/intsystems/implicit-reparameterization-trick.git
cd implicit-reparameterization-trick
pip install src/
Quick Start
Reparameterized sampling from a Beta distribution:
import torch
from irt.distributions import Beta
alpha = torch.tensor([2.0], requires_grad=True)
beta = torch.tensor([5.0], requires_grad=True)
dist = Beta(alpha, beta)
z = dist.rsample(torch.Size([64]))
Wrapping any distribution with a tractable CDF:
import torch
from irt.distributions import ImplicitReparam
loc = torch.tensor(0.0, requires_grad=True)
base = torch.distributions.Laplace(loc, 1.0)
dist = ImplicitReparam(base)
z = dist.rsample(torch.Size([64]))
Mixture of distributions:
import torch
from torch.distributions import Categorical
from irt.distributions import Normal, MixtureSameFamily
mix_weights = Categorical(torch.tensor([0.3, 0.7]))
components = Normal(
torch.tensor([-1.0, 1.0], requires_grad=True),
torch.tensor([0.5, 0.5]),
)
mixture = MixtureSameFamily(mix_weights, components)
z = mixture.rsample(torch.Size([64]))