Extended Kalman Filter

class kalman.extended.ExtendedKalmanFilter(*args: Any, **kwargs: Any)[source]

Generic Extended Kalman Filter.

Parameters

state_dimint

Dimension of latent state x.

obs_dimint

Dimension of observation z.

fCallable[[torch.Tensor], torch.Tensor]

Non‑linear transition function x_{k|k} → x_{k+1|k}. Must broadcast over batches. Shape: (, state_dim) → (, state_dim)

hCallable[[torch.Tensor], torch.Tensor]

Non‑linear measurement function x_{k|k} → z_pred. Shape: (, state_dim) → (, obs_dim)

F_jacobianOptional[Callable[[torch.Tensor], torch.Tensor]]

Function returning Jacobian of f w.r.t. x. Shape: (*, state_dim, state_dim). If None, computed by autograd.

H_jacobianOptional[Callable[[torch.Tensor], torch.Tensor]]

Function returning Jacobian of h w.r.t. x. Shape: (*, obs_dim, state_dim). If None, computed by autograd.

QOptional[torch.Tensor]

Process‑noise covariance (state_dim × state_dim). Broadcastable to batch. Defaults to identity.

ROptional[torch.Tensor]

Measurement‑noise covariance (obs_dim × obs_dim). Broadcastable to batch. Defaults to identity.

init_meanOptional[torch.Tensor]

Initial state mean (state_dim,) or (B, state_dim). If None, zeros.

init_covOptional[torch.Tensor]

Initial state covariance (state_dim×state_dim) or (B, state_dim, state_dim). If None, identity.

smoothbool

Ignored for now (placeholder for RTS smoother).

epsfloat

Jitter added to diagonals for numerical stability.

forward(observations: torch.Tensor) Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]][source]

Run the EKF over a sequence of observations.

Parameters

observationstorch.Tensor

Shape (T, B, obs_dim)

Returns

all_states : GaussianState ‑‑ convenient wrapper holding the whole trajectory (all_means, all_covs) : Tuple[torch.Tensor, torch.Tensor]

Means shape (T, B, state_dim) Covs shape (T, B, state_dim, state_dim)

predict(state: GaussianState) GaussianState[source]

EKF time‑update (prediction) step.

Returns:

Predicted state

predict_(state_mean: torch.Tensor, state_cov: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

EKF time‑update (prediction) step.

Returns:

Predicted state

predict_update(state: GaussianState, measurement: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

Convenience wrapper that performs a time‑update immediately followed by a measurement‑update.

Parameters

state : torch.GaussianState measurement : torch.Tensor

Returns

new_mean, new_covTuple[torch.Tensor, torch.Tensor]

Posterior x̂_{k|k} and P_{k|k} after incorporating z_k.

update(state: GaussianState, measurement: torch.Tensor) GaussianState[source]

EKF measurement‑update (correction) step.

update_(state_mean: torch.Tensor, state_cov: torch.Tensor, measurement: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]

EKF measurement‑update (correction) step.