Skip to content

Commit

Permalink
Circular dependency fix
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Dec 12, 2024
1 parent 58bc697 commit b044c03
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 13 deletions.
9 changes: 1 addition & 8 deletions pymc_extras/model/marginal/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,6 @@
from pymc_extras.distributions import DiscreteMarkovChain


def get_support_axes(op) -> tuple[tuple[int, ...], ...]:
if hasattr(op, "support_axes"):
return op.support_axes
else:
# For vanilla RVs, the support axes are the last ndim_supp
return (tuple(range(-op.ndim_supp, 0)),)


class MarginalRV(OpFromGraph, MeasurableOp):
"""Base class for Marginalized RVs"""

Expand Down Expand Up @@ -99,6 +91,7 @@ def reduce_batch_dependent_logps(
as well as transpose the remaining axis of dep1 logp before adding the two element-wise.
"""
from pymc_extras.model.marginal.graph_analysis import get_support_axes

reduced_logps = []
for dependent_op, dependent_logp, dependent_dims_connection in zip(
Expand Down
4 changes: 1 addition & 3 deletions pymc_extras/statespace/models/structural.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,11 +1481,9 @@ def __init__(
k_endog=k_endog,
k_states=k_states,
k_posdef=k_posdef,
state_names=self.state_names,
measurement_error=False,
combine_hidden_states=True,
exog_names=[f"data_{name}"],
obs_state_idxs=np.ones(k_states),
obs_state_idxs=obs_state_idx,
)

def make_symbolic_graph(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/statespace/utilities/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother
from pymc_extras.statespace.utils.constants import (
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MATRIX_NAMES,
MISSING_FILL,
SHORT_NAME_TO_LONG,
)
Expand Down Expand Up @@ -210,7 +210,7 @@ def delete_rvs_from_model(rv_names: list[str]) -> None:


def unpack_statespace(ssm):
return [ssm[SHORT_NAME_TO_LONG[x]] for x in LONG_MATRIX_NAMES]
return [ssm[SHORT_NAME_TO_LONG[x]] for x in MATRIX_NAMES]


def unpack_symbolic_matrices_with_params(mod, param_dict, data_dict=None, mode="FAST_COMPILE"):
Expand Down

0 comments on commit b044c03

Please sign in to comment.