From a4aad78f84ce13d99109dbe5226c1c3a098dd372 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Tue, 25 Jun 2024 14:00:39 -0400 Subject: [PATCH 1/2] Use reverse=True keyword argument in lax.scan --- dynamax/generalized_gaussian_ssm/inference.py | 19 ++-- dynamax/hidden_markov_model/inference.py | 37 ++++--- dynamax/linear_gaussian_ssm/inference.py | 101 +++++++++--------- dynamax/linear_gaussian_ssm/info_inference.py | 16 +-- .../nonlinear_gaussian_ssm/inference_ekf.py | 38 ++++--- .../nonlinear_gaussian_ssm/inference_ukf.py | 18 ++-- dynamax/nonlinear_gaussian_ssm/sarkka_lib.py | 12 +-- 7 files changed, 132 insertions(+), 109 deletions(-) diff --git a/dynamax/generalized_gaussian_ssm/inference.py b/dynamax/generalized_gaussian_ssm/inference.py index 21daa57e..fa9af8fa 100644 --- a/dynamax/generalized_gaussian_ssm/inference.py +++ b/dynamax/generalized_gaussian_ssm/inference.py @@ -18,7 +18,6 @@ _jacfwd_2d = lambda f, x: jnp.atleast_2d(jacfwd(f)(x)) - class EKFIntegrals(NamedTuple): """ Lightweight container for EKF Gaussian integrals.""" gaussian_expectation: Callable = lambda f, m, P: jnp.atleast_1d(f(m)) @@ -85,7 +84,7 @@ def compute_weights_and_sigmas(self, m, P): def _predict(m, P, f, Q, u, g_ev, g_cov): """Predict next mean and covariance under an additive-noise Gaussian filter - + p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred) where mu_pred = gev(f, m, P) @@ -337,13 +336,17 @@ def _step(carry, args): return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov) # Run the smoother - init_carry = (filtered_means[-1], filtered_covs[-1]) - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) - _, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args) + _, (smoothed_means, smoothed_covs) = lax.scan( + _step, + (filtered_means[-1], filtered_covs[-1]), + (jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]), + reverse=True + ) + + # Concatenate the last smoothed mean and covariance + smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...])) + smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...])) - # Reverse the arrays and return - smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) - smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index 7576a707..338a2ee8 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -143,7 +143,6 @@ def _step(carry, t): return post - @partial(jit, static_argnames=["transition_fn"]) def hmm_backward_filter( transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], @@ -184,9 +183,9 @@ def _step(carry, t): next_backward_pred_probs = _predict(backward_filt_probs, A.T) return (log_normalizer, next_backward_pred_probs), backward_pred_probs - carry = (0.0, jnp.ones(num_states)) - (log_normalizer, _), rev_backward_pred_probs = lax.scan(_step, carry, jnp.arange(num_timesteps)[::-1]) - backward_pred_probs = rev_backward_pred_probs[::-1] + (log_normalizer, _), backward_pred_probs = lax.scan( + _step, (0.0, jnp.ones(num_states)), jnp.arange(num_timesteps), reverse=True + ) return log_normalizer, backward_pred_probs @@ -273,7 +272,7 @@ def hmm_smoother( posterior distribution """ - num_timesteps, num_states = log_likelihoods.shape + num_timesteps = log_likelihoods.shape[0] # Run the HMM filter post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn) @@ -298,12 +297,15 @@ def _step(carry, args): return smoothed_probs, smoothed_probs # Run the HMM smoother - carry = filtered_probs[-1] - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_probs[:-1][::-1], predicted_probs[1:][::-1]) - _, rev_smoothed_probs = lax.scan(_step, carry, args) + _, smoothed_probs = lax.scan( + _step, + filtered_probs[-1], + (jnp.arange(num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:]), + reverse=True, + ) - # Reverse the arrays and return - smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]]) + # Concatenate the arrays and return + smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]]) # Package into a posterior posterior = HMMPosterior( @@ -467,10 +469,9 @@ def _backward_pass(best_next_score, t): return best_next_score, best_next_state num_states = log_likelihoods.shape[1] - best_second_score, rev_best_next_states = lax.scan( - _backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 2, -1, -1) + best_second_score, best_next_states = lax.scan( + _backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 1), reverse=True ) - best_next_states = rev_best_next_states[::-1] # Run the forward pass def _forward_pass(state, best_next_state): @@ -530,11 +531,13 @@ def _step(carry, args): # Run the HMM smoother rngs = jr.split(rng, num_timesteps) last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1]) - args = (jnp.arange(num_timesteps - 1, 0, -1), rngs[:-1][::-1], filtered_probs[:-1][::-1]) - _, rev_states = lax.scan(_step, last_state, args) + _, states = lax.scan( + _step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]), + reverse=True + ) - # Reverse the arrays and return - states = jnp.concatenate([rev_states[::-1], jnp.array([last_state])]) + # Add the last state + states = jnp.concatenate([states, jnp.array([last_state])]) return log_normalizer, states def _compute_sum_transition_probs( diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 18840bab..1588ad05 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -45,21 +45,21 @@ class ParamsLGSSMDynamics(NamedTuple): :param cov: dynamics covariance $Q$ """ - weights: Union[ParameterProperties, - Float[Array, "state_dim state_dim"], + weights: Union[ParameterProperties, + Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"]] - + bias: Union[ParameterProperties, - Float[Array, "state_dim"], + Float[Array, "state_dim"], Float[Array, "ntime state_dim"]] - + input_weights: Union[ParameterProperties, - Float[Array, "state_dim input_dim"], + Float[Array, "state_dim input_dim"], Float[Array, "ntime state_dim input_dim"]] - - cov: Union[ParameterProperties, - Float[Array, "state_dim state_dim"], - Float[Array, "ntime state_dim state_dim"], + + cov: Union[ParameterProperties, + Float[Array, "state_dim state_dim"], + Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"]] @@ -77,22 +77,22 @@ class ParamsLGSSMEmissions(NamedTuple): """ weights: Union[ParameterProperties, - Float[Array, "emission_dim state_dim"], + Float[Array, "emission_dim state_dim"], Float[Array, "ntime emission_dim state_dim"]] - + bias: Union[ParameterProperties, - Float[Array, "emission_dim"], + Float[Array, "emission_dim"], Float[Array, "ntime emission_dim"]] - + input_weights: Union[ParameterProperties, - Float[Array, "emission_dim input_dim"], + Float[Array, "emission_dim input_dim"], Float[Array, "ntime emission_dim input_dim"]] - + cov: Union[ParameterProperties, - Float[Array, "emission_dim emission_dim"], - Float[Array, "ntime emission_dim emission_dim"], - Float[Array, "emission_dim"], - Float[Array, "ntime emission_dim"], + Float[Array, "emission_dim emission_dim"], + Float[Array, "ntime emission_dim emission_dim"], + Float[Array, "emission_dim"], + Float[Array, "ntime emission_dim"], Float[Array, "emission_dim_triu"]] @@ -166,9 +166,9 @@ def _get_params(params, num_timesteps, t): D = _get_one_param(params.emissions.input_weights, 2, t) d = _get_one_param(params.emissions.bias, 1, t) - if len(params.emissions.cov.shape) == 1: + if len(params.emissions.cov.shape) == 1: R = _get_one_param(params.emissions.cov, 1, t) - elif len(params.emissions.cov.shape) > 2: + elif len(params.emissions.cov.shape) > 2: R = _get_one_param(params.emissions.cov, 2, t) elif params.emissions.cov.shape[0] != num_timesteps: R = _get_one_param(params.emissions.cov, 2, t) @@ -278,20 +278,20 @@ def _condition_on(m, P, H, D, d, R, u, y): if R.ndim == 2: S = R + H @ P @ H.T K = psd_solve(S, H @ P).T - else: + else: # Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I # (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity) I = jnp.eye(P.shape[0]) U = H @ jnp.linalg.cholesky(P) X = U / R[:, None] - S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) + S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T) """ # Could alternatively use U=H and C=P R_inv = jnp.diag(1.0 / R) P_inv = psd_solve(P, jnp.eye(P.shape[0])) S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv) """ - K = P @ H.T @ S_inv + K = P @ H.T @ S_inv S = jnp.diag(R) + H @ P @ H.T Sigma_cond = P - K @ S @ K.T @@ -361,8 +361,6 @@ def wrapper(*args, **kwargs): return wrapper - - def lgssm_joint_sample( params: ParamsLGSSM, key: PRNGKey, @@ -371,7 +369,7 @@ def lgssm_joint_sample( )-> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: r"""Sample from the joint distribution to produce state and emission trajectories. - + Args: params: model parameters inputs: optional array of inputs. @@ -390,7 +388,7 @@ def _sample_emission(key, H, D, d, R, x, u): mean = H @ x + D @ u + d R = jnp.diag(R) if R.ndim==1 else R return MVN(mean, R).sample(seed=key) - + def _sample_initial(key, params, inputs): key1, key2 = jr.split(key) @@ -417,7 +415,7 @@ def _step(prev_state, args): # Sample the initial state key1, key2 = jr.split(key) - + initial_state, initial_emission = _sample_initial(key1, params, inputs) # Sample the remaining emissions and states @@ -462,7 +460,7 @@ def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y): else: L = H @ jnp.linalg.cholesky(pred_cov) return MVNLowRank(m, R, L).log_prob(y) - + def _step(carry, t): ll, pred_mean, pred_cov = carry @@ -539,14 +537,17 @@ def _step(carry, args): return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov, smoothed_cross) # Run the Kalman smoother - init_carry = (filtered_means[-1], filtered_covs[-1]) - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) - _, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(_step, init_carry, args) - - # Reverse the arrays and return - smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) - smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) - smoothed_cross = smoothed_cross[::-1] + _, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan( + _step, + (filtered_means[-1], filtered_covs[-1]), + (jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]), + reverse=True, + ) + + # Concatenate the arrays and return + smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...])) + smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...])) + return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, @@ -563,7 +564,7 @@ def lgssm_posterior_sample( emissions: Float[Array, "ntime emission_dim"], inputs: Optional[Float[Array, "ntime input_dim"]]=None, jitter: Optional[Scalar]=0 - + ) -> Float[Array, "ntime state_dim"]: r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$. @@ -603,12 +604,16 @@ def _step(carry, args): key, this_key = jr.split(key, 2) last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key) - args = ( - jr.split(key, num_timesteps - 1), - filtered_means[:-1][::-1], - filtered_covs[:-1][::-1], - jnp.arange(num_timesteps - 2, -1, -1), + _, states = lax.scan( + _step, + last_state, + ( + jr.split(key, num_timesteps - 1), + filtered_means[:-1], + filtered_covs[:-1], + jnp.arange(num_timesteps - 1), + ), + reverse=True, ) - _, reversed_states = lax.scan(_step, last_state, args) - states = jnp.vstack([reversed_states[::-1], last_state]) - return states + + return jnp.vstack([states, last_state]) diff --git a/dynamax/linear_gaussian_ssm/info_inference.py b/dynamax/linear_gaussian_ssm/info_inference.py index fade1cbf..99ca5125 100644 --- a/dynamax/linear_gaussian_ssm/info_inference.py +++ b/dynamax/linear_gaussian_ssm/info_inference.py @@ -271,13 +271,17 @@ def _smooth_step(carry, args): return (smoothed_eta, smoothed_prec), (smoothed_eta, smoothed_prec) # Run the Kalman smoother - init_carry = (filtered_etas[-1], filtered_precisions[-1]) - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_etas[:-1][::-1], filtered_precisions[:-1][::-1]) - _, (smoothed_etas, smoothed_precisions) = lax.scan(_smooth_step, init_carry, args) + _, (smoothed_etas, smoothed_precisions) = lax.scan( + _smooth_step, + (filtered_etas[-1], filtered_precisions[-1]), + (jnp.arange(num_timesteps - 1), filtered_etas[:-1], filtered_precisions[:-1]), + reverse=True + ) + + # Concatenate the arrays and return + smoothed_etas = jnp.vstack((smoothed_etas, filtered_etas[-1][None, ...])) + smoothed_precisions = jnp.vstack((smoothed_precisions, filtered_precisions[-1][None, ...])) - # Reverse the arrays and return - smoothed_etas = jnp.vstack((smoothed_etas[::-1], filtered_etas[-1][None, ...])) - smoothed_precisions = jnp.vstack((smoothed_precisions[::-1], filtered_precisions[-1][None, ...])) return PosteriorGSSMInfoSmoothed( marginal_loglik=ll, filtered_etas=filtered_etas, diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py index c375d7b5..d936ed45 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ekf.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ekf.py @@ -237,13 +237,17 @@ def _step(carry, args): return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov) # Run the extended Kalman smoother - init_carry = (filtered_means[-1], filtered_covs[-1]) - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) - _, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args) + _, (smoothed_means, smoothed_covs) = lax.scan( + _step, + (filtered_means[-1], filtered_covs[-1]), + (jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]), + reverse=True, + ) + + # Concatenate the arrays and return + smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...])) + smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...])) - # Reverse the arrays and return - smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) - smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, @@ -253,8 +257,6 @@ def _step(carry, args): ) - - def extended_kalman_posterior_sample( key: PRNGKey, params: ParamsNLGSSM, @@ -304,16 +306,18 @@ def _step(carry, args): key, this_key = jr.split(key, 2) last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key) - args = ( - jr.split(key, num_timesteps - 1), - filtered_means[:-1][::-1], - filtered_covs[:-1][::-1], - jnp.arange(num_timesteps - 2, -1, -1), + _, states = lax.scan( + _step, + last_state, + ( + jr.split(key, num_timesteps - 1), + filtered_means[:-1], + filtered_covs[:-1], + jnp.arange(num_timesteps - 1), + ), + reverse=True, ) - _, reversed_states = lax.scan(_step, last_state, args) - states = jnp.vstack([reversed_states[::-1], last_state]) - return states - + return jnp.vstack([states, last_state]) def iterated_extended_kalman_smoother( diff --git a/dynamax/nonlinear_gaussian_ssm/inference_ukf.py b/dynamax/nonlinear_gaussian_ssm/inference_ukf.py index 1248ad4f..1df35fbe 100644 --- a/dynamax/nonlinear_gaussian_ssm/inference_ukf.py +++ b/dynamax/nonlinear_gaussian_ssm/inference_ukf.py @@ -80,7 +80,7 @@ def _predict(m, P, f, Q, lamb, w_mean, w_cov, u): Returns: m_pred (D_hid,): predicted mean. P_pred (D_hid,D_hid): predicted covariance. - + """ n = len(m) # Form sigma points and propagate @@ -271,13 +271,17 @@ def _step(carry, args): return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov) # Run the unscented Kalman smoother - init_carry = (filtered_means[-1], filtered_covs[-1]) - args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1]) - _, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args) + _, (smoothed_means, smoothed_covs) = lax.scan( + _step, + (filtered_means[-1], filtered_covs[-1]), + (jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]), + reverse=True, + ) + + # Concatenate the arrays and return + smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...])) + smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...])) - # Reverse the arrays and return - smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...])) - smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...])) return PosteriorGSSMSmoothed( marginal_loglik=ll, filtered_means=filtered_means, diff --git a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py index 43068b4d..46e4d871 100644 --- a/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py +++ b/dynamax/nonlinear_gaussian_ssm/sarkka_lib.py @@ -63,9 +63,9 @@ def _step(carry, t): return (m_sm, P_sm), (m_sm, P_sm) carry = (m_post[-1], P_post[-1]) - _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 2, -1, -1)) - m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm))[::-1] - P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm))[::-1] + _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) + m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm)) + P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm)) return m_sm, P_sm @@ -196,8 +196,8 @@ def compute_sigmas(m, P, n, lamb): return jnp.concatenate((jnp.array([m]), sigma_plus, sigma_minus)) carry = (m_post[-1], P_post[-1]) - _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 2, -1, -1)) - m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm))[::-1] - P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm))[::-1] + _, (m_sm, P_sm) = lax.scan(_step, carry, jnp.arange(num_timesteps - 1), reverse=True) + m_sm = jnp.concatenate((jnp.array([m_post[-1]]), m_sm)) + P_sm = jnp.concatenate((jnp.array([P_post[-1]]), P_sm)) return m_sm, P_sm \ No newline at end of file From 84c024bc7035d65721849d18fb73d732b3883e23 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Tue, 25 Jun 2024 14:02:41 -0400 Subject: [PATCH 2/2] Pin numpy<2.0 Prevents tensorflow probability error --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index e500c549..19a75ba4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ install_requires = scikit-learn jaxtyping typing-extensions + numpy<2.0 packages = find: