Coverage for bmm_multitask_learning/variational/elbo.py: 94%
84 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-13 13:33 +0000
1from typing import Literal, Callable
2from functools import partial
3from pipe import select
4from itertools import product
6import torch
7from torch import nn
8from torch import distributions as distr
10from .distr import kl_sample_estimation, TargetDistr, LatentDistr
13class MultiTaskElbo(nn.Module):
14 """General ELBO computer for variational multitask problem.
15 """
16 def __init__(
17 self,
18 task_distrs: list[TargetDistr],
19 task_num_samples: list[int],
20 classifier_distr: list[distr.Distribution],
21 latent_distr: list[LatentDistr],
22 classifier_num_particles: int = 1,
23 latent_num_particles: int = 1,
24 temp_scheduler: Callable[[int], float] | Literal["const"] = Literal["const"],
25 kl_estimator_num_samples: int = 10
26 ):
27 """
28 Args:
29 task_distrs (list[TargetDistr]): Data distribution for each task p_t(y | z, w)
30 task_num_samples (list[int]): Number of train samples for each task. Needed for unbiased ELBO computation in case of batched data.
31 classifier_distr (list[distr.Distribution]): Distribution for the classifier q(w | D)
32 latent_distr (list[LatentDistr]): Distribution for the latent state q(z | x, D)
33 classifier_num_particles (int, optional): num samples from classifier distr. Defaults to 1.
34 latent_num_particles (int, optional): num samples from latent distr. Defaults to 1.
35 temp_scheduler (Callable[[int], float] | Literal["const"], optional): _description_. Defaults to Literal["const"].
36 kl_estimator_num_samples (int, optional): if your distrs does not have implicit kl computation,
37 it will be approximated using this number of samples. Defaults to 10.
39 Warning:
40 This nn.Module does not register nn.Parameters from the distributions inside itself
41 Raises:
42 ValueError: if number of tasks <= 2
43 """
44 super().__init__()
46 self.task_distrs = task_distrs
47 self.classifier_distr = classifier_distr
48 self.latent_distr = latent_distr
50 self.num_tasks = len(task_distrs)
51 if self.num_tasks < 2:
52 raise ValueError(f"Number of tasks should be > 2, {self.num_tasks} was given")
53 self.task_num_samples = task_num_samples
54 self.classifier_num_particles = classifier_num_particles
55 self.latent_num_particles = latent_num_particles
56 self.kl_estimator_num_samples = kl_estimator_num_samples
58 self.temp_scheduler = temp_scheduler if temp_scheduler is not "const" else lambda t: 1.
60 # define gumbel-softmax parameters for classifier and latent
61 # initialize uniform
62 self._classifier_mixings_params, self._latent_mixings_params = [
63 nn.Parameter(
64 torch.full((self.num_tasks, self.num_tasks), 1 / (self.num_tasks - 1))
65 )
66 ] * 2
68 def forward(self, data: list[torch.Tensor], targets: list[torch.Tensor], step: int) -> torch.Tensor:
69 """Computes ELBO estimation for variational multitask problem.
71 Args:
72 targets (list[torch.Tensor]): batched targets (y) for each task
73 data (list[torch.Tensor]): batched data (X) for each task
74 step: needed for temperature func
76 Returns:
77 torch.Tensor: ELBO estimation
78 """
79 # get mixing values in form of matrix
80 temp = self.temp_scheduler(step)
81 classifier_mixing = self._get_gumbelsm_mixing(self._classifier_mixings_params, temp)
82 latent_mixing = self._get_gumbelsm_mixing(self._latent_mixings_params, temp)
84 # sample classifiers
85 # shape = (num_tasks, classifier_num_particles, classifier_shape)
86 classifiers = torch.stack(
87 list(
88 self.classifier_distr |
89 select(lambda d: d.rsample((self.classifier_num_particles, )))
90 )
91 )
93 # sample latents
94 # shape = [num_tasks, (num_samples(num_tasks), latent_num_particles, latent_shape)]
95 latents = []
96 for i, latent_cond_distr in enumerate(self.latent_distr):
97 latents.append(
98 latent_cond_distr(data[i]).rsample((self.latent_num_particles, )).swapaxes(0, 1)
99 )
101 # get log liklyhood for task + sampled averaged across latent and classifier particles
102 lh_per_task = []
103 for i in range(self.num_tasks):
104 cur_lh = self._compute_lh_per_task(i, latents[i], classifiers[i], targets[i])
105 lh_per_task.append(cur_lh)
106 # average lh samples across tasks
107 lh_val = torch.stack(lh_per_task).mean()
109 # get summed latents kl for each task
110 latents_kl = []
111 for i in range(self.num_tasks):
112 cur_data = data[i]
113 cur_mixing = latent_mixing[i]
114 cur_kl = self._compute_latent_kl_per_task(i, cur_data, cur_mixing)
115 latents_kl.append(cur_kl)
116 # average kl among tasks
117 latents_kl = torch.stack(latents_kl).mean()
119 # get classifiers kl for each task
120 classifiers_kl = []
121 for i in range(self.num_tasks):
122 cur_mixing = classifier_mixing[i]
123 cur_kl = self._compute_cls_kl_per_task(i, cur_mixing)
124 classifiers_kl.append(cur_kl)
125 # average kl among tasks
126 classifiers_kl = torch.stack(classifiers_kl).mean()
128 elbo = lh_val + latents_kl + classifiers_kl
130 return {
131 "elbo": elbo,
132 "lh_loss": lh_val,
133 "lat_kl": latents_kl,
134 "cls_kl": classifiers_kl
135 }
137 def _compute_lh_per_task(
138 self,
139 task_num: int,
140 latents: torch.Tensor,
141 classifiers: torch.Tensor,
142 targets: torch.Tensor
143 ) -> torch.Tensor:
144 """Compute -log prob for each latent and classifier particle for each batch,
145 mean across classifiers and latents, sum across targets with batch size correction
146 """
147 task_cond_distr = self.task_distrs[task_num]
148 target_shape = targets.shape[1:]
149 batch_size = targets.shape[0]
151 # log_prob shape=(batch_size, lat_num_part, classifier_num_part, target_shape)
152 return -task_cond_distr(latents, classifiers).log_prob(
153 targets[:, None, None, ...].expand(-1, self.latent_num_particles, self.classifier_num_particles, *target_shape)
154 ).mean(dim=(1, 2)).sum(dim=0) * (self.task_num_samples[task_num] / batch_size)
156 def _compute_latent_kl_per_task(
157 self,
158 task_num: int,
159 inputs: torch.Tensor,
160 latent_mixing: torch.Tensor
161 ) -> torch.Tensor:
162 batch_size = inputs.shape[0]
163 cur_distr = self.latent_distr[task_num](inputs)
165 return torch.stack(
166 [self._compute_kl(cur_distr, lat_cond_distr(inputs)) for lat_cond_distr in self.latent_distr],
167 dim=1
168 ).matmul(latent_mixing).sum() * \
169 (self.task_num_samples[task_num] / batch_size) # sum across batch with batch size correction
171 def _compute_cls_kl_per_task(
172 self,
173 task_num: int,
174 clas_mixing: torch.Tensor
175 ) -> torch.Tensor:
176 cur_distr = self.classifier_distr[task_num]
178 return torch.stack(
179 [self._compute_kl(cur_distr, cl_cond_distr) for cl_cond_distr in self.classifier_distr]
180 ).dot(clas_mixing)
182 def _compute_kl(self, distr_1: distr.Distribution, distr_2: distr.Distribution) -> torch.Tensor:
183 """Computes KL analytically if possible else make a sample estimation
184 """
185 if distr_1 is distr_2:
186 return torch.zeros(distr_1.batch_shape)
188 try:
189 return distr.kl_divergence(distr_1, distr_2)
190 except NotImplementedError:
191 return kl_sample_estimation(distr_1, distr_2, self.kl_estimator_num_samples)
193 def _get_gumbelsm_mixing(self, mixings_params: torch.Tensor, temp: float) -> torch.Tensor:
194 # mixing with self is prohibited, so we mask diagonal to get zeros after softmax
195 mask = torch.diag(torch.full((self.num_tasks, ), -torch.inf))
197 mixing = distr.Gumbel(0., 1.).sample((self.num_tasks, self.num_tasks))
198 mixing += mixings_params.log()
199 mixing = mixing / temp
200 mixing += mask
201 mixing = torch.softmax(mixing, dim=1)
203 return mixing
205 @property
206 def classifier_mixings_params(self):
207 """Accesses classifer mixing params
208 """
209 return self._classifier_mixings_params
211 @property
212 def latent_mixings_params(self):
213 """Accesses latent mixing params
214 """
215 return self._latent_mixings_params