Skip to content

Bayesian Layers

Drop-in replacements for standard torch.nn layers that implement the Local Reparameterization Trick (LRT).

Bayesian Linear

bensemble.layers.linear.BayesianLinear

BayesianLinear(
    in_features: int,
    out_features: int,
    prior_sigma: float = 1.0,
    init_sigma: float = 0.1,
    weight_init: str = "kaiming",
)

Bases: BaseBayesianLayer

Bayesian Linear layer implementing Variational Inference with the Local Reparameterization Trick.

Weights and biases are modeled as Gaussian distributions with learnable means and standard deviations (parametrized by rho).

Parameters:

Name Type Description Default
in_features int

Size of each input sample.

required
out_features int

Size of each output sample.

required
prior_sigma float

Standard deviation of the prior Gaussian distribution.

1.0
init_sigma float

Initial standard deviation for the posterior.

0.1
weight_init str

Initialization method for weight means ('kaiming', 'xavier', or 'normal').

'kaiming'
Source code in bensemble/layers/linear.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    prior_sigma: float = 1.0,
    init_sigma: float = 0.1,
    weight_init: str = "kaiming",
):
    """
    Args:
        in_features: Size of each input sample.
        out_features: Size of each output sample.
        prior_sigma: Standard deviation of the prior Gaussian distribution.
        init_sigma: Initial standard deviation for the posterior.
        weight_init: Initialization method for weight means ('kaiming', 'xavier', or 'normal').
    """
    super().__init__(prior_sigma=prior_sigma)
    self.in_features = in_features
    self.out_features = out_features
    self.init_sigma = init_sigma
    self.weight_init = weight_init

    self.w_mu = nn.Parameter(torch.empty(out_features, in_features))
    self.w_rho = nn.Parameter(torch.empty(out_features, in_features))

    self.b_mu = nn.Parameter(torch.empty(out_features))
    self.b_rho = nn.Parameter(torch.empty(out_features))

    self.reset_parameters()

forward

forward(x: Tensor) -> torch.Tensor

Forward pass with Local Reparameterization Trick.

Source code in bensemble/layers/linear.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass with Local Reparameterization Trick.
    """
    if not self.training:
        return F.linear(x, self.w_mu, self.b_mu)

    w_sigma = F.softplus(self.w_rho)
    b_sigma = F.softplus(self.b_rho)

    gamma = F.linear(x, self.w_mu)
    delta = F.linear(x.pow(2), w_sigma.pow(2)) + b_sigma.pow(2)

    eps = torch.randn_like(gamma)
    return gamma + eps * torch.sqrt(delta + 1e-8) + self.b_mu

Bayesian Conv2d

bensemble.layers.conv.BayesianConv2d

BayesianConv2d(
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Tuple[int, int]],
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    groups: int = 1,
    prior_sigma: float = 1.0,
    init_sigma: float = 0.1,
)

Bases: BaseBayesianLayer

Bayesian Convolutional Layer (2D) with Local Reparameterization Trick.

Source code in bensemble/layers/conv.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    kernel_size: Union[int, Tuple[int, int]],
    stride: Union[int, Tuple[int, int]] = 1,
    padding: Union[int, Tuple[int, int]] = 0,
    dilation: Union[int, Tuple[int, int]] = 1,
    groups: int = 1,
    prior_sigma: float = 1.0,
    init_sigma: float = 0.1,
):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.init_sigma = init_sigma
    self.kernel_size = (
        kernel_size
        if isinstance(kernel_size, tuple)
        else (kernel_size, kernel_size)
    )
    self.stride = stride
    self.padding = padding
    self.dilation = dilation
    self.groups = groups
    self.prior_sigma = prior_sigma

    weight_shape = (out_channels, in_channels // groups, *self.kernel_size)

    self.w_mu = nn.Parameter(torch.empty(weight_shape))
    self.w_rho = nn.Parameter(torch.empty(weight_shape))

    self.b_mu = nn.Parameter(torch.empty(out_channels))
    self.b_rho = nn.Parameter(torch.empty(out_channels))

    self.reset_parameters()

Base Class

bensemble.layers.base.BaseBayesianLayer

BaseBayesianLayer(prior_sigma: float = 1.0)

Bases: Module

Base class for all bayesian layers.

Computes KL-divergence automatically for all parameters ending with _mu and _rho.

Source code in bensemble/layers/base.py
def __init__(self, prior_sigma: float = 1.0):
    super().__init__()
    self.prior_sigma = prior_sigma

get_pruning_masks

get_pruning_masks(threshold: float = 0.83) -> dict

Returns binary masks for parameters satisfying the SNR threshold.

Implements Graves' pruning heuristic where weights with low Signal-to-Noise Ratio are considered redundant and can be removed.

Parameters:

Name Type Description Default
threshold float

The SNR threshold (|mu|/sigma). Defaults to 0.83, the "safe" threshold suggested by Graves.

0.83

Returns:

Type Description
dict

dict[str, torch.Tensor]: A dictionary mapping parameter names to binary masks (1.0 for keeping, 0.0 for pruning).

Source code in bensemble/layers/base.py
def get_pruning_masks(self, threshold: float = 0.83) -> dict:
    """Returns binary masks for parameters satisfying the SNR threshold.

    Implements Graves' pruning heuristic where weights with low
    Signal-to-Noise Ratio are considered redundant and can be removed.

    Args:
        threshold (float, optional): The SNR threshold (|mu|/sigma).
            Defaults to 0.83, the "safe" threshold suggested by Graves.

    Returns:
        dict[str, torch.Tensor]: A dictionary mapping parameter names to
            binary masks (1.0 for keeping, 0.0 for pruning).
    """
    snr_dict = self._get_snr_dict()
    return {name: (val > threshold).float() for name, val in snr_dict.items()}

kl_divergence

kl_divergence() -> torch.Tensor

Computes KL-divergence KL(q || p) for all bayesian weights of the layer. p(w) = N(0, prior_sigma^2) q(w) = N(mu, sigma^2), where sigma = softplus(rho)

Source code in bensemble/layers/base.py
def kl_divergence(self) -> torch.Tensor:
    """
    Computes KL-divergence KL(q || p) for all bayesian weights of the layer.
    p(w) = N(0, prior_sigma^2)
    q(w) = N(mu, sigma^2), where sigma = softplus(rho)
    """
    total_kl = 0.0

    for name, param in self.named_parameters():
        if name.endswith("_mu"):
            rho_name = name.replace("_mu", "_rho")

            if hasattr(self, rho_name):
                mu = param
                rho = getattr(self, rho_name)

                total_kl += self._compute_kl_for_param(mu, rho)

    return total_kl