Source code for kalman.unscented

# kalman/unscented.py
from typing import Callable, Tuple, Optional

import torch

from kalman.gaussian import GaussianState          # <-- already in gaussian.py
from kalman.filters import BaseFilter              # <-- your abstract base


[docs] class UnscentedKalmanFilter(BaseFilter): r""" Scaled–sigma‑point Unscented Kalman Filter. Parameters ---------- state_dim, obs_dim : int Dimensions *n* and *m*. f, h : Callable[[torch.Tensor], torch.Tensor] Process / measurement models that expect sigma‑points of shape ``(..., 2n + 1, n)`` and return the propagated sigma‑points with the same shape. alpha, beta, kappa : float Standard UKF scaling parameters. Q, R : torch.Tensor, optional Process‑ and measurement‑noise covariances. init_mean, init_cov : torch.Tensor, optional Initial posterior (after a fictitious step 0 update). eps : float Jitter added to all Cholesky factorizations and post‑update covariances for numerical stability. """ # ------------------------------------------------------------------ # # constructor # # ------------------------------------------------------------------ # def __init__( self, state_dim: int, obs_dim: int, f: Callable[[torch.Tensor], torch.Tensor], h: Callable[[torch.Tensor], torch.Tensor], *, alpha: float = 1e-3, beta: float = 2.0, kappa: float = 0.0, Q: Optional[torch.Tensor] = None, R: Optional[torch.Tensor] = None, init_mean: Optional[torch.Tensor] = None, init_cov: Optional[torch.Tensor] = None, eps=1e-7 ): super().__init__(state_dim, obs_dim) self.f, self.h = f, h self.register_buffer('Q', Q) self.register_buffer('R', R) n = state_dim lam = alpha**2 * (n + kappa) - n self._gamma = torch.sqrt(torch.tensor(n + lam, dtype=Q.dtype)) self.eps = eps # weights (registered so they automatically follow device / dtype) Wm = torch.empty(2 * n + 1, dtype=Q.dtype) Wc = torch.empty_like(Wm) Wm[0] = lam / (n + lam) Wc[0] = Wm[0] + (1 - alpha**2 + beta) Wm[1:] = 0.5 / (n + lam) Wc[1:] = Wm[1:] self.register_buffer('Wm', Wm) self.register_buffer('Wc', Wc) # noises eye_x = torch.eye(state_dim) eye_z = torch.eye(obs_dim) self.register_buffer("Q", eye_x if Q is None else Q) self.register_buffer("R", eye_z if R is None else R) # initial posterior self.register_buffer("_init_mean", torch.zeros(state_dim) if init_mean is None else init_mean) self.register_buffer("_init_cov", eye_x.clone() if init_cov is None else init_cov) # ------------------------------------------------ sigma‑points def _sigma_points(self, mean: torch.Tensor, cov: torch.Tensor) -> torch.Tensor: """ mean (..., n) → sigma (..., 2n+1, n) """ n = mean.shape[-1] #chol = torch.linalg.cholesky(cov) # (..., n, n) try: chol = torch.linalg.cholesky(cov) # (..., n, n) except RuntimeError: # If Cholesky fails, add small diagonal noise and try again eps = self.eps * torch.eye(self.state_dim, device=cov.device, dtype=cov.dtype) chol = torch.linalg.cholesky(cov + eps) sigma = [mean] # μ scaled = self._gamma * chol # (..., n, n) for i in range(n): col = scaled[..., :, i] # (..., n) sigma.append(mean + col) sigma.append(mean - col) return torch.stack(sigma, dim=-2) # (..., 2n+1, n) # ------------------------------------------------ unscented transform def _unscented_transform( self, sigma: torch.Tensor, noise_cov: torch.Tensor, fn: Callable[[torch.Tensor], torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ sigma (..., 2n+1, n) → (mean (..., d), cov (..., d, d), y_sigma (..., 2n+1, d)) """ y_sigma = fn(sigma) # (..., 2n+1, d) # mean y_mean = torch.sum(self.Wm[..., None] * y_sigma, dim=-2) # covariance y_diff = y_sigma - y_mean.unsqueeze(-2) # (..., 2n+1, d) y_cov = torch.sum( self.Wc[..., None, None] * y_diff.unsqueeze(-1) * y_diff.unsqueeze(-2), # outer prod dim=-3, ) y_cov = y_cov + noise_cov return y_mean, y_cov, y_sigma # ------------------------------------------------ predict / update @torch.no_grad() def predict_(self, state_mean: torch.Tensor, state_cov: torch.Tensor): sigma = self._sigma_points(state_mean, state_cov) f = lambda x: self.f(x) # expects (..., n) → (..., n) pred_mean, pred_cov, _ = self._unscented_transform(sigma, self.Q, f) return pred_mean, pred_cov @torch.no_grad() def update_(self, state_mean: torch.Tensor, state_cov: torch.Tensor, measurement: torch.Tensor): # 1. сигма‑точки из априорного распределения sigma = self._sigma_points(state_mean, state_cov) # 2. прогноз состояния f = lambda x: self.f(x) x_mean, x_cov, x_sigma = self._unscented_transform(sigma, self.Q, f) # 3. прогноз измерения h = lambda x: self.h(x) z_mean, z_cov, z_sigma = self._unscented_transform(x_sigma, self.R, h) # 4. перекрестная ковариация P_xz x_diff = x_sigma - x_mean.unsqueeze(-2) # (..., 2n+1, n) z_diff = z_sigma - z_mean.unsqueeze(-2) # (..., 2n+1, m) P_xz = torch.sum( self.Wc[..., None, None] * x_diff.unsqueeze(-1) * z_diff.unsqueeze(-2), # (...,2n+1,n,m) dim=-3, ) # (..., n, m) # 5. Калман‑усиление K = P_xz @ torch.linalg.inv(z_cov) # (..., n, m) # 6. корректировка y = measurement - z_mean # (..., m) upd_mean = x_mean + K @ y # (..., n) upd_cov = x_cov - K @ z_cov @ K.mT # (..., n, n) return upd_mean, upd_cov # ================================================================== # # high‑level GaussianState API # # ================================================================== #
[docs] def predict(self, state: GaussianState) -> GaussianState: m, P = self.predict_(state.mean, state.covariance) return GaussianState(m, P)
[docs] def update(self, state: GaussianState, measurement: torch.Tensor) -> GaussianState: m, P = self.update_(state.mean, state.covariance, measurement) return GaussianState(m, P)
[docs] def predict_update(self, state: GaussianState, measurement: torch.Tensor) -> GaussianState: return self.update(self.predict(state), measurement)
# ================================================================== # # full sequence pass # # ================================================================== #
[docs] def forward(self, observations: torch.Tensor): """ Run the UKF over a sequence of observations. Parameters ---------- observations : torch.Tensor Shape (T, B, obs_dim) Returns ------- all_states : GaussianState ‑‑ convenient wrapper holding the whole trajectory """ T, B = observations.shape[:2] self._init_mean = self._init_mean.to(observations.device) self._init_cov = self._init_cov.to(observations.device) means = [] covs = [] for t in range(T): if t > 0: self._init_mean, self._init_cov = self.predict_(self._init_mean, self._init_cov) self._init_mean, self._init_cov = self.update_(self._init_mean, self._init_cov, observations[t]) means.append(self._init_mean) covs.append(self._init_cov) all_means = torch.stack(means, dim=0) # (T, B, state_dim) all_covs = torch.stack(covs, dim=0) # (T, B, state_dim, state_dim) return GaussianState(all_means, all_covs)