Extended Kalman Filter
- class kalman.extended.ExtendedKalmanFilter(*args: Any, **kwargs: Any)[source]
Generic Extended Kalman Filter.
Parameters
- state_dimint
Dimension of latent state x.
- obs_dimint
Dimension of observation z.
- fCallable[[torch.Tensor], torch.Tensor]
Non‑linear transition function x_{k|k} → x_{k+1|k}. Must broadcast over batches. Shape: (, state_dim) → (, state_dim)
- hCallable[[torch.Tensor], torch.Tensor]
Non‑linear measurement function x_{k|k} → z_pred. Shape: (, state_dim) → (, obs_dim)
- F_jacobianOptional[Callable[[torch.Tensor], torch.Tensor]]
Function returning Jacobian of f w.r.t. x. Shape: (*, state_dim, state_dim). If None, computed by autograd.
- H_jacobianOptional[Callable[[torch.Tensor], torch.Tensor]]
Function returning Jacobian of h w.r.t. x. Shape: (*, obs_dim, state_dim). If None, computed by autograd.
- QOptional[torch.Tensor]
Process‑noise covariance (state_dim × state_dim). Broadcastable to batch. Defaults to identity.
- ROptional[torch.Tensor]
Measurement‑noise covariance (obs_dim × obs_dim). Broadcastable to batch. Defaults to identity.
- init_meanOptional[torch.Tensor]
Initial state mean (state_dim,) or (B, state_dim). If None, zeros.
- init_covOptional[torch.Tensor]
Initial state covariance (state_dim×state_dim) or (B, state_dim, state_dim). If None, identity.
- smoothbool
Ignored for now (placeholder for RTS smoother).
- epsfloat
Jitter added to diagonals for numerical stability.
- forward(observations: torch.Tensor) Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] [source]
Run the EKF over a sequence of observations.
Parameters
- observationstorch.Tensor
Shape (T, B, obs_dim)
Returns
all_states : GaussianState ‑‑ convenient wrapper holding the whole trajectory (all_means, all_covs) : Tuple[torch.Tensor, torch.Tensor]
Means shape (T, B, state_dim) Covs shape (T, B, state_dim, state_dim)
- predict(state: GaussianState) GaussianState [source]
EKF time‑update (prediction) step.
- Returns:
Predicted state
- predict_(state_mean: torch.Tensor, state_cov: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] [source]
EKF time‑update (prediction) step.
- Returns:
Predicted state
- predict_update(state: GaussianState, measurement: torch.Tensor) Tuple[torch.Tensor, torch.Tensor] [source]
Convenience wrapper that performs a time‑update immediately followed by a measurement‑update.
Parameters
state : torch.GaussianState measurement : torch.Tensor
Returns
- new_mean, new_covTuple[torch.Tensor, torch.Tensor]
Posterior x̂_{k|k} and P_{k|k} after incorporating z_k.
- update(state: GaussianState, measurement: torch.Tensor) GaussianState [source]
EKF measurement‑update (correction) step.