diff --git a/dynamax/linear_gaussian_ssm/parallel_inference.py b/dynamax/linear_gaussian_ssm/parallel_inference.py index 2e50e8d6..516387e1 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference.py @@ -33,7 +33,7 @@ import jax.numpy as jnp from jax import vmap, lax from jaxtyping import Array, Float -from typing import NamedTuple +from typing import NamedTuple, Optional from dynamax.types import PRNGKey from functools import partial import warnings @@ -45,6 +45,7 @@ from jax.scipy.linalg import cho_solve, cho_factor from dynamax.utils.utils import symmetrize, psd_solve from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM +from dynamax.linear_gaussian_ssm.inference import _zeros_if_none def _get_one_param(x, dim, t): @@ -56,14 +57,16 @@ def _get_one_param(x, dim, t): else: return x -def _get_params(params, num_timesteps, t): +def _get_params(params: ParamsLGSSM, num_timesteps, t): """Helper function to get parameters at time t.""" assert not callable(params.emissions.cov), "Emission covariance cannot be a callable." F = _get_one_param(params.dynamics.weights, 2, t) + B = _get_one_param(params.dynamics.input_weights, 2, t) b = _get_one_param(params.dynamics.bias, 1, t) Q = _get_one_param(params.dynamics.cov, 2, t) H = _get_one_param(params.emissions.weights, 2, t+1) + D = _get_one_param(params.emissions.input_weights, 2, t+1) d = _get_one_param(params.emissions.bias, 1, t+1) if len(params.emissions.cov.shape) == 1: @@ -81,7 +84,7 @@ def _get_params(params, num_timesteps, t): "The covariance will be interpreted as static and non-diagonal. To " "specify a dynamic and diagonal covariance, pass it as a 3D array.") - return F, b, Q, H, d, R + return F, B, b, Q, H, D, d, R #---------------------------------------------------------------------------# @@ -151,13 +154,18 @@ class FilterMessage(NamedTuple): logZ: Float[Array, "ntime"] -def _initialize_filtering_messages(params, emissions): +def _initialize_filtering_messages( + params: ParamsLGSSM, + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None + ): """Preprocess observations to construct input for filtering assocative scan.""" num_timesteps = emissions.shape[0] + inputs = _zeros_if_none(inputs, (num_timesteps, 0)) - def _first_message(params, y): - H, d, R = _get_params(params, num_timesteps, -1)[3:] + def _first_message(params, y, u): + H, D, d, R = _get_params(params, num_timesteps, -1)[4:] m = params.initial.mean P = params.initial.cov @@ -166,34 +174,35 @@ def _first_message(params, y): K = P @ H.T @ S_inv A = jnp.zeros_like(P) - b = m + K @ (y - H @ m - d) + b = m + K @ (y - H @ m - D @ u - d) C = symmetrize(P - K @ S @ K.T) eta = jnp.zeros_like(b) J = jnp.eye(len(b)) - logZ = _marginal_loglik_elem(P, H, R, y) + logZ = _marginal_loglik_elem(P, H, R, y - H @ m - D @ u - d) return A, b, C, J, eta, logZ - @partial(vmap, in_axes=(None, 0, 0)) - def _generic_message(params, y, t): - F, b, Q, H, d, R = _get_params(params, num_timesteps, t) + @partial(vmap, in_axes=(None, 0, 0, 0)) + def _generic_message(params, y, u, t): + F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t) S_inv = _emissions_scale(Q, H, R) K = Q @ H.T @ S_inv - eta = F.T @ H.T @ S_inv @ (y - H @ b - d) + innov = (y - H @ b - D @ u - d) + eta = F.T @ H.T @ S_inv @ innov J = symmetrize(F.T @ H.T @ S_inv @ H @ F) A = F - K @ H @ F - b = b + K @ (y - H @ b - d) + b = b + B @ u + K @ innov C = symmetrize(Q - K @ H @ Q) - logZ = _marginal_loglik_elem(Q, H, R, y) + logZ = _marginal_loglik_elem(Q, H, R, innov) return A, b, C, J, eta, logZ - A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0]) - At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1)) + A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0]) + At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], inputs[1:], jnp.arange(len(emissions)-1)) return FilterMessage( A=jnp.concatenate([A0[None], At]), @@ -208,7 +217,8 @@ def _generic_message(params, y, t): def lgssm_filter( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMFiltered: """A parallel version of the lgssm filtering algorithm. @@ -238,7 +248,7 @@ def _operator(elem1, elem2): logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1) return FilterMessage(A, b, C, J, eta, logZ) - initial_messages = _initialize_filtering_messages(params, emissions) + initial_messages = _initialize_filtering_messages(params, emissions, inputs) final_messages = lax.associative_scan(_operator, initial_messages) return PosteriorGSSMFiltered( @@ -265,25 +275,30 @@ class SmoothMessage(NamedTuple): L: Float[Array, "ntime state_dim state_dim"] -def _initialize_smoothing_messages(params, filtered_means, filtered_covariances): +def _initialize_smoothing_messages(params: ParamsLGSSM, + filtered_means: Float[Array, "ntime state_dim"], + filtered_covariances: Float[Array, "ntime state_dim state_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None +) -> SmoothMessage: """Preprocess filtering output to construct input for smoothing assocative scan.""" def _last_message(m, P): return jnp.zeros_like(P), m, P num_timesteps = filtered_means.shape[0] + inputs = _zeros_if_none(inputs, (num_timesteps, 0)) - @partial(vmap, in_axes=(None, 0, 0, 0)) - def _generic_message(params, m, P, t): - F, b, Q = _get_params(params, num_timesteps, t)[:3] + @partial(vmap, in_axes=(None, 0, 0, 0, 0)) + def _generic_message(params, m, P, u, t): + F, B, b, Q = _get_params(params, num_timesteps, t)[:4] CF, low = cho_factor(F @ P @ F.T + Q) E = cho_solve((CF, low), F @ P).T - g = m - E @ (F @ m + b) + g = m - E @ (F @ m + b + B @ u) L = symmetrize(P - E @ F @ P) return E, g, L En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1]) - Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1)) + Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], inputs[:-1], jnp.arange(len(filtered_means)-1)) return SmoothMessage( E=jnp.concatenate([Et, En[None]]), @@ -294,7 +309,8 @@ def _generic_message(params, m, P, t): def lgssm_smoother( params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> PosteriorGSSMSmoothed: """A parallel version of the lgssm smoothing algorithm. @@ -302,7 +318,7 @@ def lgssm_smoother( Note: This function does not yet handle `inputs` to the system. """ - filtered_posterior = lgssm_filter(params, emissions) + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances @@ -315,7 +331,7 @@ def _operator(elem1, elem2): L = symmetrize(E2 @ L1 @ E2.T + L2) return E, g, L - initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs) + initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs, inputs) final_messages = lax.associative_scan(_operator, initial_messages, reverse=True) return PosteriorGSSMSmoothed( @@ -343,7 +359,9 @@ class SampleMessage(NamedTuple): h: Float[Array, "ntime state_dim"] -def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances): +def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances, + inputs: Optional[Float[Array, "ntime input_dim"]]=None +) -> SampleMessage: """A parallel version of the lgssm sampling algorithm. Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`, @@ -356,7 +374,8 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian def lgssm_posterior_sample( key: PRNGKey, params: ParamsLGSSM, - emissions: Float[Array, "ntime emission_dim"] + emissions: Float[Array, "ntime emission_dim"], + inputs: Optional[Float[Array, "ntime input_dim"]]=None ) -> Float[Array, "ntime state_dim"]: """A parallel version of the lgssm sampling algorithm. @@ -364,7 +383,7 @@ def lgssm_posterior_sample( Note: This function does not yet handle `inputs` to the system. """ - filtered_posterior = lgssm_filter(params, emissions) + filtered_posterior = lgssm_filter(params, emissions, inputs) filtered_means = filtered_posterior.filtered_means filtered_covs = filtered_posterior.filtered_covariances @@ -377,6 +396,6 @@ def _operator(elem1, elem2): h = E2 @ h1 + h2 return E, h - initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs) + initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs, inputs) _, samples = lax.associative_scan(_operator, initial_messages, reverse=True) return samples \ No newline at end of file diff --git a/dynamax/linear_gaussian_ssm/parallel_inference_test.py b/dynamax/linear_gaussian_ssm/parallel_inference_test.py index cd6376b3..6ba6dd3e 100644 --- a/dynamax/linear_gaussian_ssm/parallel_inference_test.py +++ b/dynamax/linear_gaussian_ssm/parallel_inference_test.py @@ -5,8 +5,8 @@ from dynamax.linear_gaussian_ssm import LinearGaussianSSM from dynamax.linear_gaussian_ssm import lgssm_joint_sample -from dynamax.linear_gaussian_ssm import lgssm_smoother as serial_lgssm_smoother -from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother +from dynamax.linear_gaussian_ssm import lgssm_smoother as serial_lgssm_smoother, lgssm_filter as serial_lgssm_filter +from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother, parallel_lgssm_filter from dynamax.linear_gaussian_ssm import lgssm_posterior_sample as serial_lgssm_posterior_sample from dynamax.linear_gaussian_ssm import parallel_lgssm_posterior_sample from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov @@ -45,6 +45,37 @@ def make_static_lgssm_params(): emission_weights=H, emission_covariance=R) return params, lgssm + + +def make_lgssm_params_with_inputs(): + dt = 0.1 + F = jnp.eye(4) + dt * jnp.eye(4, k=2) + B = jnp.array([[0., 0.], [1., 0.], [0., 0.], [0., 1.]]) * dt + Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2], + [dt**2/2, dt]]), + jnp.eye(2)) + + H = jnp.eye(2, 4) + D = jnp.ones((2, 2)) + R = 0.5 ** 2 * jnp.eye(2) + μ0 = jnp.array([0.,0.,1.,-1.]) + Σ0 = jnp.eye(4) + + latent_dim = 4 + observation_dim = 2 + input_dim = 2 + + lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim) + params, _ = lgssm.initialize(jr.PRNGKey(0), + initial_mean=μ0, + initial_covariance= Σ0, + dynamics_weights=F, + dynamics_input_weights=B, + dynamics_covariance=Q, + emission_weights=H, + emission_input_weights=D, + emission_covariance=R) + return params, lgssm def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, seed=0): @@ -114,6 +145,41 @@ def test_marginal_loglik(self): assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) +class TestParallelLGSSMSmootherWithInputs: + """ Compare parallel and serial lgssm smoothing implementations.""" + + num_timesteps = 50 + key = jr.PRNGKey(1) + + params, lgssm = make_lgssm_params_with_inputs() + params_diag = flatten_diagonal_emission_cov(params) + inputs = jnp.ones((num_timesteps, 2)) + _, emissions = lgssm_joint_sample(params, key, num_timesteps, inputs) + + + serial_posterior = serial_lgssm_smoother(params, emissions, inputs) + parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs) + parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs) + + def test_filtered_means(self): + assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means) + assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means) + + def test_filtered_covariances(self): + assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior.filtered_covariances) + assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances) + + def test_smoothed_means(self): + assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior.smoothed_means) + assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means) + + def test_smoothed_covariances(self): + assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances) + assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances) + + def test_marginal_loglik(self): + assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1) + assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1) class TestTimeVaryingParallelLGSSMSmoother: