Source code for kalman.filters

from typing import Tuple, Optional
from overrides import override

import torch
from torch import nn
from typing import Optional

from kalman.gaussian import GaussianState


[docs] class BaseFilter(nn.Module): """ Abstract base class for Kalman Filters """ def __init__(self, state_dim: int, obs_dim: int, smooth: bool = False): super().__init__() self.state_dim = state_dim self.obs_dim = obs_dim self.smooth = smooth
[docs] def predict_( self, state_mean: torch.Tensor, state_cov: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Internal function for single-step predict. Returns: predicted_state_mean, predicted_state_cov """ pass
[docs] def update_( self, state_mean: torch.Tensor, state_cov: torch.Tensor, measurement: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Internal single-step update. Returns: updated_state_mean, updated_state_cov """ pass
[docs] def predict(self, state: GaussianState) -> GaussianState: """ Single-step predict. Returns: GaussianStateŃŽ """ m, P = self.predict_(state.mean, state.covariance) return GaussianState(m, P)
[docs] def update(self, state: GaussianState, measurement: torch.Tensor) -> GaussianState: """ Single-step update. Returns: GaussianStateŃŽ """ m, P = self.update_(state.mean, state.covariance, measurement) return GaussianState(m, P)
[docs] def predict_update( self, state: GaussianState, measurement: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Single-step predict and update in one function. Returns: updated_state_mean, updated_state_cov """ predicted_state = self.update(state) updated_state = self.predict(predicted_state, measurement) return updated_state
[docs] def forward(self, observations: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Processes an entire sequence of observations in shape (T, B, obs_dim). Returns ------- all_states, all_states : GaussianState ‑‑ convenient wrapper holding the whole trajectory """ pass
[docs] class KalmanFilter(BaseFilter): """ Kalman Filter class. Attributes: process_matrix (torch.Tensor): State transition matrix (F) Shape: (*, state_dim, state_dim) measurement_matrix (torch.Tensor): Projection matrix (H) Shape: (*, obs_dim, state_dim) process_noise (torch.Tensor): Uncertainty on the process (Q) Shape: (*, state_dim, state_dim) measurement_noise (torch.Tensor): Uncertainty on the measure (R) Shape: (*, obs_dim, obs_dim) """ def __init__(self, process_matrix: torch.Tensor, measurement_matrix: torch.Tensor, process_noise: torch.Tensor, measurement_noise: torch.Tensor): super().__init__(process_matrix.shape[-1], measurement_matrix.shape[-1]) self.process_matrix = process_matrix self.measurement_matrix = measurement_matrix self.process_noise = process_noise self.measurement_noise = measurement_noise
[docs] def predict(self, state: GaussianState, *, process_matrix: Optional[torch.Tensor] = None, process_noise: Optional[torch.Tensor] = None,) -> GaussianState: """ Predict step of the Kalman Filter. Initially state_mean is (B, state_dim) """ if process_matrix is None: process_matrix = self.process_matrix if process_noise is None: process_noise = self.process_noise state_mean = process_matrix @ state.mean.unsqueeze(-1) # now it is (B, state_dim, 1) state_mean = state_mean.squeeze(-1) state_cov = (process_matrix @ state.covariance @ process_matrix.transpose(-2, -1)) + process_noise return GaussianState(state_mean, state_cov)
[docs] def project( self, state: GaussianState, *, measurement_matrix: Optional[torch.Tensor] = None, measurement_noise: Optional[torch.Tensor] = None, precompute_precision=True, ) -> GaussianState: """Project the current state (usually the prior) onto the measurement space Args: state (GaussianState): Current state estimation (Usually the results of `predict`) measurement_matrix (Optional[torch.Tensor]): Overwrite the default projection matrix Shape: (*, bs, state_dim) measurement_noise (Optional[torch.Tensor]): Overwrite the default projection noise) Shape: (*, obs_dim, obs_dim) precompute_precision (bool): Precompute precision matrix (inverse covariance) Done once to prevent more computations Default: True Returns: GaussianState: Prior on the next state """ if measurement_matrix is None: measurement_matrix = self.measurement_matrix if measurement_noise is None: measurement_noise = self.measurement_noise mean = measurement_matrix @ state.mean.unsqueeze(-1) # now it is (B, obs_dim, 1) mean = mean.squeeze(-1) cov = (measurement_matrix @ state.covariance @ measurement_matrix.transpose(-2, -1)) + measurement_noise precision = ( torch.linalg.inv(cov).transpose(-2, -1) if precompute_precision else None ) return GaussianState(mean, cov, precision)
@override def update(self, state: GaussianState, measurement: torch.Tensor, *, projection: Optional[GaussianState] = None, measurement_matrix: Optional[torch.Tensor] = None, measurement_noise: Optional[torch.Tensor] = None,) -> GaussianState: """ Update step of the Kalman Filter. """ if measurement_matrix is None: measurement_matrix = self.measurement_matrix if measurement_noise is None: measurement_noise = self.measurement_noise if projection is None: projection = self.project(state, measurement_matrix=measurement_matrix, measurement_noise=measurement_noise) residual = measurement - projection.mean # now it is (B, obs_dim) kalman_gain = ( state.covariance @ measurement_matrix.transpose(-2, -1) @ projection.precision ) # now it is (B, state_dim, obs_dim) updated_mean = state.mean + (kalman_gain @ residual.unsqueeze(-1)).squeeze(-1) # now it is (B, state_dim) updated_cov = ( state.covariance - kalman_gain @ measurement_matrix @ state.covariance ) # now it is (B, state_dim, state_dim) return GaussianState(updated_mean, updated_cov)
[docs] def forward(self, observations: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Run the Kalman Filter over a sequence of observations. """ means = [] covs = [] state = self.predict(self.initial_state) for obs in observations: state = self.update(state, obs) means.append(state.mean) covs.append(state.covariance) return GaussianState(torch.stack(means), torch.stack(covs))