Source code for kalman.vkf

# https://users.aalto.fi/~ssarkka/pub/mvb-akf-mlsp.pdf

from typing import Optional, Tuple
import torch
from torch import nn
from kalman.gaussian import GaussianState
from kalman.filters import BaseFilter

[docs] class VBKalmanFilter(BaseFilter): """ Variational Bayesian Adaptive Kalman Filter (VB-AKF) """ def __init__(self, process_matrix: torch.Tensor, measurement_matrix: torch.Tensor, process_noise: torch.Tensor, initial_measurement_cov: torch.Tensor, rho: float = 0.95, B: torch.Tensor = None, state_dim: int = None, obs_dim: int = None): if state_dim is None: state_dim = process_matrix.shape[-1] if obs_dim is None: obs_dim = measurement_matrix.shape[-2] super().__init__(state_dim, obs_dim) # Основные параметры self.process_matrix = process_matrix # F (state_dim, state_dim) self.measurement_matrix = measurement_matrix # H (obs_dim, state_dim) self.process_noise = process_noise # Q (state_dim, state_dim) # Параметры адаптации ковариации self.rho = rho self.B = B if B is not None else torch.sqrt(torch.tensor(rho)) * torch.eye(obs_dim) # Инициализация параметров обратного распределения Уишарта self.nu = obs_dim + 2 # Степени свободы self.V = (self.nu - obs_dim - 1) * initial_measurement_cov # Масштабная матрица
[docs] def predict(self, state: GaussianState, process_matrix: Optional[torch.Tensor] = None) -> GaussianState: """ Prediction step with covariance dynamics """ F = process_matrix if process_matrix is not None else self.process_matrix Q = self.process_noise # Predict state predicted_mean = state.mean @ F predicted_cov = F @ state.covariance @ F.T + Q # Predict covariance parameters self.nu = self.rho * (self.nu - self.obs_dim - 1) + self.obs_dim + 1 self.V = self.B @ self.V @ self.B.T return GaussianState(predicted_mean, predicted_cov)
[docs] def update(self, state: GaussianState, measurement: torch.Tensor) -> GaussianState: """ Iterative variational update step """ H = self.measurement_matrix y = measurement # Инициализация параметров m = state.mean.clone() P = state.covariance.clone() nu = self.nu + 1 V = self.V.clone() for i in range(5): R_inv = (nu - self.obs_dim - 1) * torch.inverse(V) S = H @ P @ H.T + torch.inverse(R_inv) K = P @ H.T @ torch.inverse(S) m = state.mean + torch.einsum('hv,v->h', K, y - state.mean @ H.T) P = state.covariance - K @ S @ K.transpose(-1, -2) # Обновление параметров ковариации V = self.V + H @ P @ H.T + torch.einsum('i,j->ij', y - m @ H.T, y - m @ H.T) # Сохраняем новые параметры self.V = V return GaussianState(m, P)
[docs] def forward(self, observations: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Process full sequence (T, B, obs_dim) """ T, _ = observations.shape means = torch.zeros(T, self.state_dim) covs = torch.zeros(T, self.state_dim, self.state_dim) # Инициализация current_state = GaussianState( torch.zeros(self.state_dim), torch.eye(self.state_dim).repeat(1, 1)) for t in range(T): # Predict predicted_state = self.predict(current_state) # Update updated_state = self.update(predicted_state, observations[t]) # Store results means[t] = updated_state.mean covs[t] = updated_state.covariance current_state = updated_state return means, covs
[docs] def get_measurement_covariance(self) -> torch.Tensor: """ Returns current estimate of measurement covariance """ return self.V / (self.nu - self.obs_dim - 1)