Skip to content

Variational Renyu Divergence

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
import numpy as np
import matplotlib.pyplot as plt

import torchvision
import torchvision.transforms as transforms

import torch
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F
class Classifier(nn.Module): 
    def __init__(self, classes: int = 10): 
        super().__init__() 
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) 
        self.pool = nn.MaxPool2d(2, 2) 
        #self.dropout1 = nn.Dropout2d(0.25) 
        #self.dropout2 = nn.Dropout2d(0.5) 
        self.fc1 = nn.Linear(64 * 7 * 7, 128) 
        self.fc2 = nn.Linear(128, classes) 

    def forward(self, x): 
        x = self.pool(F.relu(self.conv1(x))) 
        #x = self.dropout1(x) 
        x = self.pool(F.relu(self.conv2(x))) 
        #x = self.dropout2(x) 
        x = x.view(-1, 64 * 7 * 7) 
        x = F.relu(self.fc1(x)) 
        x = self.fc2(x) 
        return x
# imports for model changes
from src.methods.bayes.variational.net import VarBayesNet
from src.methods.bayes.variational.net import NormalVarBayesLayer
from src.methods.bayes.variational.optimization import VarRenuiLoss
# imports for trainer
from src.methods.bayes.variational.trainer import VarBayesTrainer, VarTrainerParams, Beta_Scheduler_Plato, CallbackLossAccuracy
from src.methods.report.base import ReportChain
from src.methods.report.variational import VarBaseReport
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
test_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

Define parameters for training

EPOCHS=5

BATCH_SIZE=1024
LR = 1e-3 #5e-4
# Split the training set into training and validation sets 
VAL_PERCENT = 0.2 # percentage of the data used for validation 
SAMPLES = 10
BETA = 1e-2 #5e-5 #len(train_dataset) *1. / BATCH_SIZE
BETA_FAC = 5e-1
PRUNE = -5#1.99, 2.1
PLATO_TOL = 20

# define a module and make a Bayesian model out of it
base_module = Classifier()
var_module = NormalVarBayesLayer(base_module)
model = VarBayesNet(base_module, nn.ModuleList([var_module]))

# set an optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)

# define losses
fit_loss = nn.CrossEntropyLoss(reduction="sum") 
renui_loss = VarRenuiLoss()
beta = BETA

train_params = VarTrainerParams(
                num_epochs= EPOCHS, 
                optimizer=optimizer, 
                fit_loss=fit_loss, 
                dist_loss= renui_loss,
                num_samples=SAMPLES,
                prune_threshold=PRUNE,
                beta=beta,
                callback_losses={'accuracy': CallbackLossAccuracy()}
                )

prepare dataloader

val_size    = int(VAL_PERCENT * len(train_dataset)) 
train_size  = len(train_dataset) - val_size 

t_dataset, v_dataset = torch.utils.data.random_split(train_dataset,  
                                                        [train_size,  
                                                            val_size]) 

# Create DataLoaders for the training and validation sets 
train_loader = torch.utils.data.DataLoader(t_dataset,  
                                        batch_size=BATCH_SIZE,  
                                        shuffle=True, 
                                        pin_memory=True) 

eval_loader = torch.utils.data.DataLoader(v_dataset,  
                                        batch_size=BATCH_SIZE,  
                                        shuffle=False, 
                                        pin_memory=True)

train model

model.to(device)

trainer = VarBayesTrainer(train_params, ReportChain([VarBaseReport()]), train_loader, eval_loader)

trainer.train(model)
  0%|          | 0/5 [00:00<?, ?it/s]


Epoch [1/5],Loss:4226.0458984375, KL Loss: 186235.15625. FitLoss: 2485.269287109375,Accuracy:0.10749605618990384,Validation Loss:4084.62255859375,Validation Accuracy:0.09221311475409837, Prune parameters: 596.0/421642,Beta: 0.01
Epoch [2/5],Loss:4147.62109375, KL Loss: 186219.140625. FitLoss: 2379.785400390625,Accuracy:0.1213172325721154,Validation Loss:4065.800537109375,Validation Accuracy:0.21413934426229508, Prune parameters: 634.0/421642,Beta: 0.01
Epoch [3/5],Loss:4097.51220703125, KL Loss: 186340.484375. FitLoss: 2320.719482421875,Accuracy:0.13963153545673082,Validation Loss:4037.357666015625,Validation Accuracy:0.2776639344262295, Prune parameters: 656.0/421642,Beta: 0.01
Epoch [4/5],Loss:4031.77392578125, KL Loss: 186475.59375. FitLoss: 2275.6953125,Accuracy:0.16670015775240382,Validation Loss:3954.070556640625,Validation Accuracy:0.4354508196721312, Prune parameters: 646.0/421642,Beta: 0.01
Epoch [5/5],Loss:3908.925537109375, KL Loss: 186752.4375. FitLoss: 2179.834228515625,Accuracy:0.22488356370192303,Validation Loss:3777.31298828125,Validation Accuracy:0.6127049180327869, Prune parameters: 651.0/421642,Beta: 0.01





<src.methods.bayes.variational.net_distribution.VarBayesModuleNetDistribution at 0x7fd84c833890>

Show, that trained model can be pruned well

thresholds = np.linspace(-5, 5, 10)
threshold_results = trainer.eval_thresholds(model, thresholds)

# plot dependence of accuracy on count on non pruned parameter
x_s = [(a.cnt_prune_parameters / a.cnt_params).cpu().numpy() for a in threshold_results]
x_s = 1. - np.array(x_s) # now we get partition of non pruned parameters
y_s = [a.custom_losses["val_accuracy"] for a in threshold_results]

plt.plot(x_s, y_s)
plt.xlabel("Non pruned params, %")
plt.ylabel("Accuracy, %")
plt.title("Accuracy on count of nonpruned params")
plt.grid()

png