Skip to content

Commit

Permalink
Bugfixes for statespace/models/structural.py (#287)
Browse files Browse the repository at this point in the history
Expand test coverage fix bugs in `structural.py`
  • Loading branch information
jessegrabowski authored Dec 18, 2023
1 parent 430c344 commit 656b800
Show file tree
Hide file tree
Showing 9 changed files with 758 additions and 180 deletions.
2 changes: 1 addition & 1 deletion docs/statespace/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Statespace Models
.. autosummary::
:toctree: generated

BayesianARIMA
BayesianSARIMA
BayesianVARMAX

*********************
Expand Down
4 changes: 1 addition & 3 deletions docs/statespace/models/structural.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,4 @@ Structural Components
TimeSeasonality
FrequencySeasonality
MeasurementError

StructuralTimeSeries
Component
CycleComponent
2 changes: 1 addition & 1 deletion pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
handler = logging.StreamHandler()
_log.addHandler(handler)

from pymc_experimental import distributions, gp, utils
from pymc_experimental import distributions, gp, statespace, utils
from pymc_experimental.inference.fit import fit
from pymc_experimental.model.marginal_model import MarginalModel
from pymc_experimental.model.model_api import as_model
47 changes: 32 additions & 15 deletions pymc_experimental/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,37 @@ def _insert_random_variables(self) -> List[Variable]:
}
self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)

def _register_matrices_with_pymc_model(self) -> List[pt.TensorVariable]:
"""
Add all statespace matrices to the PyMC model currently on the context stack as pm.Deterministic nodes, and
adds named dimensions if they are found.
Returns
-------
registered_matrices: list of pt.TensorVariable
List of statespace matrices, wrapped in pm.Deterministic
"""

pm_mod = modelcontext(None)
matrices = self.unpack_statespace()

registered_matrices = []
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
if not getattr(pm_mod, name, None):
shape, dims = self._get_matrix_shape_and_dims(name)
has_dims = dims is not None

if matrix.ndim == time_varying_ndim and has_dims:
dims = (TIME_DIM,) + dims

x = pm.Deterministic(name, matrix, dims=dims)
registered_matrices.append(x)
else:
registered_matrices.append(matrices[i])

return registered_matrices

def add_exogenous(self, exog: pt.TensorVariable) -> None:
"""
Add an exogenous process to the statespace model
Expand Down Expand Up @@ -746,7 +777,6 @@ def build_statespace_graph(
pm_mod = modelcontext(None)

self._insert_random_variables()
matrices = self.unpack_statespace()
obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)

self.data_len = data.shape[0]
Expand All @@ -758,20 +788,7 @@ def build_statespace_graph(
missing_fill_value=missing_fill_value,
)

registered_matrices = []
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
if not getattr(pm_mod, name, None):
shape, dims = self._get_matrix_shape_and_dims(name)
has_dims = dims is not None

if matrix.ndim == time_varying_ndim and has_dims:
dims = (TIME_DIM,) + dims

x = pm.Deterministic(name, matrix, dims=dims)
registered_matrices.append(x)
else:
registered_matrices.append(matrices[i])
registered_matrices = self._register_matrices_with_pymc_model()

filter_outputs = self.kalman_filter.build_graph(
pt.as_tensor_variable(data),
Expand Down
39 changes: 33 additions & 6 deletions pymc_experimental/statespace/filters/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.graph.basic import Node

floatX = pytensor.config.floatX
COV_ZERO_TOL = 0

lgss_shape_message = (
"The LinearGaussianStateSpace distribution needs shape information to be constructed. "
Expand Down Expand Up @@ -157,8 +158,11 @@ def step_fn(*args):
middle_rng, a_innovation = pm.MvNormal.dist(mu=0, cov=Q, rng=rng).owner.outputs
next_rng, y_innovation = pm.MvNormal.dist(mu=0, cov=H, rng=middle_rng).owner.outputs

a_next = c + T @ a + R @ a_innovation
y_next = d + Z @ a_next + y_innovation
a_mu = c + T @ a
a_next = pt.switch(pt.all(pt.le(Q, COV_ZERO_TOL)), a_mu, a_mu + R @ a_innovation)

y_mu = d + Z @ a_next
y_next = pt.switch(pt.all(pt.le(H, COV_ZERO_TOL)), y_mu, y_mu + y_innovation)

next_state = pt.concatenate([a_next, y_next], axis=0)

Expand All @@ -168,7 +172,11 @@ def step_fn(*args):
Z_init = Z_ if Z_ in non_sequences else Z_[0]
H_init = H_ if H_ in non_sequences else H_[0]

init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng)
init_y_ = pt.switch(
pt.all(pt.le(H_init, COV_ZERO_TOL)),
Z_init @ init_x_,
pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng),
)
init_dist_ = pt.concatenate([init_x_, init_y_], axis=0)

statespace, updates = pytensor.scan(
Expand Down Expand Up @@ -216,6 +224,7 @@ def __new__(
steps=None,
mode=None,
sequence_names=None,
k_endog=None,
**kwargs,
):
dims = kwargs.pop("dims", None)
Expand All @@ -239,11 +248,29 @@ def __new__(
sequence_names=sequence_names,
**kwargs,
)

k_states = T.type.shape[0]

latent_states = latent_obs_combined[..., :k_states]
obs_states = latent_obs_combined[..., k_states:]
if k_endog is None and k_states is None:
raise ValueError("Could not infer number of observed states, explicitly pass k_endog.")
if k_endog is not None and k_states is not None:
total_shape = latent_obs_combined.type.shape[-1]
inferred_endog = total_shape - k_states
if inferred_endog != k_endog:
raise ValueError(
f"Inferred k_endog does not agree with provided value ({inferred_endog} != {k_endog}). "
f"It is not necessary to provide k_endog when the value can be inferred."
)
latent_slice = slice(None, -k_endog)
obs_slice = slice(-k_endog, None)
elif k_endog is None:
latent_slice = slice(None, k_states)
obs_slice = slice(k_states, None)
else:
latent_slice = slice(None, -k_endog)
obs_slice = slice(-k_endog, None)

latent_states = latent_obs_combined[..., latent_slice]
obs_states = latent_obs_combined[..., obs_slice]

latent_states = pm.Deterministic(f"{name}_latent", latent_states, dims=latent_dims)
obs_states = pm.Deterministic(f"{name}_observed", obs_states, dims=obs_dims)
Expand Down
Loading

0 comments on commit 656b800

Please sign in to comment.