Skip to content

Calibration

bensemble.calibration.scaling

TemperatureScaling

TemperatureScaling(init_temp: float = 1.5)

Bases: Module

Temperature Scaling for model calibration.

Divides the logits by a learnable scalar parameter T (temperature). This softens the probabilities and calibrates the model's confidence without changing its accuracy (the argmax remains the same).

Parameters:

Name Type Description Default
init_temp float

Initial value for temperature. Modern deep networks are typically overconfident, so starting with T > 1.0 is recommended. Defaults to 1.5.

1.5
Source code in bensemble/calibration/scaling.py
def __init__(self, init_temp: float = 1.5):
    """
    Args:
        init_temp (float, optional): Initial value for temperature.
            Modern deep networks are typically overconfident, so starting
            with T > 1.0 is recommended. Defaults to 1.5.
    """
    super().__init__()
    self.temperature = nn.Parameter(torch.ones(1) * init_temp)

fit

fit(
    logits: Tensor, labels: Tensor, max_iter: int = 50
) -> TemperatureScaling

Finds the optimal temperature T using a validation set.

Uses the L-BFGS optimizer, which is the standard and most efficient algorithm for this 1D convex optimization problem.

Parameters:

Name Type Description Default
logits Tensor

Unscaled logits from a hold-out validation set. Shape: [N, Num_classes].

required
labels Tensor

Ground truth class indices. Shape: [N].

required
max_iter int

Maximum number of L-BFGS iterations. Defaults to 50.

50

Returns:

Name Type Description
TemperatureScaling TemperatureScaling

The fitted model itself.

Source code in bensemble/calibration/scaling.py
def fit(
    self, logits: torch.Tensor, labels: torch.Tensor, max_iter: int = 50
) -> "TemperatureScaling":
    """
    Finds the optimal temperature T using a validation set.

    Uses the L-BFGS optimizer, which is the standard and most efficient
    algorithm for this 1D convex optimization problem.

    Args:
        logits (torch.Tensor): Unscaled logits from a hold-out validation set. Shape: [N, Num_classes].
        labels (torch.Tensor): Ground truth class indices. Shape: [N].
        max_iter (int, optional): Maximum number of L-BFGS iterations. Defaults to 50.

    Returns:
        TemperatureScaling: The fitted model itself.
    """

    logits = logits.detach()
    optimizer = optim.LBFGS([self.temperature], lr=0.01, max_iter=max_iter)

    def eval_loss():
        optimizer.zero_grad()
        scaled_logits = self.forward(logits)
        loss = F.cross_entropy(scaled_logits, labels)
        loss.backward()
        return loss

    optimizer.step(eval_loss)

    return self

forward

forward(logits: Tensor) -> torch.Tensor

Applies temperature scaling to the input logits.

Parameters:

Name Type Description Default
logits Tensor

Raw uncalibrated logits of shape[Batch, Num_classes].

required

Returns:

Type Description
Tensor

torch.Tensor: Scaled logits of the same shape.

Source code in bensemble/calibration/scaling.py
def forward(self, logits: torch.Tensor) -> torch.Tensor:
    """
    Applies temperature scaling to the input logits.

    Args:
        logits (torch.Tensor): Raw uncalibrated logits of shape[Batch, Num_classes].

    Returns:
        torch.Tensor: Scaled logits of the same shape.
    """
    return logits / self.temperature

VectorScaling

VectorScaling(num_classes: int)

Bases: Module

Vector Scaling (Multi-class extension of Platt Scaling).

Applies a per-class affine transformation to the uncalibrated logits: calibrated_logits = logits * a + b

Parameters:

Name Type Description Default
num_classes int

Number of classes in the classification task.

required
Source code in bensemble/calibration/scaling.py
def __init__(self, num_classes: int):
    """
    Args:
        num_classes (int): Number of classes in the classification task.
    """
    super().__init__()
    self.a = nn.Parameter(torch.ones(num_classes))
    self.b = nn.Parameter(torch.zeros(num_classes))

fit

fit(
    logits: Tensor, labels: Tensor, max_iter: int = 50
) -> VectorScaling

Finds the optimal vectors 'a' and 'b' using a validation set.

Parameters:

Name Type Description Default
logits Tensor

Unscaled logits from a hold-out validation set. Shape:[N, Num_classes].

required
labels Tensor

Ground truth class indices. Shape: [N].

required
max_iter int

Maximum number of L-BFGS iterations. Defaults to 50.

50

Returns:

Name Type Description
PlattScaling VectorScaling

The fitted model itself.

Source code in bensemble/calibration/scaling.py
def fit(
    self, logits: torch.Tensor, labels: torch.Tensor, max_iter: int = 50
) -> "VectorScaling":
    """
    Finds the optimal vectors 'a' and 'b' using a validation set.

    Args:
        logits (torch.Tensor): Unscaled logits from a hold-out validation set. Shape:[N, Num_classes].
        labels (torch.Tensor): Ground truth class indices. Shape: [N].
        max_iter (int, optional): Maximum number of L-BFGS iterations. Defaults to 50.

    Returns:
        PlattScaling: The fitted model itself.
    """
    logits = logits.detach()
    optimizer = optim.LBFGS([self.a, self.b], lr=0.01, max_iter=max_iter)

    def eval_loss():
        optimizer.zero_grad()
        scaled_logits = self.forward(logits)
        loss = F.cross_entropy(scaled_logits, labels)
        loss.backward()
        return loss

    optimizer.step(eval_loss)

    return self

forward

forward(logits: Tensor) -> torch.Tensor

Applies the learned affine transformation to the logits.

Parameters:

Name Type Description Default
logits Tensor

Raw uncalibrated logits of shape [Batch, Num_classes].

required

Returns:

Type Description
Tensor

torch.Tensor: Calibrated logits of the same shape.

Source code in bensemble/calibration/scaling.py
def forward(self, logits: torch.Tensor) -> torch.Tensor:
    """
    Applies the learned affine transformation to the logits.

    Args:
        logits (torch.Tensor): Raw uncalibrated logits of shape [Batch, Num_classes].

    Returns:
        torch.Tensor: Calibrated logits of the same shape.
    """
    return logits * self.a + self.b