Skip to content

Commit

Permalink
Draw ar noise within scan
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Sep 10, 2024
1 parent 3e1ef80 commit c0375c3
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
40 changes: 19 additions & 21 deletions pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
from numpyro.contrib.control_flow import scan
from numpyro.infer.reparam import LocScaleReparam

from pyrenew.metaclass import RandomVariable, SampledValue
from pyrenew.process.iidrandomsequence import StandardNormalSequence


class ARProcess(RandomVariable):
Expand All @@ -16,21 +17,9 @@ class ARProcess(RandomVariable):
an AR(p) process.
"""

def __init__(self, noise_rv_name: str, *args, **kwargs) -> None:
"""
Default constructor.
Parameters
----------
noise_rv_name : str
A name for the internal RandomVariable
holding the process noise.
"""
super().__init__(*args, **kwargs)
self.noise_rv_ = StandardNormalSequence(element_rv_name=noise_rv_name)

def sample(
self,
noise_name: str,
n: int,
autoreg: ArrayLike,
init_vals: ArrayLike,
Expand Down Expand Up @@ -96,20 +85,29 @@ def sample(
f"order {order}"
)

raw_noise, *_ = self.noise_rv_(n=n, **kwargs)
noise = noise_sd_arr * raw_noise.value
def transition(recent_vals, _): # numpydoc ignore=GL08
with numpyro.handlers.reparam(
config={noise_name: LocScaleReparam(0)}
):
next_noise = numpyro.sample(
noise_name,
numpyro.distributions.Normal(loc=0, scale=noise_sd_arr),
)

def transition(recent_vals, next_noise): # numpydoc ignore=GL08
new_term = jnp.dot(autoreg, recent_vals) + next_noise
new_recent_vals = jnp.hstack(
[new_term, recent_vals[: (order - 1)]]
new_recent_vals = jnp.concatenate(
[new_term, recent_vals[..., : (order - 1)]]
)
return new_recent_vals, new_term

last, ts = scan(transition, init_vals, noise)
last, ts = scan(f=transition, init=init_vals, xs=None, length=n)
return (
SampledValue(
jnp.hstack([init_vals, ts]),
jnp.squeeze(
jnp.concatenate(
[init_vals[::, jnp.newaxis], ts],
)
),
t_start=self.t_start,
t_unit=self.t_unit,
),
Expand Down
7 changes: 4 additions & 3 deletions pyrenew/process/rtperiodicdiffar.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ def __init__(
self.log_rt_rv = log_rt_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv
self.ar_process_suffix = ar_process_suffix

self.ar_diff = DifferencedProcess(
fundamental_process=ARProcess(
noise_rv_name=f"{name}{ar_process_suffix}"
),
fundamental_process=ARProcess(),
differencing_order=1,
)

Expand Down Expand Up @@ -169,6 +169,7 @@ def sample(
# Running the process

log_rt = self.ar_diff(
noise_name=f"{self.name}{self.ar_process_suffix}",
n=n_periods,
init_vals=jnp.array([log_rt_rv[0]]),
autoreg=b,
Expand Down
16 changes: 13 additions & 3 deletions test/test_ar_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,37 @@ def test_ar_can_be_sampled():
Check that an AR process
can be initialized and sampled from
"""
ar1 = ARProcess(noise_rv_name="ar1process_noise")
ar1 = ARProcess()
with numpyro.handlers.seed(rng_seed=62):
# can sample
ar1(
noise_name="ar1process_noise",
n=3532,
init_vals=jnp.array([50.0]),
autoreg=jnp.array([0.95]),
noise_sd=0.5,
)

ar3 = ARProcess(noise_rv_name="ar3process_noise")
ar3 = ARProcess()

with numpyro.handlers.seed(rng_seed=62):
# can sample
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=[0.25],
)
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
Expand All @@ -50,13 +54,15 @@ def test_ar_can_be_sampled():
# error
with pytest.raises(ValueError, match="must be a scalar"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=jnp.array([1.0, 2.0]),
)
with pytest.raises(ValueError, match="must be a scalar"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
Expand All @@ -66,20 +72,23 @@ def test_ar_can_be_sampled():
# bad dimensionality raises error
with pytest.raises(ValueError, match="Array of autoregressive"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([[0.05, 0.025, 0.025]]),
noise_sd=0.5,
)
with pytest.raises(ValueError, match="Array of initial"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([[50.0, 49.9, 48.2]]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)
with pytest.raises(ValueError, match="same size as the order"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 1, 1, 1]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
Expand All @@ -94,11 +103,12 @@ def test_ar_samples_correctly_distributed():
"""
noise_sd = jnp.array([0.5])
ar_inits = jnp.array([25.0])
ar = ARProcess("arprocess")
ar = ARProcess()
with numpyro.handlers.seed(rng_seed=62):
# check it regresses to mean
# when started away from it
long_ts, *_ = ar(
noise_name="arprocess_noise",
n=10000,
init_vals=ar_inits,
autoreg=jnp.array([0.75]),
Expand Down
2 changes: 1 addition & 1 deletion test/test_differenced_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_integrator_correctness(order, n_diffs):
)
result_proc1 = proc.integrate(inits, diffs)
assert result_proc1.shape == (n_diffs + order,)
assert_array_almost_equal(result_manual, result_proc1, decimal=5)
assert_array_almost_equal(result_manual, result_proc1, decimal=4)
assert result_proc1[0] == inits[0]


Expand Down

0 comments on commit c0375c3

Please sign in to comment.