Bayesian Layers
Drop-in replacements for standard torch.nn layers that implement the Local Reparameterization Trick (LRT).
Bayesian Linear
bensemble.layers.linear.BayesianLinear
BayesianLinear(
in_features: int,
out_features: int,
prior_sigma: float = 1.0,
init_sigma: float = 0.1,
weight_init: str = "kaiming",
)
Bases: BaseBayesianLayer
Bayesian Linear layer implementing Variational Inference with the Local Reparameterization Trick.
Weights and biases are modeled as Gaussian distributions with learnable means and standard deviations (parametrized by rho).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features | int | Size of each input sample. | required |
out_features | int | Size of each output sample. | required |
prior_sigma | float | Standard deviation of the prior Gaussian distribution. | 1.0 |
init_sigma | float | Initial standard deviation for the posterior. | 0.1 |
weight_init | str | Initialization method for weight means ('kaiming', 'xavier', or 'normal'). | 'kaiming' |
Source code in bensemble/layers/linear.py
forward
Forward pass with Local Reparameterization Trick.
Source code in bensemble/layers/linear.py
Bayesian Conv2d
bensemble.layers.conv.BayesianConv2d
BayesianConv2d(
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
prior_sigma: float = 1.0,
init_sigma: float = 0.1,
)
Bases: BaseBayesianLayer
Bayesian Convolutional Layer (2D) with Local Reparameterization Trick.
Source code in bensemble/layers/conv.py
Base Class
bensemble.layers.base.BaseBayesianLayer
Bases: Module
Base class for all bayesian layers.
Computes KL-divergence automatically for all parameters ending with _mu and _rho.
Source code in bensemble/layers/base.py
get_pruning_masks
Returns binary masks for parameters satisfying the SNR threshold.
Implements Graves' pruning heuristic where weights with low Signal-to-Noise Ratio are considered redundant and can be removed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
threshold | float | The SNR threshold (|mu|/sigma). Defaults to 0.83, the "safe" threshold suggested by Graves. | 0.83 |
Returns:
| Type | Description |
|---|---|
dict | dict[str, torch.Tensor]: A dictionary mapping parameter names to binary masks (1.0 for keeping, 0.0 for pruning). |
Source code in bensemble/layers/base.py
kl_divergence
Computes KL-divergence KL(q || p) for all bayesian weights of the layer. p(w) = N(0, prior_sigma^2) q(w) = N(mu, sigma^2), where sigma = softplus(rho)