Skip to content

Reference for variational approach in multitask learning

bmm_multitask_learning.variational.distr

Utils for working with distributions

build_predictive(pred_distr, classifier_distr, latent_distr, X, classifier_num_particles=1, latent_num_particles=1)

Constructs torch.distribution as an approximation to the true predictive distribution (in bayessian sense) using variational distributions

Parameters:

Name Type Description Default
pred_distr PredictiveDistr

see MultiTaskElbo

required
classifier_distr Distribution

see MultiTaskElbo

required
latent_distr LatentDistr

see MultiTaskElbo

required
X Tensor

new inputs for which to build predictive distr

required
classifier_num_particles int

see MultiTaskElbo. Defaults to 1.

1
latent_num_particles int

see MultiTaskElbo. Defaults to 1.

1

Returns:

Type Description
MixtureSameFamily

distr.MixtureSameFamily: the predictive distr can be seen as mixture distr

Source code in bmm_multitask_learning/variational/distr.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def build_predictive(
    pred_distr: PredictiveDistr,
    classifier_distr: distr.Distribution,
    latent_distr: LatentDistr,
    X: torch.Tensor,
    classifier_num_particles: int = 1,
    latent_num_particles: int = 1
) -> distr.MixtureSameFamily:
    """Constructs torch.distribution as an approximation to the true predictive distribution
    (in bayessian sense) using variational distributions

    Args:
        pred_distr (PredictiveDistr): see MultiTaskElbo
        classifier_distr (distr.Distribution): see MultiTaskElbo
        latent_distr (LatentDistr): see MultiTaskElbo
        X (torch.Tensor): new inputs for which to build predictive distr
        classifier_num_particles (int, optional): see MultiTaskElbo. Defaults to 1.
        latent_num_particles (int, optional): see MultiTaskElbo. Defaults to 1.

    Returns:
        distr.MixtureSameFamily: the predictive distr can be seen as mixture distr
    """
    # sample hidden state (classifier + latent) from posterior
    classifier_samples = classifier_distr.sample((classifier_num_particles, ))
    latent_samples = latent_distr(X).sample((latent_num_particles, )).swapaxes(0, 1)
    # build conditional distribution objects for target
    pred_distr = pred_distr(latent_samples, classifier_samples)

    mixing_distr = distr.Categorical(torch.ones(pred_distr.batch_shape))

    return distr.MixtureSameFamily(mixing_distr, pred_distr)

kl_sample_estimation(distr_1, distr_2, num_particles=1)

Make sample estimation of the KL divirgence

Parameters:

Name Type Description Default
num_particles int

number of samples for estimation. Defaults to 1.

1
Source code in bmm_multitask_learning/variational/distr.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def kl_sample_estimation(
    distr_1: distr.Distribution,
    distr_2: distr.Distribution,
    num_particles: int = 1
) -> torch.Tensor:
    """Make sample estimation of the KL divirgence

    Args:
        num_particles (int, optional): number of samples for estimation. Defaults to 1.
    """
    samples = distr_1.rsample([num_particles])
    log_p_1 = distr_1.log_prob(samples)
    log_p_2 = distr_2.log_prob(samples)

    return (log_p_1 - log_p_2).mean()

bmm_multitask_learning.variational.elbo

MultiTaskElbo

Bases: Module

General ELBO computer for variational multitask problem.

Source code in bmm_multitask_learning/variational/elbo.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class MultiTaskElbo(nn.Module):
    """General ELBO computer for variational multitask problem. 
    """
    def __init__(
        self,
        task_distrs: list[TargetDistr],
        task_num_samples: list[int],
        classifier_distr: list[distr.Distribution],
        latent_distr: list[LatentDistr],
        classifier_num_particles: int = 1,
        latent_num_particles: int = 1,
        temp_scheduler: Callable[[int], float] | Literal["const"] = Literal["const"],
        kl_estimator_num_samples: int = 10
    ):
        """
        Args:
            task_distrs (list[TargetDistr]): Data distribution for each task p_t(y | z, w)
            task_num_samples (list[int]): Number of train samples for each task. Needed for unbiased ELBO computation in case of batched data.
            classifier_distr (list[distr.Distribution]): Distribution for the classifier q(w | D)
            latent_distr (list[LatentDistr]): Distribution for the latent state q(z | x, D)
            classifier_num_particles (int, optional): num samples from classifier distr. Defaults to 1.
            latent_num_particles (int, optional):  num samples from latent distr. Defaults to 1.
            temp_scheduler (Callable[[int], float] | Literal["const"], optional): _description_. Defaults to Literal["const"].
            kl_estimator_num_samples (int, optional): if your distrs does not have implicit kl computation, 
            it will be approximated using this number of samples. Defaults to 10.

            Warning:
                This nn.Module does not register nn.Parameters from the distributions inside itself
        Raises:
            ValueError: if number of tasks <= 2
        """
        super().__init__()

        self.task_distrs = task_distrs
        self.classifier_distr = classifier_distr
        self.latent_distr = latent_distr

        self.num_tasks = len(task_distrs)
        if self.num_tasks < 2:
            raise ValueError(f"Number of tasks should be > 2, {self.num_tasks} was given")
        self.task_num_samples = task_num_samples
        self.classifier_num_particles = classifier_num_particles
        self.latent_num_particles = latent_num_particles
        self.kl_estimator_num_samples = kl_estimator_num_samples

        self.temp_scheduler = temp_scheduler if temp_scheduler is not "const" else lambda t: 1.

        # define gumbel-softmax parameters for classifier and latent
        # initialize uniform
        self._classifier_mixings_params, self._latent_mixings_params = [
            nn.Parameter(
                torch.full((self.num_tasks, self.num_tasks), 1 / (self.num_tasks - 1))
            )
        ] * 2

    def forward(self, data: list[torch.Tensor], targets: list[torch.Tensor], step: int) -> torch.Tensor:
        """Computes ELBO estimation for variational multitask problem.

        Args:
            targets (list[torch.Tensor]): batched targets (y) for each task 
            data (list[torch.Tensor]): batched data (X) for each task 
            step: needed for temperature func

        Returns:
            torch.Tensor: ELBO estimation
        """
        # get mixing values in form of matrix
        temp = self.temp_scheduler(step)
        classifier_mixing = self._get_gumbelsm_mixing(self._classifier_mixings_params, temp)
        latent_mixing = self._get_gumbelsm_mixing(self._latent_mixings_params, temp)

        # sample classifiers
        # shape = (num_tasks, classifier_num_particles, classifier_shape)
        classifiers = torch.stack(
            list(
                self.classifier_distr |
                select(lambda d: d.rsample((self.classifier_num_particles, )))
            )
        )

        # sample latents
        # shape = [num_tasks, (num_samples(num_tasks), latent_num_particles, latent_shape)]
        latents = []
        for i, latent_cond_distr in enumerate(self.latent_distr):
            latents.append(
                latent_cond_distr(data[i]).rsample((self.latent_num_particles, )).swapaxes(0, 1)
            )

        # get log liklyhood for task + sampled averaged across latent and classifier particles
        lh_per_task = []
        for i in range(self.num_tasks):
            cur_lh = self._compute_lh_per_task(i, latents[i], classifiers[i], targets[i])
            lh_per_task.append(cur_lh)
        # average lh samples across tasks
        lh_val = torch.stack(lh_per_task).mean()

        # get summed latents kl for each task
        latents_kl = []
        for i in range(self.num_tasks):
            cur_data = data[i]
            cur_mixing = latent_mixing[i]
            cur_kl = self._compute_latent_kl_per_task(i, cur_data, cur_mixing)
            latents_kl.append(cur_kl)
        # average kl among tasks
        latents_kl = torch.stack(latents_kl).mean()

        # get classifiers kl for each task
        classifiers_kl = []
        for i in range(self.num_tasks):
            cur_mixing = classifier_mixing[i]
            cur_kl = self._compute_cls_kl_per_task(i, cur_mixing)
            classifiers_kl.append(cur_kl)
        # average kl among tasks
        classifiers_kl = torch.stack(classifiers_kl).mean()

        elbo = lh_val + latents_kl + classifiers_kl

        return {
            "elbo": elbo,
            "lh_loss": lh_val,
            "lat_kl": latents_kl,
            "cls_kl": classifiers_kl
        }

    def _compute_lh_per_task(
        self,
        task_num: int,
        latents: torch.Tensor,
        classifiers: torch.Tensor,
        targets: torch.Tensor
    ) -> torch.Tensor:
        """Compute -log prob for each latent and classifier particle for each batch,
        mean across classifiers and latents, sum across targets with batch size correction
        """
        task_cond_distr = self.task_distrs[task_num]
        target_shape = targets.shape[1:]
        batch_size = targets.shape[0]

        # log_prob shape=(batch_size, lat_num_part, classifier_num_part, target_shape)
        return -task_cond_distr(latents, classifiers).log_prob(
                targets[:, None, None, ...].expand(-1, self.latent_num_particles, self.classifier_num_particles, *target_shape)
            ).mean(dim=(1, 2)).sum(dim=0) * (self.task_num_samples[task_num] / batch_size)

    def _compute_latent_kl_per_task(
        self,
        task_num: int,
        inputs: torch.Tensor,
        latent_mixing: torch.Tensor
    ) -> torch.Tensor:
        batch_size = inputs.shape[0]
        cur_distr = self.latent_distr[task_num](inputs)

        return torch.stack(
            [self._compute_kl(cur_distr, lat_cond_distr(inputs)) for lat_cond_distr in self.latent_distr],
            dim=1
        ).matmul(latent_mixing).sum() * \
            (self.task_num_samples[task_num] / batch_size) # sum across batch with batch size correction

    def _compute_cls_kl_per_task(
        self,
        task_num: int,
        clas_mixing: torch.Tensor
    ) -> torch.Tensor:
        cur_distr = self.classifier_distr[task_num]

        return torch.stack(
            [self._compute_kl(cur_distr, cl_cond_distr) for cl_cond_distr in self.classifier_distr]
        ).dot(clas_mixing)

    def _compute_kl(self, distr_1: distr.Distribution, distr_2: distr.Distribution) -> torch.Tensor:
        """Computes KL analytically if possible else make a sample estimation
        """
        if distr_1 is distr_2:
            return torch.zeros(distr_1.batch_shape)

        try:
            return distr.kl_divergence(distr_1, distr_2)
        except NotImplementedError:
            return kl_sample_estimation(distr_1, distr_2, self.kl_estimator_num_samples)

    def _get_gumbelsm_mixing(self, mixings_params: torch.Tensor, temp: float) -> torch.Tensor:
        # mixing with self is prohibited, so we mask diagonal to get zeros after softmax
        mask = torch.diag(torch.full((self.num_tasks, ), -torch.inf))

        mixing = distr.Gumbel(0., 1.).sample((self.num_tasks, self.num_tasks))
        mixing += mixings_params.log()
        mixing = mixing / temp
        mixing += mask
        mixing = torch.softmax(mixing, dim=1)

        return mixing

    @property
    def classifier_mixings_params(self):
        """Accesses classifer mixing params
        """
        return self._classifier_mixings_params

    @property
    def latent_mixings_params(self):
        """Accesses latent mixing params
        """
        return self._latent_mixings_params

classifier_mixings_params property

Accesses classifer mixing params

latent_mixings_params property

Accesses latent mixing params

__init__(task_distrs, task_num_samples, classifier_distr, latent_distr, classifier_num_particles=1, latent_num_particles=1, temp_scheduler=Literal['const'], kl_estimator_num_samples=10)

Parameters:

Name Type Description Default
task_distrs list[TargetDistr]

Data distribution for each task p_t(y | z, w)

required
task_num_samples list[int]

Number of train samples for each task. Needed for unbiased ELBO computation in case of batched data.

required
classifier_distr list[Distribution]

Distribution for the classifier q(w | D)

required
latent_distr list[LatentDistr]

Distribution for the latent state q(z | x, D)

required
classifier_num_particles int

num samples from classifier distr. Defaults to 1.

1
latent_num_particles int

num samples from latent distr. Defaults to 1.

1
temp_scheduler Callable[[int], float] | Literal[&quot;const&quot;]

description. Defaults to Literal["const"].

Literal['const']
kl_estimator_num_samples int

if your distrs does not have implicit kl computation,

10
Warning

This nn.Module does not register nn.Parameters from the distributions inside itself

required

Raises: ValueError: if number of tasks <= 2

Source code in bmm_multitask_learning/variational/elbo.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(
    self,
    task_distrs: list[TargetDistr],
    task_num_samples: list[int],
    classifier_distr: list[distr.Distribution],
    latent_distr: list[LatentDistr],
    classifier_num_particles: int = 1,
    latent_num_particles: int = 1,
    temp_scheduler: Callable[[int], float] | Literal["const"] = Literal["const"],
    kl_estimator_num_samples: int = 10
):
    """
    Args:
        task_distrs (list[TargetDistr]): Data distribution for each task p_t(y | z, w)
        task_num_samples (list[int]): Number of train samples for each task. Needed for unbiased ELBO computation in case of batched data.
        classifier_distr (list[distr.Distribution]): Distribution for the classifier q(w | D)
        latent_distr (list[LatentDistr]): Distribution for the latent state q(z | x, D)
        classifier_num_particles (int, optional): num samples from classifier distr. Defaults to 1.
        latent_num_particles (int, optional):  num samples from latent distr. Defaults to 1.
        temp_scheduler (Callable[[int], float] | Literal[&quot;const&quot;], optional): _description_. Defaults to Literal["const"].
        kl_estimator_num_samples (int, optional): if your distrs does not have implicit kl computation, 
        it will be approximated using this number of samples. Defaults to 10.

        Warning:
            This nn.Module does not register nn.Parameters from the distributions inside itself
    Raises:
        ValueError: if number of tasks <= 2
    """
    super().__init__()

    self.task_distrs = task_distrs
    self.classifier_distr = classifier_distr
    self.latent_distr = latent_distr

    self.num_tasks = len(task_distrs)
    if self.num_tasks < 2:
        raise ValueError(f"Number of tasks should be > 2, {self.num_tasks} was given")
    self.task_num_samples = task_num_samples
    self.classifier_num_particles = classifier_num_particles
    self.latent_num_particles = latent_num_particles
    self.kl_estimator_num_samples = kl_estimator_num_samples

    self.temp_scheduler = temp_scheduler if temp_scheduler is not "const" else lambda t: 1.

    # define gumbel-softmax parameters for classifier and latent
    # initialize uniform
    self._classifier_mixings_params, self._latent_mixings_params = [
        nn.Parameter(
            torch.full((self.num_tasks, self.num_tasks), 1 / (self.num_tasks - 1))
        )
    ] * 2

forward(data, targets, step)

Computes ELBO estimation for variational multitask problem.

Parameters:

Name Type Description Default
targets list[Tensor]

batched targets (y) for each task

required
data list[Tensor]

batched data (X) for each task

required
step int

needed for temperature func

required

Returns:

Type Description
Tensor

torch.Tensor: ELBO estimation

Source code in bmm_multitask_learning/variational/elbo.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def forward(self, data: list[torch.Tensor], targets: list[torch.Tensor], step: int) -> torch.Tensor:
    """Computes ELBO estimation for variational multitask problem.

    Args:
        targets (list[torch.Tensor]): batched targets (y) for each task 
        data (list[torch.Tensor]): batched data (X) for each task 
        step: needed for temperature func

    Returns:
        torch.Tensor: ELBO estimation
    """
    # get mixing values in form of matrix
    temp = self.temp_scheduler(step)
    classifier_mixing = self._get_gumbelsm_mixing(self._classifier_mixings_params, temp)
    latent_mixing = self._get_gumbelsm_mixing(self._latent_mixings_params, temp)

    # sample classifiers
    # shape = (num_tasks, classifier_num_particles, classifier_shape)
    classifiers = torch.stack(
        list(
            self.classifier_distr |
            select(lambda d: d.rsample((self.classifier_num_particles, )))
        )
    )

    # sample latents
    # shape = [num_tasks, (num_samples(num_tasks), latent_num_particles, latent_shape)]
    latents = []
    for i, latent_cond_distr in enumerate(self.latent_distr):
        latents.append(
            latent_cond_distr(data[i]).rsample((self.latent_num_particles, )).swapaxes(0, 1)
        )

    # get log liklyhood for task + sampled averaged across latent and classifier particles
    lh_per_task = []
    for i in range(self.num_tasks):
        cur_lh = self._compute_lh_per_task(i, latents[i], classifiers[i], targets[i])
        lh_per_task.append(cur_lh)
    # average lh samples across tasks
    lh_val = torch.stack(lh_per_task).mean()

    # get summed latents kl for each task
    latents_kl = []
    for i in range(self.num_tasks):
        cur_data = data[i]
        cur_mixing = latent_mixing[i]
        cur_kl = self._compute_latent_kl_per_task(i, cur_data, cur_mixing)
        latents_kl.append(cur_kl)
    # average kl among tasks
    latents_kl = torch.stack(latents_kl).mean()

    # get classifiers kl for each task
    classifiers_kl = []
    for i in range(self.num_tasks):
        cur_mixing = classifier_mixing[i]
        cur_kl = self._compute_cls_kl_per_task(i, cur_mixing)
        classifiers_kl.append(cur_kl)
    # average kl among tasks
    classifiers_kl = torch.stack(classifiers_kl).mean()

    elbo = lh_val + latents_kl + classifiers_kl

    return {
        "elbo": elbo,
        "lh_loss": lh_val,
        "lat_kl": latents_kl,
        "cls_kl": classifiers_kl
    }