Variational multitask learning elementary example¶
This is a practical demonstration on how to use variational subpackage on a simple classification example. We are going to solve 3 classification tasks with logistic regression as a model. Additionally, we will add prior on the weight so the tasks become bayessian. Two of the tasks will be probabilsitcally connected, the last will have no probabilistic connections with others.
First, we will apply variational principle to learn each task individually. We will use pyro package to automatically compute ELBO and minimize it.
Secondly, we will use variational subpackage and learn 3 tasks alltogether. Learning here is the same ELBO minimizing, but for special variational structure - see doc for more details.
Lastly, we will compare two approaches in accuracy terms.
from typing import Optional
from pipe import select
from omegaconf import OmegaConf
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, StackDataset
from torch import optim
from torch import distributions as distr
import pyro
import pyro.nn as pnn
import pyro.distributions as pdistr
from pyro.distributions import Delta
from pyro.infer import Trace_ELBO
from pyro.infer.autoguide import AutoNormal
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping
from torchmetrics.classification import Accuracy
from bmm_multitask_learning.variational.elbo import MultiTaskElbo
from bmm_multitask_learning.variational.distr import build_predictive
%load_ext tensorboard
The experiment is configurable via yaml config.
config = OmegaConf.load("config.yaml")
torch.manual_seed(config.seed);
Data¶
For each task we generate 2-dimensional inputs $X$ where each individual input $x$ is
$$ x \sim \mathcal{N}(\mu, \sigma_x^2) $$
For task 1 and 2 these are very close to each other.
Then we generate logistic regression parameter $w$ as
$$ w_1 \sim \mathcal{N}(\mathbb{E}[X_2], \sigma_w^2) \\ w_2 \sim \mathcal{N}(\mathbb{E}[X_1], \sigma_w^2) \\ w_3 \sim \mathcal{N}(\mathbf{1}, \sigma_w^2) $$
Because $X_1$ and $X_2$ are close, distribuitions for $w_1$ and $w_2$ will be close too. This connection can be utilized by variational multitask approach.
Finally, label $y$ for input $x$ generated as
$$ y \sim \text{Bern}(\sigma(x^{T}w)) $$
Data generation with given schema is defined in data.py
from data import build_linked_datasets, build_solo_dataset
NUM_MODELS = 3
datasets = [*build_linked_datasets(config.size, config.dim), build_solo_dataset(config.size, config.dim)]
# extract w
w_list = list(
datasets | select(lambda w_dataset: w_dataset[0])
)
# extract (X, y) pairs
datasets = list(
datasets |
select(lambda w_dataset: w_dataset[1]) |
select(lambda d: random_split(d, [1 - config.test_ratio, config.test_ratio]))
)
train_datasets, test_datasets = zip(*datasets)
Let's vizualize generated inputs and $w$ vectors for each task
points_df = []
for i, dataset in enumerate(train_datasets):
X = dataset.dataset.tensors[0].numpy()
cur_df = pd.DataFrame(X, columns=["x1", "x2"])
cur_df["dataset"] = str(i)
points_df.append(cur_df)
points_df = pd.concat(points_df, axis=0)
points_df.head()
| x1 | x2 | dataset | |
|---|---|---|---|
| 0 | -0.094634 | -2.335958 | 0 |
| 1 | 10.453235 | 3.087356 | 0 |
| 2 | -2.813742 | -2.073199 | 0 |
| 3 | -6.502639 | -0.509825 | 0 |
| 4 | -5.492137 | -2.334638 | 0 |
plane_df = []
for i, w in enumerate(w_list):
w = w.numpy()
cur_df = pd.DataFrame(np.linspace(-20, 20, 10)[:, None] * w[None, :], columns=["x1", "x2"])
cur_df["dataset"] = str(i+1)
plane_df.append(cur_df)
plane_df = pd.concat(plane_df, axis=0)
plane_df.head()
| x1 | x2 | dataset | |
|---|---|---|---|
| 0 | 9.379683 | -1.807698 | 1 |
| 1 | 7.295309 | -1.405987 | 1 |
| 2 | 5.210935 | -1.004277 | 1 |
| 3 | 3.126561 | -0.602566 | 1 |
| 4 | 1.042187 | -0.200855 | 1 |
fig_points = px.scatter(points_df, x="x1", y="x2", color="dataset", symbol="dataset")
fig_plane = px.line(plane_df, x="x1", y="x2", color="dataset")
fig_plane.update_traces(line=dict(width=3))
fig = go.Figure(data=fig_points.data + fig_plane.data)
fig.show()
As we can see inputs for tasks 1 and 2 are indeed close. Because of small $\sigma_w$ we got $w_1$ and $w_2$ also close.
Solo models¶
# here we define probabilistic model of each task using pyro
class SoloModel(pnn.PyroModule):
def __init__(
self,
dim: int = 2,
num_data_samples: Optional[int] = None
):
super().__init__()
self.num_data_samples = num_data_samples
# set parametric prior on w
self.w_loc = pnn.PyroParam(torch.zeros((dim, )))
self.log_w_scale = pnn.PyroParam(torch.zeros((dim, )))
self.w = pnn.PyroSample(
lambda self: pdistr.Normal(self.w_loc, torch.exp(self.log_w_scale)).to_event(1)
)
def forward(self, X: torch.Tensor, y: torch.Tensor = None):
batch_size = X.shape[0]
if self.num_data_samples:
size = self.num_data_samples
subsample_size = batch_size
else:
size = batch_size
subsample_size = None
p = torch.sigmoid(X.matmul(self.w))
with pyro.plate("data_batch", size=size, subsample_size=subsample_size):
pyro.sample("y", pdistr.Bernoulli(p), obs=y)
We are going to use lightning to perform training and logging
class LitSoloModel(L.LightningModule):
def __init__(
self,
elbo_f: pyro.infer.elbo.ELBOModule,
predictive: pyro.infer.Predictive,
num_data_samples: int,
):
super().__init__()
self.num_data_samples = num_data_samples
self.accuracy_computer = Accuracy('binary')
self.elbo_f = elbo_f
self.model: SoloModel = elbo_f.model
self.guide = elbo_f.guide
self.predictive = predictive
def training_step(self, batch: tuple[torch.Tensor], batch_idx: int):
X, y = batch
elbo_loss = self.elbo_f(X, y)
self.log("Train/ELBO", elbo_loss, prog_bar=True)
return elbo_loss
def validation_step(self, batch: tuple[torch.Tensor], batch_idx: int):
X, y = batch
y_pred = (self.predictive(X, y=None)["y"].mean(dim=0) > 0.5).to(torch.float32)
self.accuracy_computer.update(y_pred, y)
def on_validation_epoch_end(self):
self.log("Test/Accuracy", self.accuracy_computer.compute())
self.accuracy_computer.reset()
def configure_optimizers(self):
return optim.Adam(self.elbo_f.parameters())
Now train individual models
for i in range(NUM_MODELS):
print(f"Training model {i}\n")
model = SoloModel(config.dim, config.size)
guide = AutoNormal(model)
# num elbo particles is equivallent to variational multitask
num_elbo_particles = config.mt_elbo.classifier_num_particles * config.mt_elbo.latent_num_particles
elbo_f = Trace_ELBO(num_elbo_particles)(model, guide)
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = next(iter(DataLoader(train_datasets[0], batch_size=1)))
elbo_f(*mini_batch)
# this choice of num_predictive_particles is rather balancing
num_predictive_particles = num_elbo_particles
predictive = pyro.infer.Predictive(model, guide=guide, num_samples=num_predictive_particles)
lit_model = LitSoloModel(elbo_f, predictive, config.size)
train_dataloader = DataLoader(train_datasets[i], batch_size=config.batch_size, shuffle=True)
test_dataloader = DataLoader(test_datasets[i], batch_size=config.batch_size)
logger = TensorBoardLogger("mt_logs/solo", name=f"solo_{i}")
callbacks = [
EarlyStopping(monitor="Train/ELBO", min_delta=1e-3, patience=10, mode="min")
]
trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
trainer.fit(lit_model, train_dataloader, test_dataloader)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry. GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /Users/sem-k32/10 sem/bmm-multitask-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: PossibleUserWarning: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. | Name | Type | Params | Mode ------------------------------------------------------------- 0 | accuracy_computer | BinaryAccuracy | 0 | train 1 | elbo_f | ELBOModule | 8 | train 2 | model | SoloModel | 4 | train 3 | guide | AutoNormal | 4 | train 4 | predictive | Predictive | 8 | train ------------------------------------------------------------- 8 Trainable params 0 Non-trainable params 8 Total params 0.000 Total estimated model params size (MB) 7 Modules in train mode 0 Modules in eval mode
Training model 0
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/Users/sem-k32/10 sem/bmm-multitask-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: PossibleUserWarning: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance. /Users/sem-k32/10 sem/bmm-multitask-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: PossibleUserWarning: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=20` reached. Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry. GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs | Name | Type | Params | Mode ------------------------------------------------------------- 0 | accuracy_computer | BinaryAccuracy | 0 | train 1 | elbo_f | ELBOModule | 8 | train 2 | model | SoloModel | 4 | train 3 | guide | AutoNormal | 4 | train 4 | predictive | Predictive | 8 | train ------------------------------------------------------------- 8 Trainable params 0 Non-trainable params 8 Total params 0.000 Total estimated model params size (MB) 7 Modules in train mode 0 Modules in eval mode
Training model 1
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry. GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs | Name | Type | Params | Mode ------------------------------------------------------------- 0 | accuracy_computer | BinaryAccuracy | 0 | train 1 | elbo_f | ELBOModule | 8 | train 2 | model | SoloModel | 4 | train 3 | guide | AutoNormal | 4 | train 4 | predictive | Predictive | 8 | train ------------------------------------------------------------- 8 Trainable params 0 Non-trainable params 8 Total params 0.000 Total estimated model params size (MB) 7 Modules in train mode 0 Modules in eval mode
Training model 2
Sanity Checking: | | 0/? [00:00<?, ?it/s]
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=20` reached.
Metrics are saved in the tensorboard
%tensorboard --logdir mt_logs/solo
Reusing TensorBoard on port 6006 (pid 30339), started 0:00:16 ago. (Use '!kill 30339' to kill it.)
Variational multitask models¶
First, register KL computation for $\delta$-distribution. Here we don't have state variables and won't compute its KL at all.
@distr.kl.register_kl(Delta, Delta)
def kl_delta_delta(d1: Delta, d2: Delta):
return torch.zeros(d1.batch_shape)
# this is how it should be
# return torch.zeros(d1.batch_shape) if torch.allclose(d1.v, d2.v) else torch.full(torch.inf, d1.batch_shape)
Define batched distributions for tasks. Here we assume that latents come with (batch_size, num_latent_particles, ...) shape, classifiers come with (num_classifier_samples, ...) shape
# same for all tasks
def target_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
return distr.Bernoulli(logits=torch.tensordot(Z, W, dims=[[-1], [-1]]))
# same for all tasks
def predictive_distr(Z: torch.Tensor, W: torch.Tensor) -> distr.Distribution:
return distr.Bernoulli(logits=torch.tensordot(Z, W, dims=[[-1], [-1]]).flatten(1, 2))
task_distrs = [target_distr for _ in range(NUM_MODELS)]
task_num_samples = list(train_datasets | select(len))
# we don't have latents here, but we need it formaly as distribution
def latent_distr(X: torch.Tensor) -> distr.Distribution:
return Delta(X, event_dim=1)
class NormalLogits(distr.Normal):
"""Normal distribution with scale parametrized via logits
"""
def __init__(self, loc, logit, validate_args=None):
self.logit = logit
super().__init__(loc, torch.exp(logit), validate_args)
def __getattribute__(self, name):
if name == "scale":
return self.logit.exp()
else:
return super().__getattribute__(name)
# parametric variational distr for classifiers
classifier_distrs_params = {}
classifier_distrs = []
for i in range(NUM_MODELS):
# set inital values for distribution's parameters
loc, scale_logit = nn.Parameter(torch.zeros((config.dim, ))), nn.Parameter(torch.zeros((config.dim, )))
classifier_distrs_params.update({
f"distr_{i}": [loc, scale_logit]
})
classifier_distrs.append(
distr.Independent(
NormalLogits(loc, scale_logit),
reinterpreted_batch_ndims=1
)
)
# parametric variational distr for latents
latent_distrs = [latent_distr for _ in range(NUM_MODELS)]
# temperature must decrease over steps
temp_scheduler = lambda step: 1. / torch.sqrt(torch.tensor(step + 1))
# create variational multitask elbo module
mt_elbo = MultiTaskElbo(
task_distrs,
task_num_samples,
classifier_distrs,
latent_distrs,
temp_scheduler=temp_scheduler,
**dict(config.mt_elbo)
)
We are going to use lightning to perform training and logging
class LitMtModel(L.LightningModule):
def __init__(
self,
mt_elbo: MultiTaskElbo
):
super().__init__()
num_tasks = mt_elbo.num_tasks
self.accuracy_computers = [Accuracy('binary') for _ in range(num_tasks)]
self.mt_elbo = mt_elbo
self.distr_params = nn.ParameterList()
for param_list in classifier_distrs_params.values():
self.distr_params.extend(
param_list
)
def training_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
mt_loss_dict = self.mt_elbo(*list(zip(*batch)), step=self.global_step)
self.log_dict(mt_loss_dict, prog_bar=True)
# DEBUG
return mt_loss_dict["elbo"]
def on_train_batch_end(self, outputs, batch, batch_idx):
with torch.no_grad():
for distr_name, distr_params in classifier_distrs_params.items():
params_grad_norm = sum(distr_params | select(lambda param: param.grad.norm()))
self.log(f"{distr_name}_grad", params_grad_norm)
def on_train_epoch_end(self):
# log mixing
fig, ax = plt.subplots()
cax = ax.matshow(self.mt_elbo.latent_mixings_params.detach().numpy())
fig.colorbar(cax)
self.logger.experiment.add_figure("Latent_mixing", fig, self.global_step)
fig, ax = plt.subplots()
cax = ax.matshow(self.mt_elbo.classifier_mixings_params.detach().numpy())
fig.colorbar(cax)
self.logger.experiment.add_figure("Classifier_mixing", fig, self.global_step)
def validation_step(self, batch: tuple[tuple[torch.Tensor]], batch_idx: int):
for i, (X, y) in enumerate(batch):
NUM_PREDICTIVE_SAMPLES = 10
cur_predictive = build_predictive(
predictive_distr,
classifier_distrs[i],
latent_distrs[i],
X,
config.mt_elbo.classifier_num_particles,
config.mt_elbo.latent_num_particles
)
y_pred = (cur_predictive.sample((NUM_PREDICTIVE_SAMPLES, )).mean(dim=0) > 0.5).float()
self.accuracy_computers[i].update(y_pred, y)
def on_validation_epoch_end(self):
for i, accuracy_computer in enumerate(self.accuracy_computers):
self.log(f"Test/Accuracy_{i}", accuracy_computer.compute())
accuracy_computer.reset()
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
lit_mt_model = LitMtModel(mt_elbo)
# stack task datasets
unified_train_dataset = StackDataset(*train_datasets)
unified_test_dataset = StackDataset(*test_datasets)
mt_train_dataloader = DataLoader(unified_train_dataset, config.batch_size, shuffle=True)
mt_test_dataloader = DataLoader(unified_test_dataset, config.batch_size, shuffle=False)
logger = TensorBoardLogger("mt_logs", name="multitask")
callbacks = [
EarlyStopping(monitor="elbo", min_delta=1e-3, patience=10, mode="min")
]
trainer = L.Trainer(logger=logger, callbacks=callbacks, **dict(config.trainer))
trainer.fit(lit_mt_model, mt_train_dataloader, mt_test_dataloader)
Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry. GPU available: True (mps), used: False TPU available: False, using: 0 TPU cores HPU available: False, using: 0 HPUs /Users/sem-k32/10 sem/bmm-multitask-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/setup.py:177: PossibleUserWarning: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. | Name | Type | Params | Mode ------------------------------------------------------- 0 | mt_elbo | MultiTaskElbo | 9 | train 1 | distr_params | ParameterList | 12 | train ------------------------------------------------------- 21 Trainable params 0 Non-trainable params 21 Total params 0.000 Total estimated model params size (MB) 2 Modules in train mode 0 Modules in eval mode
Sanity Checking: | | 0/? [00:00<?, ?it/s]
/Users/sem-k32/10 sem/bmm-multitask-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: PossibleUserWarning: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance. /Users/sem-k32/10 sem/bmm-multitask-learning/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: PossibleUserWarning: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
Training: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
Validation: | | 0/? [00:00<?, ?it/s]
`Trainer.fit` stopped: `max_epochs=20` reached.
Metrics are saved in the tensorboard
%tensorboard --logdir mt_logs/multitask
Reusing TensorBoard on port 6007 (pid 30356), started 0:00:04 ago. (Use '!kill 30356' to kill it.)
As we can see, the training is successful. From the vizualized classifier mixing paramters it is clear that the algorithm has found connection between task 1 and 2.
Solo and multitask comparasion¶
Final accuracy metrics are equal within the error margin for solo and multitask models. Because the tasks are simple and synthetic, we have not seen the performance difference. But for more serious tasks with possible probabilistic connections the variational multitask may give significant boost.