Coverage for bmm_multitask_learning/variational/elbo.py: 94%

84 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-13 13:33 +0000

1from typing import Literal, Callable 

2from functools import partial 

3from pipe import select 

4from itertools import product 

5 

6import torch 

7from torch import nn 

8from torch import distributions as distr 

9 

10from .distr import kl_sample_estimation, TargetDistr, LatentDistr 

11 

12 

13class MultiTaskElbo(nn.Module): 

14 """General ELBO computer for variational multitask problem.  

15 """ 

16 def __init__( 

17 self, 

18 task_distrs: list[TargetDistr], 

19 task_num_samples: list[int], 

20 classifier_distr: list[distr.Distribution], 

21 latent_distr: list[LatentDistr], 

22 classifier_num_particles: int = 1, 

23 latent_num_particles: int = 1, 

24 temp_scheduler: Callable[[int], float] | Literal["const"] = Literal["const"], 

25 kl_estimator_num_samples: int = 10 

26 ): 

27 """ 

28 Args: 

29 task_distrs (list[TargetDistr]): Data distribution for each task p_t(y | z, w) 

30 task_num_samples (list[int]): Number of train samples for each task. Needed for unbiased ELBO computation in case of batched data. 

31 classifier_distr (list[distr.Distribution]): Distribution for the classifier q(w | D) 

32 latent_distr (list[LatentDistr]): Distribution for the latent state q(z | x, D) 

33 classifier_num_particles (int, optional): num samples from classifier distr. Defaults to 1. 

34 latent_num_particles (int, optional): num samples from latent distr. Defaults to 1. 

35 temp_scheduler (Callable[[int], float] | Literal["const"], optional): _description_. Defaults to Literal["const"]. 

36 kl_estimator_num_samples (int, optional): if your distrs does not have implicit kl computation,  

37 it will be approximated using this number of samples. Defaults to 10. 

38 

39 Warning: 

40 This nn.Module does not register nn.Parameters from the distributions inside itself 

41 Raises: 

42 ValueError: if number of tasks <= 2 

43 """ 

44 super().__init__() 

45 

46 self.task_distrs = task_distrs 

47 self.classifier_distr = classifier_distr 

48 self.latent_distr = latent_distr 

49 

50 self.num_tasks = len(task_distrs) 

51 if self.num_tasks < 2: 

52 raise ValueError(f"Number of tasks should be > 2, {self.num_tasks} was given") 

53 self.task_num_samples = task_num_samples 

54 self.classifier_num_particles = classifier_num_particles 

55 self.latent_num_particles = latent_num_particles 

56 self.kl_estimator_num_samples = kl_estimator_num_samples 

57 

58 self.temp_scheduler = temp_scheduler if temp_scheduler is not "const" else lambda t: 1. 

59 

60 # define gumbel-softmax parameters for classifier and latent 

61 # initialize uniform 

62 self._classifier_mixings_params, self._latent_mixings_params = [ 

63 nn.Parameter( 

64 torch.full((self.num_tasks, self.num_tasks), 1 / (self.num_tasks - 1)) 

65 ) 

66 ] * 2 

67 

68 def forward(self, data: list[torch.Tensor], targets: list[torch.Tensor], step: int) -> torch.Tensor: 

69 """Computes ELBO estimation for variational multitask problem. 

70 

71 Args: 

72 targets (list[torch.Tensor]): batched targets (y) for each task  

73 data (list[torch.Tensor]): batched data (X) for each task  

74 step: needed for temperature func 

75 

76 Returns: 

77 torch.Tensor: ELBO estimation 

78 """ 

79 # get mixing values in form of matrix 

80 temp = self.temp_scheduler(step) 

81 classifier_mixing = self._get_gumbelsm_mixing(self._classifier_mixings_params, temp) 

82 latent_mixing = self._get_gumbelsm_mixing(self._latent_mixings_params, temp) 

83 

84 # sample classifiers 

85 # shape = (num_tasks, classifier_num_particles, classifier_shape) 

86 classifiers = torch.stack( 

87 list( 

88 self.classifier_distr | 

89 select(lambda d: d.rsample((self.classifier_num_particles, ))) 

90 ) 

91 ) 

92 

93 # sample latents 

94 # shape = [num_tasks, (num_samples(num_tasks), latent_num_particles, latent_shape)] 

95 latents = [] 

96 for i, latent_cond_distr in enumerate(self.latent_distr): 

97 latents.append( 

98 latent_cond_distr(data[i]).rsample((self.latent_num_particles, )).swapaxes(0, 1) 

99 ) 

100 

101 # get log liklyhood for task + sampled averaged across latent and classifier particles 

102 lh_per_task = [] 

103 for i in range(self.num_tasks): 

104 cur_lh = self._compute_lh_per_task(i, latents[i], classifiers[i], targets[i]) 

105 lh_per_task.append(cur_lh) 

106 # average lh samples across tasks 

107 lh_val = torch.stack(lh_per_task).mean() 

108 

109 # get summed latents kl for each task 

110 latents_kl = [] 

111 for i in range(self.num_tasks): 

112 cur_data = data[i] 

113 cur_mixing = latent_mixing[i] 

114 cur_kl = self._compute_latent_kl_per_task(i, cur_data, cur_mixing) 

115 latents_kl.append(cur_kl) 

116 # average kl among tasks 

117 latents_kl = torch.stack(latents_kl).mean() 

118 

119 # get classifiers kl for each task 

120 classifiers_kl = [] 

121 for i in range(self.num_tasks): 

122 cur_mixing = classifier_mixing[i] 

123 cur_kl = self._compute_cls_kl_per_task(i, cur_mixing) 

124 classifiers_kl.append(cur_kl) 

125 # average kl among tasks 

126 classifiers_kl = torch.stack(classifiers_kl).mean() 

127 

128 elbo = lh_val + latents_kl + classifiers_kl 

129 

130 return { 

131 "elbo": elbo, 

132 "lh_loss": lh_val, 

133 "lat_kl": latents_kl, 

134 "cls_kl": classifiers_kl 

135 } 

136 

137 def _compute_lh_per_task( 

138 self, 

139 task_num: int, 

140 latents: torch.Tensor, 

141 classifiers: torch.Tensor, 

142 targets: torch.Tensor 

143 ) -> torch.Tensor: 

144 """Compute -log prob for each latent and classifier particle for each batch, 

145 mean across classifiers and latents, sum across targets with batch size correction 

146 """ 

147 task_cond_distr = self.task_distrs[task_num] 

148 target_shape = targets.shape[1:] 

149 batch_size = targets.shape[0] 

150 

151 # log_prob shape=(batch_size, lat_num_part, classifier_num_part, target_shape) 

152 return -task_cond_distr(latents, classifiers).log_prob( 

153 targets[:, None, None, ...].expand(-1, self.latent_num_particles, self.classifier_num_particles, *target_shape) 

154 ).mean(dim=(1, 2)).sum(dim=0) * (self.task_num_samples[task_num] / batch_size) 

155 

156 def _compute_latent_kl_per_task( 

157 self, 

158 task_num: int, 

159 inputs: torch.Tensor, 

160 latent_mixing: torch.Tensor 

161 ) -> torch.Tensor: 

162 batch_size = inputs.shape[0] 

163 cur_distr = self.latent_distr[task_num](inputs) 

164 

165 return torch.stack( 

166 [self._compute_kl(cur_distr, lat_cond_distr(inputs)) for lat_cond_distr in self.latent_distr], 

167 dim=1 

168 ).matmul(latent_mixing).sum() * \ 

169 (self.task_num_samples[task_num] / batch_size) # sum across batch with batch size correction 

170 

171 def _compute_cls_kl_per_task( 

172 self, 

173 task_num: int, 

174 clas_mixing: torch.Tensor 

175 ) -> torch.Tensor: 

176 cur_distr = self.classifier_distr[task_num] 

177 

178 return torch.stack( 

179 [self._compute_kl(cur_distr, cl_cond_distr) for cl_cond_distr in self.classifier_distr] 

180 ).dot(clas_mixing) 

181 

182 def _compute_kl(self, distr_1: distr.Distribution, distr_2: distr.Distribution) -> torch.Tensor: 

183 """Computes KL analytically if possible else make a sample estimation 

184 """ 

185 if distr_1 is distr_2: 

186 return torch.zeros(distr_1.batch_shape) 

187 

188 try: 

189 return distr.kl_divergence(distr_1, distr_2) 

190 except NotImplementedError: 

191 return kl_sample_estimation(distr_1, distr_2, self.kl_estimator_num_samples) 

192 

193 def _get_gumbelsm_mixing(self, mixings_params: torch.Tensor, temp: float) -> torch.Tensor: 

194 # mixing with self is prohibited, so we mask diagonal to get zeros after softmax 

195 mask = torch.diag(torch.full((self.num_tasks, ), -torch.inf)) 

196 

197 mixing = distr.Gumbel(0., 1.).sample((self.num_tasks, self.num_tasks)) 

198 mixing += mixings_params.log() 

199 mixing = mixing / temp 

200 mixing += mask 

201 mixing = torch.softmax(mixing, dim=1) 

202 

203 return mixing 

204 

205 @property 

206 def classifier_mixings_params(self): 

207 """Accesses classifer mixing params 

208 """ 

209 return self._classifier_mixings_params 

210 

211 @property 

212 def latent_mixings_params(self): 

213 """Accesses latent mixing params 

214 """ 

215 return self._latent_mixings_params