Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create a censored normal distribution #428

Merged
merged 13 commits into from
Sep 6, 2024
7 changes: 7 additions & 0 deletions docs/source/msei_reference/distributions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Distributions
===========

.. automodule:: pyrenew.distributions
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions pyrenew/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# numpydoc ignore=GL08

from pyrenew.distributions.censorednormal import CensoredNormal

__all__ = [
"CensoredNormal",
]
117 changes: 117 additions & 0 deletions pyrenew/distributions/censorednormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# numpydoc ignore=GL08

import jax
import jax.numpy as jnp
import numpyro
import numpyro.util
from numpyro.distributions import constraints
from numpyro.distributions.util import promote_shapes, validate_sample


class CensoredNormal(numpyro.distributions.Distribution):
"""
Censored normal distribution under which samples
are truncated to lie within a specified interval.
This implementation is adapted from
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
pytree_data_fields = (
"loc",
"scale",
"lower_limit",
"upper_limit",
"_support",
)

def __init__(
self,
loc=0,
scale=1,
lower_limit=-jnp.inf,
upper_limit=jnp.inf,
validate_args=None,
):
"""
Default constructor

Parameters
----------
loc : ArrayLike or float, optional
The mean of the normal distribution.
Defaults to 0.
scale : ArrayLike or float, optional
The standard deviation of the normal
distribution. Must be positive. Defaults to 1.
lower_limit : float, optional
The lower bound of the interval for censoring.
Defaults to -inf (no lower bound).
upper_limit : float, optional
The upper bound of the interval for censoring.
Defaults to inf (no upper bound).
validate_args : bool, optional
If True, checks if parameters are valid.
Defaults to None.

Returns
-------
None
"""
self.loc, self.scale = promote_shapes(loc, scale)
self.lower_limit = lower_limit
self.upper_limit = upper_limit

sbidari marked this conversation as resolved.
Show resolved Hide resolved
batch_shape = jax.lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale)
)
self.normal_ = numpyro.distributions.Normal(
loc=loc, scale=scale, validate_args=validate_args
)
super().__init__(batch_shape=batch_shape, validate_args=validate_args)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self): # numpydoc ignore=GL08
return self._support

Check warning on line 75 in pyrenew/distributions/censorednormal.py

View check run for this annotation

Codecov / codecov/patch

pyrenew/distributions/censorednormal.py#L75

Added line #L75 was not covered by tests
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved

def sample(self, key, sample_shape=()):
"""
Generates samples from the censored normal distribution.

Returns
-------
Array
Containing samples from the censored normal distribution.
"""
assert numpyro.util.is_prng_key(key)
result = self.normal_.sample(key, sample_shape)
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit)

@validate_sample
def log_prob(self, value):
"""
Computes the log probability density of a given value(s) under
the censored normal distribution.

Returns
-------
Array
Containing log probability of the given value(s)
under the censored normal distribution
"""
rescaled_ulim = (self.upper_limit - self.loc) / self.scale
rescaled_llim = (self.lower_limit - self.loc) / self.scale
lim_val = jnp.where(
value <= self.lower_limit,
jax.scipy.special.log_ndtr(rescaled_llim),
jax.scipy.special.log_ndtr(-rescaled_ulim),
)
# we exploit the fact that for the
# standard normal, P(x > a) = P(-x < a)
# to compute the log complementary CDF
inbounds = jnp.logical_and(
value > self.lower_limit, value < self.upper_limit
)
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val)

return result
134 changes: 134 additions & 0 deletions test/test_censorednormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# numpydoc ignore=GL08

import jax
import jax.numpy as jnp
import numpyro
import pytest
from numpy.testing import assert_array_almost_equal, assert_array_equal

from pyrenew.distributions import CensoredNormal


@pytest.mark.parametrize(
["loc", "scale", "lower_limit", "upper_limit", "in_val", "l_val", "h_val"],
[
[0, 1, -1, 1, jnp.array([0, 0.5]), -2, 2],
],
)
def test_interval_censored_normal_distribution(
loc,
scale,
lower_limit,
upper_limit,
in_val,
l_val,
h_val,
):
"""
Tests the censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(
loc=loc, scale=scale, lower_limit=lower_limit, upper_limit=upper_limit
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp >= lower_limit)
assert jnp.all(samp <= upper_limit)

# test log prob of values within bounds
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values lower than the limit
assert_array_almost_equal(
censored_dist.log_prob(l_val),
jax.scipy.special.log_ndtr((lower_limit - loc) / scale),
)

# test log prob of values higher than the limit
assert_array_almost_equal(
censored_dist.log_prob(h_val),
jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale),
)


@pytest.mark.parametrize(
["loc", "scale", "lower_limit", "in_val", "l_val"],
[
[0, 1, -5, jnp.array([-2, 1]), -6],
],
)
def test_left_censored_normal_distribution(
loc,
scale,
lower_limit,
in_val,
l_val,
):
"""
Tests the lower censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(
loc=loc,
scale=scale,
lower_limit=lower_limit,
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp >= lower_limit)

# test log prob of values within bounds
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values lower than the limit
assert_array_almost_equal(
censored_dist.log_prob(l_val),
jax.scipy.special.log_ndtr((lower_limit - loc) / scale),
)


@pytest.mark.parametrize(
["loc", "scale", "upper_limit", "in_val", "h_val"],
[
[0, 1, 3, jnp.array([1, 2]), 5],
],
)
def test_right_censored_normal_distribution(
loc,
scale,
upper_limit,
in_val,
h_val,
):
"""
Tests the upper censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(
loc=loc, scale=scale, upper_limit=upper_limit
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp <= upper_limit)

# test log prob of values within bounds
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values higher than the limit
assert_array_almost_equal(
censored_dist.log_prob(h_val),
jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale),
)