diff --git a/docs/source/msei_reference/distributions.rst b/docs/source/msei_reference/distributions.rst new file mode 100644 index 00000000..baa7f383 --- /dev/null +++ b/docs/source/msei_reference/distributions.rst @@ -0,0 +1,7 @@ +Distributions +=========== + +.. automodule:: pyrenew.distributions + :members: + :undoc-members: + :show-inheritance: diff --git a/pyrenew/distributions/__init__.py b/pyrenew/distributions/__init__.py new file mode 100644 index 00000000..42679d9a --- /dev/null +++ b/pyrenew/distributions/__init__.py @@ -0,0 +1,7 @@ +# numpydoc ignore=GL08 + +from pyrenew.distributions.censorednormal import CensoredNormal + +__all__ = [ + "CensoredNormal", +] diff --git a/pyrenew/distributions/censorednormal.py b/pyrenew/distributions/censorednormal.py new file mode 100644 index 00000000..a77ed7e9 --- /dev/null +++ b/pyrenew/distributions/censorednormal.py @@ -0,0 +1,119 @@ +# numpydoc ignore=GL08 + +import jax +import jax.numpy as jnp +import numpyro +import numpyro.util +from numpyro.distributions import constraints +from numpyro.distributions.util import promote_shapes, validate_sample + + +class CensoredNormal(numpyro.distributions.Distribution): + """ + Censored normal distribution under which samples + are truncated to lie within a specified interval. + This implementation is adapted from + https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + pytree_data_fields = ( + "loc", + "scale", + "lower_limit", + "upper_limit", + "_support", + ) + + def __init__( + self, + loc=0, + scale=1, + lower_limit=-jnp.inf, + upper_limit=jnp.inf, + validate_args=None, + ): + """ + Default constructor + + Parameters + ---------- + loc : ArrayLike or float, optional + The mean of the normal distribution. + Defaults to 0. + scale : ArrayLike or float, optional + The standard deviation of the normal + distribution. Must be positive. Defaults to 1. + lower_limit : float, optional + The lower bound of the interval for censoring. + Defaults to -inf (no lower bound). + upper_limit : float, optional + The upper bound of the interval for censoring. + Defaults to inf (no upper bound). + validate_args : bool, optional + If True, checks if parameters are valid. + Defaults to None. + + Returns + ------- + None + """ + self.loc, self.scale = promote_shapes(loc, scale) + self.lower_limit = lower_limit + self.upper_limit = upper_limit + self._support = constraints.interval( + self.lower_limit, self.upper_limit + ) + batch_shape = jax.lax.broadcast_shapes( + jnp.shape(loc), jnp.shape(scale) + ) + self.normal_ = numpyro.distributions.Normal( + loc=loc, scale=scale, validate_args=validate_args + ) + super().__init__(batch_shape=batch_shape, validate_args=validate_args) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): # numpydoc ignore=GL08 + return self._support + + def sample(self, key, sample_shape=()): + """ + Generates samples from the censored normal distribution. + + Returns + ------- + Array + Containing samples from the censored normal distribution. + """ + assert numpyro.util.is_prng_key(key) + result = self.normal_.sample(key, sample_shape) + return jnp.clip(result, min=self.lower_limit, max=self.upper_limit) + + @validate_sample + def log_prob(self, value): + """ + Computes the log probability density of a given value(s) under + the censored normal distribution. + + Returns + ------- + Array + Containing log probability of the given value(s) + under the censored normal distribution + """ + rescaled_ulim = (self.upper_limit - self.loc) / self.scale + rescaled_llim = (self.lower_limit - self.loc) / self.scale + lim_val = jnp.where( + value <= self.lower_limit, + jax.scipy.special.log_ndtr(rescaled_llim), + jax.scipy.special.log_ndtr(-rescaled_ulim), + ) + # we exploit the fact that for the + # standard normal, P(x > a) = P(-x < a) + # to compute the log complementary CDF + inbounds = jnp.logical_and( + value > self.lower_limit, value < self.upper_limit + ) + result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val) + + return result diff --git a/test/test_censorednormal.py b/test/test_censorednormal.py new file mode 100644 index 00000000..b5605fa8 --- /dev/null +++ b/test/test_censorednormal.py @@ -0,0 +1,170 @@ +# numpydoc ignore=GL08 + +import jax +import jax.numpy as jnp +import numpyro +import pytest +from numpy.testing import ( + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) + +from pyrenew.distributions import CensoredNormal + + +@pytest.mark.parametrize( + ["loc", "scale", "lower_limit", "upper_limit", "in_val", "l_val", "h_val"], + [ + [ + jnp.array([0]), + jnp.array([2.0, 1.0]), + -1, + 1, + jnp.array([0, 0.5]), + -2, + 2, + ], + [ + jnp.array([0, 1]), + jnp.array([1.0]), + -1, + 2, + jnp.array([0, 0.5]), + -2, + 3, + ], + ], +) +def test_interval_censored_normal_distribution( + loc, + scale, + lower_limit, + upper_limit, + in_val, + l_val, + h_val, +): + """ + Tests the censored normal distribution samples + within the limit and calculation of log probability + """ + censored_dist = CensoredNormal( + loc=loc, scale=scale, lower_limit=lower_limit, upper_limit=upper_limit + ) + normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale) + + # test samples within the bounds + samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,)) + assert jnp.all(samp >= lower_limit) + assert jnp.all(samp <= upper_limit) + + # test log prob of values within bounds + assert_array_equal( + censored_dist.log_prob(in_val), normal_dist.log_prob(in_val) + ) + + # test log prob of values lower than the limit + assert_array_almost_equal( + censored_dist.log_prob(l_val), + jax.scipy.special.log_ndtr((lower_limit - loc) / scale), + ) + + # test log prob of values higher than the limit + assert_array_almost_equal( + censored_dist.log_prob(h_val), + jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale), + ) + + # test_broadcasting + assert_equal(samp.shape[-1], max(loc.shape[0], scale.shape[0])) + + # test support of the distribution + assert_equal(censored_dist.support.lower_bound, lower_limit) + assert_equal(censored_dist.support.upper_bound, upper_limit) + + +@pytest.mark.parametrize( + ["loc", "scale", "lower_limit", "in_val", "l_val"], + [ + [0, 1, -5, jnp.array([-2, 1]), -6], + ], +) +def test_left_censored_normal_distribution( + loc, + scale, + lower_limit, + in_val, + l_val, +): + """ + Tests the lower censored normal distribution samples + within the limit and calculation of log probability + """ + censored_dist = CensoredNormal( + loc=loc, + scale=scale, + lower_limit=lower_limit, + ) + normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale) + + # test samples within the bounds + samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,)) + assert jnp.all(samp >= lower_limit) + + # test log prob of values within bounds + assert_array_equal( + censored_dist.log_prob(in_val), normal_dist.log_prob(in_val) + ) + + # test log prob of values lower than the limit + assert_array_almost_equal( + censored_dist.log_prob(l_val), + jax.scipy.special.log_ndtr((lower_limit - loc) / scale), + ) + + # test support of the distribution + assert_equal(censored_dist.support.lower_bound, lower_limit) + assert censored_dist.support.upper_bound == jnp.inf + + +@pytest.mark.parametrize( + ["loc", "scale", "upper_limit", "in_val", "h_val"], + [ + [0, 1, 3, jnp.array([1, 2]), 5], + ], +) +def test_right_censored_normal_distribution( + loc, + scale, + upper_limit, + in_val, + h_val, +): + """ + Tests the upper censored normal distribution samples + within the limit and calculation of log probability + """ + censored_dist = CensoredNormal( + loc=loc, scale=scale, upper_limit=upper_limit + ) + normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale) + + # test samples within the bounds + samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,)) + assert jnp.all(samp <= upper_limit) + + # test log prob of values within bounds + assert_array_equal( + censored_dist.log_prob(in_val), normal_dist.log_prob(in_val) + ) + + # test log prob of values higher than the limit + assert_array_almost_equal( + censored_dist.log_prob(h_val), + jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale), + ) + + # test support of the distribution + assert_equal(censored_dist.support.upper_bound, upper_limit) + assert censored_dist.support.lower_bound == -jnp.inf