Skip to content

Commit

Permalink
Add tests for DistributionalRV factory and classes
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Aug 16, 2024
1 parent 8ca2dd2 commit 25fa99a
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions model/src/test/test_distributional_rv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Tests for the distributional RV classes
"""
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
import pytest
from numpy.testing import assert_array_equal
from pyrenew.metaclass import (
DistributionalRV,
DynamicDistributionalRV,
StaticDistributionalRV,
)


class NonCallableTestClass:
"""
Generic non-callable object to test
callable checking for DynamicDistributionalRV.
"""

def __init__(self):
"""
Initialization method for generic non-callable
object
"""
pass


@pytest.mark.parametrize("not_a_dist", [1, "test", NonCallableTestClass()])
def test_invalid_constructor_args(not_a_dist):
"""
Test that the constructor errors
appropriately when given incorrect input
"""

with pytest.raises(
ValueError, match="distribution argument to DistributionalRV"
):
DistributionalRV(name="this should fail", distribution=not_a_dist)
with pytest.raises(
ValueError,
match=(
"distribution should be an instance of "
"numpyro.distributions.Distribution"
),
):
StaticDistributionalRV.validate(not_a_dist)
with pytest.raises(ValueError, match="must provide a Callable"):
DynamicDistributionalRV.validate(not_a_dist)


@pytest.mark.parametrize(
["valid_static_dist_arg", "valid_dynamic_dist_arg"],
[
[dist.Normal(0, 1), dist.Normal],
[dist.Cauchy(3.0, 5.0), dist.Cauchy],
[dist.Poisson(0.25), dist.Poisson],
],
)
def test_factory_triage(valid_static_dist_arg, valid_dynamic_dist_arg):
"""
Test that passing a numpyro.distributions.Distribution
instance to the DistributionalRV factory instaniates
a StaticDistributionalRV, while passing a callable
instaniates a DynamicDistributionalRV
"""
static = DistributionalRV(
name="test static", distribution=valid_static_dist_arg
)
assert isinstance(static, StaticDistributionalRV)
dynamic = DistributionalRV(
name="test dynamic", distribution=valid_dynamic_dist_arg
)
assert isinstance(dynamic, DynamicDistributionalRV)


@pytest.mark.parametrize(
["dist", "params"],
[
[dist.Normal, {"loc": 0.0, "scale": 0.5}],
[dist.Poisson, {"rate": 0.35265}],
[
dist.Cauchy,
{
"loc": jnp.array([1.0, 5.0, -0.25]),
"scale": jnp.array([0.02, 0.15, 2]),
},
],
],
)
def test_sampling_equivalent(dist, params):
"""
Test that sampling a DynamicDistributionalRV
with a given parameterization is equivalent to
sampling a StaticDistributionalRV with the
same parameterization and the same random seed
"""
static = DistributionalRV(name="static", distribution=dist(**params))
dynamic = DistributionalRV(name="dynamic", distribution=dist)
assert isinstance(static, StaticDistributionalRV)
assert isinstance(dynamic, DynamicDistributionalRV)
with numpyro.handlers.seed(rng_seed=5):
static_samp, *_ = static()
with numpyro.handlers.seed(rng_seed=5):
dynamic_samp, *_ = dynamic(**params)
assert_array_equal(static_samp.value, dynamic_samp.value)

0 comments on commit 25fa99a

Please sign in to comment.