Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel Kalman filter and smoother with inputs #385

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 50 additions & 31 deletions dynamax/linear_gaussian_ssm/parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import jax.numpy as jnp
from jax import vmap, lax
from jaxtyping import Array, Float
from typing import NamedTuple
from typing import NamedTuple, Optional
from dynamax.types import PRNGKey
from functools import partial
import warnings
Expand All @@ -45,6 +45,7 @@
from jax.scipy.linalg import cho_solve, cho_factor
from dynamax.utils.utils import symmetrize, psd_solve
from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM
from dynamax.linear_gaussian_ssm.inference import _zeros_if_none


def _get_one_param(x, dim, t):
Expand All @@ -56,14 +57,16 @@ def _get_one_param(x, dim, t):
else:
return x

def _get_params(params, num_timesteps, t):
def _get_params(params: ParamsLGSSM, num_timesteps, t):
"""Helper function to get parameters at time t."""
assert not callable(params.emissions.cov), "Emission covariance cannot be a callable."

F = _get_one_param(params.dynamics.weights, 2, t)
B = _get_one_param(params.dynamics.input_weights, 2, t)
b = _get_one_param(params.dynamics.bias, 1, t)
Q = _get_one_param(params.dynamics.cov, 2, t)
H = _get_one_param(params.emissions.weights, 2, t+1)
D = _get_one_param(params.emissions.input_weights, 2, t+1)
d = _get_one_param(params.emissions.bias, 1, t+1)

if len(params.emissions.cov.shape) == 1:
Expand All @@ -81,7 +84,7 @@ def _get_params(params, num_timesteps, t):
"The covariance will be interpreted as static and non-diagonal. To "
"specify a dynamic and diagonal covariance, pass it as a 3D array.")

return F, b, Q, H, d, R
return F, B, b, Q, H, D, d, R


#---------------------------------------------------------------------------#
Expand Down Expand Up @@ -151,13 +154,18 @@ class FilterMessage(NamedTuple):
logZ: Float[Array, "ntime"]


def _initialize_filtering_messages(params, emissions):
def _initialize_filtering_messages(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
):
"""Preprocess observations to construct input for filtering assocative scan."""

num_timesteps = emissions.shape[0]
inputs = _zeros_if_none(inputs, (num_timesteps, 0))

def _first_message(params, y):
H, d, R = _get_params(params, num_timesteps, -1)[3:]
def _first_message(params, y, u):
H, D, d, R = _get_params(params, num_timesteps, -1)[4:]
m = params.initial.mean
P = params.initial.cov

Expand All @@ -166,34 +174,35 @@ def _first_message(params, y):
K = P @ H.T @ S_inv

A = jnp.zeros_like(P)
b = m + K @ (y - H @ m - d)
b = m + K @ (y - H @ m - D @ u - d)
C = symmetrize(P - K @ S @ K.T)
eta = jnp.zeros_like(b)
J = jnp.eye(len(b))

logZ = _marginal_loglik_elem(P, H, R, y)
logZ = _marginal_loglik_elem(P, H, R, y - H @ m - D @ u - d)
return A, b, C, J, eta, logZ


@partial(vmap, in_axes=(None, 0, 0))
def _generic_message(params, y, t):
F, b, Q, H, d, R = _get_params(params, num_timesteps, t)
@partial(vmap, in_axes=(None, 0, 0, 0))
def _generic_message(params, y, u, t):
F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t)

S_inv = _emissions_scale(Q, H, R)
K = Q @ H.T @ S_inv

eta = F.T @ H.T @ S_inv @ (y - H @ b - d)
innov = (y - H @ b - D @ u - d)
eta = F.T @ H.T @ S_inv @ innov
J = symmetrize(F.T @ H.T @ S_inv @ H @ F)

A = F - K @ H @ F
b = b + K @ (y - H @ b - d)
b = b + B @ u + K @ innov
C = symmetrize(Q - K @ H @ Q)

logZ = _marginal_loglik_elem(Q, H, R, y)
logZ = _marginal_loglik_elem(Q, H, R, innov)
return A, b, C, J, eta, logZ

A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0])
At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1))
A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0])
At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], inputs[1:], jnp.arange(len(emissions)-1))

return FilterMessage(
A=jnp.concatenate([A0[None], At]),
Expand All @@ -208,7 +217,8 @@ def _generic_message(params, y, t):

def lgssm_filter(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> PosteriorGSSMFiltered:
"""A parallel version of the lgssm filtering algorithm.

Expand Down Expand Up @@ -238,7 +248,7 @@ def _operator(elem1, elem2):
logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1)
return FilterMessage(A, b, C, J, eta, logZ)

initial_messages = _initialize_filtering_messages(params, emissions)
initial_messages = _initialize_filtering_messages(params, emissions, inputs)
final_messages = lax.associative_scan(_operator, initial_messages)

return PosteriorGSSMFiltered(
Expand All @@ -265,25 +275,30 @@ class SmoothMessage(NamedTuple):
L: Float[Array, "ntime state_dim state_dim"]


def _initialize_smoothing_messages(params, filtered_means, filtered_covariances):
def _initialize_smoothing_messages(params: ParamsLGSSM,
filtered_means: Float[Array, "ntime state_dim"],
filtered_covariances: Float[Array, "ntime state_dim state_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> SmoothMessage:
"""Preprocess filtering output to construct input for smoothing assocative scan."""

def _last_message(m, P):
return jnp.zeros_like(P), m, P

num_timesteps = filtered_means.shape[0]
inputs = _zeros_if_none(inputs, (num_timesteps, 0))

@partial(vmap, in_axes=(None, 0, 0, 0))
def _generic_message(params, m, P, t):
F, b, Q = _get_params(params, num_timesteps, t)[:3]
@partial(vmap, in_axes=(None, 0, 0, 0, 0))
def _generic_message(params, m, P, u, t):
F, B, b, Q = _get_params(params, num_timesteps, t)[:4]
CF, low = cho_factor(F @ P @ F.T + Q)
E = cho_solve((CF, low), F @ P).T
g = m - E @ (F @ m + b)
g = m - E @ (F @ m + b + B @ u)
L = symmetrize(P - E @ F @ P)
return E, g, L

En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1])
Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1))
Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], inputs[:-1], jnp.arange(len(filtered_means)-1))

return SmoothMessage(
E=jnp.concatenate([Et, En[None]]),
Expand All @@ -294,15 +309,16 @@ def _generic_message(params, m, P, t):

def lgssm_smoother(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> PosteriorGSSMSmoothed:
"""A parallel version of the lgssm smoothing algorithm.

See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.

Note: This function does not yet handle `inputs` to the system.
"""
filtered_posterior = lgssm_filter(params, emissions)
filtered_posterior = lgssm_filter(params, emissions, inputs)
filtered_means = filtered_posterior.filtered_means
filtered_covs = filtered_posterior.filtered_covariances

Expand All @@ -315,7 +331,7 @@ def _operator(elem1, elem2):
L = symmetrize(E2 @ L1 @ E2.T + L2)
return E, g, L

initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs)
initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs, inputs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)

return PosteriorGSSMSmoothed(
Expand Down Expand Up @@ -343,7 +359,9 @@ class SampleMessage(NamedTuple):
h: Float[Array, "ntime state_dim"]


def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances):
def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances,
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> SampleMessage:
"""A parallel version of the lgssm sampling algorithm.

Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`,
Expand All @@ -356,15 +374,16 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian
def lgssm_posterior_sample(
key: PRNGKey,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> Float[Array, "ntime state_dim"]:
"""A parallel version of the lgssm sampling algorithm.

See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.

Note: This function does not yet handle `inputs` to the system.
"""
filtered_posterior = lgssm_filter(params, emissions)
filtered_posterior = lgssm_filter(params, emissions, inputs)
filtered_means = filtered_posterior.filtered_means
filtered_covs = filtered_posterior.filtered_covariances

Expand All @@ -377,6 +396,6 @@ def _operator(elem1, elem2):
h = E2 @ h1 + h2
return E, h

initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs)
initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs, inputs)
_, samples = lax.associative_scan(_operator, initial_messages, reverse=True)
return samples
70 changes: 68 additions & 2 deletions dynamax/linear_gaussian_ssm/parallel_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_joint_sample
from dynamax.linear_gaussian_ssm import lgssm_smoother as serial_lgssm_smoother
from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother
from dynamax.linear_gaussian_ssm import lgssm_smoother as serial_lgssm_smoother, lgssm_filter as serial_lgssm_filter
from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother, parallel_lgssm_filter
from dynamax.linear_gaussian_ssm import lgssm_posterior_sample as serial_lgssm_posterior_sample
from dynamax.linear_gaussian_ssm import parallel_lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov
Expand Down Expand Up @@ -45,6 +45,37 @@ def make_static_lgssm_params():
emission_weights=H,
emission_covariance=R)
return params, lgssm


def make_lgssm_params_with_inputs():
dt = 0.1
F = jnp.eye(4) + dt * jnp.eye(4, k=2)
B = jnp.array([[0., 0.], [1., 0.], [0., 0.], [0., 1.]]) * dt
Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
[dt**2/2, dt]]),
jnp.eye(2))

H = jnp.eye(2, 4)
D = jnp.ones((2, 2))
R = 0.5 ** 2 * jnp.eye(2)
μ0 = jnp.array([0.,0.,1.,-1.])
Σ0 = jnp.eye(4)

latent_dim = 4
observation_dim = 2
input_dim = 2

lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim)
params, _ = lgssm.initialize(jr.PRNGKey(0),
initial_mean=μ0,
initial_covariance= Σ0,
dynamics_weights=F,
dynamics_input_weights=B,
dynamics_covariance=Q,
emission_weights=H,
emission_input_weights=D,
emission_covariance=R)
return params, lgssm


def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, seed=0):
Expand Down Expand Up @@ -114,6 +145,41 @@ def test_marginal_loglik(self):
assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1)


class TestParallelLGSSMSmootherWithInputs:
""" Compare parallel and serial lgssm smoothing implementations."""

num_timesteps = 50
key = jr.PRNGKey(1)

params, lgssm = make_lgssm_params_with_inputs()
params_diag = flatten_diagonal_emission_cov(params)
inputs = jnp.ones((num_timesteps, 2))
_, emissions = lgssm_joint_sample(params, key, num_timesteps, inputs)


serial_posterior = serial_lgssm_smoother(params, emissions, inputs)
parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs)
parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs)

def test_filtered_means(self):
assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means)
assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means)

def test_filtered_covariances(self):
assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior.filtered_covariances)
assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances)

def test_smoothed_means(self):
assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior.smoothed_means)
assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means)

def test_smoothed_covariances(self):
assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances)
assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances)

def test_marginal_loglik(self):
assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1)
assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1)


class TestTimeVaryingParallelLGSSMSmoother:
Expand Down
Loading