Coverage for bmm_multitask_learning/variational/distr.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-13 13:33 +0000

1"""Utils for working with distributions 

2""" 

3from typing import Callable 

4 

5import torch 

6from torch import distributions as distr 

7 

8# conditional distribution for targets 

9type TargetDistr = Callable[[torch.Tensor, torch.Tensor], distr.Distribution] 

10# conditional distribution for latents 

11type LatentDistr = Callable[[torch.Tensor], distr.Distribution] 

12# conditional distribution for targets, but batched latents and classifiers 

13# must be flattened in one dimension. This is needed because of MixtureSameFamily design. 

14# See https://github.com/pytorch/pytorch/issues/76709 for future possible automation 

15type PredictiveDistr = Callable[[torch.Tensor, torch.Tensor], distr.Distribution] 

16 

17 

18def kl_sample_estimation( 

19 distr_1: distr.Distribution, 

20 distr_2: distr.Distribution, 

21 num_particles: int = 1 

22) -> torch.Tensor: 

23 """Make sample estimation of the KL divirgence 

24 

25 Args: 

26 num_particles (int, optional): number of samples for estimation. Defaults to 1. 

27 """ 

28 samples = distr_1.rsample([num_particles]) 

29 log_p_1 = distr_1.log_prob(samples) 

30 log_p_2 = distr_2.log_prob(samples) 

31 

32 return (log_p_1 - log_p_2).mean() 

33 

34 

35def build_predictive( 

36 pred_distr: PredictiveDistr, 

37 classifier_distr: distr.Distribution, 

38 latent_distr: LatentDistr, 

39 X: torch.Tensor, 

40 classifier_num_particles: int = 1, 

41 latent_num_particles: int = 1 

42) -> distr.MixtureSameFamily: 

43 """Constructs torch.distribution as an approximation to the true predictive distribution 

44 (in bayessian sense) using variational distributions 

45 

46 Args: 

47 pred_distr (PredictiveDistr): see MultiTaskElbo 

48 classifier_distr (distr.Distribution): see MultiTaskElbo 

49 latent_distr (LatentDistr): see MultiTaskElbo 

50 X (torch.Tensor): new inputs for which to build predictive distr 

51 classifier_num_particles (int, optional): see MultiTaskElbo. Defaults to 1. 

52 latent_num_particles (int, optional): see MultiTaskElbo. Defaults to 1. 

53 

54 Returns: 

55 distr.MixtureSameFamily: the predictive distr can be seen as mixture distr 

56 """ 

57 # sample hidden state (classifier + latent) from posterior 

58 classifier_samples = classifier_distr.sample((classifier_num_particles, )) 

59 latent_samples = latent_distr(X).sample((latent_num_particles, )).swapaxes(0, 1) 

60 # build conditional distribution objects for target 

61 pred_distr = pred_distr(latent_samples, classifier_samples) 

62 

63 mixing_distr = distr.Categorical(torch.ones(pred_distr.batch_shape)) 

64 

65 return distr.MixtureSameFamily(mixing_distr, pred_distr)