From b044c03613a62bf19f9c8f20f79fb96db8008120 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Wed, 11 Dec 2024 21:14:33 -0600 Subject: [PATCH] Circular dependency fix --- pymc_extras/model/marginal/distributions.py | 9 +-------- pymc_extras/statespace/models/structural.py | 4 +--- tests/statespace/utilities/test_helpers.py | 4 ++-- 3 files changed, 4 insertions(+), 13 deletions(-) diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py index 7e38af37..33f4bc7c 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions.py @@ -19,14 +19,6 @@ from pymc_extras.distributions import DiscreteMarkovChain -def get_support_axes(op) -> tuple[tuple[int, ...], ...]: - if hasattr(op, "support_axes"): - return op.support_axes - else: - # For vanilla RVs, the support axes are the last ndim_supp - return (tuple(range(-op.ndim_supp, 0)),) - - class MarginalRV(OpFromGraph, MeasurableOp): """Base class for Marginalized RVs""" @@ -99,6 +91,7 @@ def reduce_batch_dependent_logps( as well as transpose the remaining axis of dep1 logp before adding the two element-wise. """ + from pymc_extras.model.marginal.graph_analysis import get_support_axes reduced_logps = [] for dependent_op, dependent_logp, dependent_dims_connection in zip( diff --git a/pymc_extras/statespace/models/structural.py b/pymc_extras/statespace/models/structural.py index 21d87c64..bc61eab9 100644 --- a/pymc_extras/statespace/models/structural.py +++ b/pymc_extras/statespace/models/structural.py @@ -1481,11 +1481,9 @@ def __init__( k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, - state_names=self.state_names, measurement_error=False, combine_hidden_states=True, - exog_names=[f"data_{name}"], - obs_state_idxs=np.ones(k_states), + obs_state_idxs=obs_state_idx, ) def make_symbolic_graph(self) -> None: diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index e6b4a1ae..6a1cae31 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -10,7 +10,7 @@ from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother from pymc_extras.statespace.utils.constants import ( JITTER_DEFAULT, - LONG_MATRIX_NAMES, + MATRIX_NAMES, MISSING_FILL, SHORT_NAME_TO_LONG, ) @@ -210,7 +210,7 @@ def delete_rvs_from_model(rv_names: list[str]) -> None: def unpack_statespace(ssm): - return [ssm[SHORT_NAME_TO_LONG[x]] for x in LONG_MATRIX_NAMES] + return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES] def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="FAST_COMPILE"):