Coverage for bmm_multitask_learning/variational/distr.py: 100%
17 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
1"""Utils for working with distributions
2"""
3from typing import Callable
5import torch
6from torch import distributions as distr
8# conditional distribution for targets
9type TargetDistr = Callable[[torch.Tensor, torch.Tensor], distr.Distribution]
10# conditional distribution for latents
11type LatentDistr = Callable[[torch.Tensor], distr.Distribution]
12# conditional distribution for targets, but batched latents and classifiers
13# must be flattened in one dimension. This is needed because of MixtureSameFamily design.
14# See https://github.com/pytorch/pytorch/issues/76709 for future possible automation
15type PredictiveDistr = Callable[[torch.Tensor, torch.Tensor], distr.Distribution]
18def kl_sample_estimation(
19 distr_1: distr.Distribution,
20 distr_2: distr.Distribution,
21 num_particles: int = 1
22) -> torch.Tensor:
23 """Make sample estimation of the KL divirgence
25 Args:
26 num_particles (int, optional): number of samples for estimation. Defaults to 1.
27 """
28 samples = distr_1.rsample([num_particles])
29 log_p_1 = distr_1.log_prob(samples)
30 log_p_2 = distr_2.log_prob(samples)
32 return (log_p_1 - log_p_2).mean()
35def build_predictive(
36 pred_distr: PredictiveDistr,
37 classifier_distr: distr.Distribution,
38 latent_distr: LatentDistr,
39 X: torch.Tensor,
40 classifier_num_particles: int = 1,
41 latent_num_particles: int = 1
42) -> distr.MixtureSameFamily:
43 """Constructs torch.distribution as an approximation to the true predictive distribution
44 (in bayessian sense) using variational distributions
46 Args:
47 pred_distr (PredictiveDistr): see MultiTaskElbo
48 classifier_distr (distr.Distribution): see MultiTaskElbo
49 latent_distr (LatentDistr): see MultiTaskElbo
50 X (torch.Tensor): new inputs for which to build predictive distr
51 classifier_num_particles (int, optional): see MultiTaskElbo. Defaults to 1.
52 latent_num_particles (int, optional): see MultiTaskElbo. Defaults to 1.
54 Returns:
55 distr.MixtureSameFamily: the predictive distr can be seen as mixture distr
56 """
57 # sample hidden state (classifier + latent) from posterior
58 classifier_samples = classifier_distr.sample((classifier_num_particles, ))
59 latent_samples = latent_distr(X).sample((latent_num_particles, )).swapaxes(0, 1)
60 # build conditional distribution objects for target
61 pred_distr = pred_distr(latent_samples, classifier_samples)
63 mixing_distr = distr.Categorical(torch.ones(pred_distr.batch_shape))
65 return distr.MixtureSameFamily(mixing_distr, pred_distr)