From c0375c376bc04ee5e8a32bb0e30c742f83792b4e Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Tue, 10 Sep 2024 13:54:43 -0400 Subject: [PATCH] Draw ar noise within scan --- pyrenew/process/ar.py | 40 ++++++++++++++--------------- pyrenew/process/rtperiodicdiffar.py | 7 ++--- test/test_ar_process.py | 16 +++++++++--- test/test_differenced_process.py | 2 +- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/pyrenew/process/ar.py b/pyrenew/process/ar.py index 5153f3e6..06ade93b 100644 --- a/pyrenew/process/ar.py +++ b/pyrenew/process/ar.py @@ -3,11 +3,12 @@ from __future__ import annotations import jax.numpy as jnp +import numpyro from jax.typing import ArrayLike from numpyro.contrib.control_flow import scan +from numpyro.infer.reparam import LocScaleReparam from pyrenew.metaclass import RandomVariable, SampledValue -from pyrenew.process.iidrandomsequence import StandardNormalSequence class ARProcess(RandomVariable): @@ -16,21 +17,9 @@ class ARProcess(RandomVariable): an AR(p) process. """ - def __init__(self, noise_rv_name: str, *args, **kwargs) -> None: - """ - Default constructor. - - Parameters - ---------- - noise_rv_name : str - A name for the internal RandomVariable - holding the process noise. - """ - super().__init__(*args, **kwargs) - self.noise_rv_ = StandardNormalSequence(element_rv_name=noise_rv_name) - def sample( self, + noise_name: str, n: int, autoreg: ArrayLike, init_vals: ArrayLike, @@ -96,20 +85,29 @@ def sample( f"order {order}" ) - raw_noise, *_ = self.noise_rv_(n=n, **kwargs) - noise = noise_sd_arr * raw_noise.value + def transition(recent_vals, _): # numpydoc ignore=GL08 + with numpyro.handlers.reparam( + config={noise_name: LocScaleReparam(0)} + ): + next_noise = numpyro.sample( + noise_name, + numpyro.distributions.Normal(loc=0, scale=noise_sd_arr), + ) - def transition(recent_vals, next_noise): # numpydoc ignore=GL08 new_term = jnp.dot(autoreg, recent_vals) + next_noise - new_recent_vals = jnp.hstack( - [new_term, recent_vals[: (order - 1)]] + new_recent_vals = jnp.concatenate( + [new_term, recent_vals[..., : (order - 1)]] ) return new_recent_vals, new_term - last, ts = scan(transition, init_vals, noise) + last, ts = scan(f=transition, init=init_vals, xs=None, length=n) return ( SampledValue( - jnp.hstack([init_vals, ts]), + jnp.squeeze( + jnp.concatenate( + [init_vals[::, jnp.newaxis], ts], + ) + ), t_start=self.t_start, t_unit=self.t_unit, ), diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 9186b9ef..1bf4abfd 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -98,10 +98,10 @@ def __init__( self.log_rt_rv = log_rt_rv self.autoreg_rv = autoreg_rv self.periodic_diff_sd_rv = periodic_diff_sd_rv + self.ar_process_suffix = ar_process_suffix + self.ar_diff = DifferencedProcess( - fundamental_process=ARProcess( - noise_rv_name=f"{name}{ar_process_suffix}" - ), + fundamental_process=ARProcess(), differencing_order=1, ) @@ -169,6 +169,7 @@ def sample( # Running the process log_rt = self.ar_diff( + noise_name=f"{self.name}{self.ar_process_suffix}", n=n_periods, init_vals=jnp.array([log_rt_rv[0]]), autoreg=b, diff --git a/test/test_ar_process.py b/test/test_ar_process.py index b1df31af..629dfd5b 100755 --- a/test/test_ar_process.py +++ b/test/test_ar_process.py @@ -13,33 +13,37 @@ def test_ar_can_be_sampled(): Check that an AR process can be initialized and sampled from """ - ar1 = ARProcess(noise_rv_name="ar1process_noise") + ar1 = ARProcess() with numpyro.handlers.seed(rng_seed=62): # can sample ar1( + noise_name="ar1process_noise", n=3532, init_vals=jnp.array([50.0]), autoreg=jnp.array([0.95]), noise_sd=0.5, ) - ar3 = ARProcess(noise_rv_name="ar3process_noise") + ar3 = ARProcess() with numpyro.handlers.seed(rng_seed=62): # can sample ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 48.2]), autoreg=jnp.array([0.05, 0.025, 0.025]), noise_sd=0.5, ) ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 48.2]), autoreg=jnp.array([0.05, 0.025, 0.025]), noise_sd=[0.25], ) ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 48.2]), autoreg=jnp.array([0.05, 0.025, 0.025]), @@ -50,6 +54,7 @@ def test_ar_can_be_sampled(): # error with pytest.raises(ValueError, match="must be a scalar"): ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 48.2]), autoreg=jnp.array([0.05, 0.025, 0.025]), @@ -57,6 +62,7 @@ def test_ar_can_be_sampled(): ) with pytest.raises(ValueError, match="must be a scalar"): ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 48.2]), autoreg=jnp.array([0.05, 0.025, 0.025]), @@ -66,6 +72,7 @@ def test_ar_can_be_sampled(): # bad dimensionality raises error with pytest.raises(ValueError, match="Array of autoregressive"): ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 48.2]), autoreg=jnp.array([[0.05, 0.025, 0.025]]), @@ -73,6 +80,7 @@ def test_ar_can_be_sampled(): ) with pytest.raises(ValueError, match="Array of initial"): ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([[50.0, 49.9, 48.2]]), autoreg=jnp.array([0.05, 0.025, 0.025]), @@ -80,6 +88,7 @@ def test_ar_can_be_sampled(): ) with pytest.raises(ValueError, match="same size as the order"): ar3( + noise_name="ar3process_noise", n=1230, init_vals=jnp.array([50.0, 49.9, 1, 1, 1]), autoreg=jnp.array([0.05, 0.025, 0.025]), @@ -94,11 +103,12 @@ def test_ar_samples_correctly_distributed(): """ noise_sd = jnp.array([0.5]) ar_inits = jnp.array([25.0]) - ar = ARProcess("arprocess") + ar = ARProcess() with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it long_ts, *_ = ar( + noise_name="arprocess_noise", n=10000, init_vals=ar_inits, autoreg=jnp.array([0.75]), diff --git a/test/test_differenced_process.py b/test/test_differenced_process.py index ba4e95c9..01267b60 100644 --- a/test/test_differenced_process.py +++ b/test/test_differenced_process.py @@ -122,7 +122,7 @@ def test_integrator_correctness(order, n_diffs): ) result_proc1 = proc.integrate(inits, diffs) assert result_proc1.shape == (n_diffs + order,) - assert_array_almost_equal(result_manual, result_proc1, decimal=5) + assert_array_almost_equal(result_manual, result_proc1, decimal=4) assert result_proc1[0] == inits[0]