Source code for kalman.gaussian

# Author: torch-kf
# https://github.com/raphaelreme/torch-kf/blob/main/torch_kf/kalman_filter.py

import torch
from torch import nn
from typing import Optional, overload


[docs] class GaussianState: """ Class for Gaussian state. Attributes: mean (torch.Tensor): Mean of the distribution Shape: (*, dim, 1) covariance (torch.Tensor): Covariance of the distribution Shape: (*, dim, dim) precision (Optional[torch.Tensor]): Optional inverse covariance matrix This may be useful for some computations (E.G mahalanobis distance, likelihood) after a predict step. Shape: (*, dim, dim) """ def __init__(self, mean: torch.Tensor, covariance: torch.Tensor, precision: Optional[torch.Tensor] = None): self.mean = mean self.covariance = covariance self.precision = precision
[docs] def clone(self) -> "GaussianState": """Clone the Gaussian State using `torch.Tensor.clone` Returns: GaussianState: A copy of the Gaussian state """ return GaussianState( self.mean.clone(), self.covariance.clone(), self.precision.clone() if self.precision is not None else None )
def __getitem__(self, idx) -> "GaussianState": return GaussianState( self.mean[idx], self.covariance[idx], self.precision[idx] if self.precision is not None else None ) def __setitem__(self, idx, value) -> None: if isinstance(value, GaussianState): self.mean[idx] = value.mean self.covariance[idx] = value.covariance if self.precision is not None and value.precision is not None: self.precision[idx] = value.precision return raise NotImplementedError() @overload def to(self, dtype: torch.dtype) -> "GaussianState": ... @overload def to(self, device: torch.device) -> "GaussianState": ...
[docs] def to(self, fmt): """Convert a GaussianState to a specific device or dtype Args: fmt (torch.dtype | torch.device): Memory format to send the state to. Returns: GaussianState: The GaussianState with the right format """ return GaussianState( self.mean.to(fmt), self.covariance.to(fmt), self.precision.to(fmt) if self.precision is not None else None, )
[docs] def mahalanobis_squared(self, measure: torch.Tensor) -> torch.Tensor: """Computes the squared mahalanobis distance of given measure It supports batch computation: You can provide multiple measurements and have multiple states You just need to ensure that shapes are broadcastable. Args: measure (torch.Tensor): Points to consider Shape: (*, dim, 1) Returns: torch.Tensor: Squared mahalanobis distance for each measure/state Shape: (*) """ diff = self.mean - measure # You are responsible for broadcast if self.precision is None: # The inverse is transposed (back) to be contiguous: as it is symmetric # This is equivalent and faster to hold on the contiguous verison # But this may slightly increase floating errors. self.precision = self.covariance.inverse().mT return (diff.mT @ self.precision @ diff)[..., 0, 0] # Delete trailing dimensions
[docs] def mahalanobis(self, measure: torch.Tensor) -> torch.Tensor: """Computes the mahalanobis distance of given measure Computations of the sqrt can be slow. If you want to compare with a given threshold, you should rather compare the squared mahalanobis with the squared threshold. It supports batch computation: You can provide multiple measurements and have multiple states You just need to ensure that shapes are broadcastable. Args: measure (torch.Tensor): Points to consider Shape: (*, dim, 1) Returns: torch.Tensor: Mahalanobis distance for each measure/state Shape: (*) """ return self.mahalanobis_squared(measure).sqrt()
[docs] def log_likelihood(self, measure: torch.Tensor) -> torch.Tensor: """Computes the log-likelihood of given measure It supports batch computation: You can provide multiple measurements and have multiple states You just need to ensure that shapes are broadcastable. Args: measure (torch.Tensor): Points to consider Shape: (*, dim, 1) Returns: torch.Tensor: Log-likelihood for each measure/state Shape: (*, 1) """ maha_2 = self.mahalanobis_squared(measure) log_det = torch.log(torch.det(self.covariance)) return -0.5 * (self.covariance.shape[-1] * torch.log(2 * torch.tensor(torch.pi)) + log_det + maha_2)
[docs] def likelihood(self, measure: torch.Tensor) -> torch.Tensor: """Computes the likelihood of given measure It supports batch computation: You can provide multiple measurements and have multiple states You just need to ensure that shapes are broadcastable. Args: measure (torch.Tensor): Points to consider Shape: (*, dim, 1) Returns: torch.Tensor: Likelihood for each measure/state Shape: (*, 1) """ return self.log_likelihood(measure).exp()