Skip to content

Commit

Permalink
create a censored normal distribution (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbidari authored Sep 6, 2024
1 parent 3c5fbe7 commit 1e204f6
Show file tree
Hide file tree
Showing 4 changed files with 303 additions and 0 deletions.
7 changes: 7 additions & 0 deletions docs/source/msei_reference/distributions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Distributions
===========

.. automodule:: pyrenew.distributions
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions pyrenew/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# numpydoc ignore=GL08

from pyrenew.distributions.censorednormal import CensoredNormal

__all__ = [
"CensoredNormal",
]
119 changes: 119 additions & 0 deletions pyrenew/distributions/censorednormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# numpydoc ignore=GL08

import jax
import jax.numpy as jnp
import numpyro
import numpyro.util
from numpyro.distributions import constraints
from numpyro.distributions.util import promote_shapes, validate_sample


class CensoredNormal(numpyro.distributions.Distribution):
"""
Censored normal distribution under which samples
are truncated to lie within a specified interval.
This implementation is adapted from
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py
"""

arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
pytree_data_fields = (
"loc",
"scale",
"lower_limit",
"upper_limit",
"_support",
)

def __init__(
self,
loc=0,
scale=1,
lower_limit=-jnp.inf,
upper_limit=jnp.inf,
validate_args=None,
):
"""
Default constructor
Parameters
----------
loc : ArrayLike or float, optional
The mean of the normal distribution.
Defaults to 0.
scale : ArrayLike or float, optional
The standard deviation of the normal
distribution. Must be positive. Defaults to 1.
lower_limit : float, optional
The lower bound of the interval for censoring.
Defaults to -inf (no lower bound).
upper_limit : float, optional
The upper bound of the interval for censoring.
Defaults to inf (no upper bound).
validate_args : bool, optional
If True, checks if parameters are valid.
Defaults to None.
Returns
-------
None
"""
self.loc, self.scale = promote_shapes(loc, scale)
self.lower_limit = lower_limit
self.upper_limit = upper_limit
self._support = constraints.interval(
self.lower_limit, self.upper_limit
)
batch_shape = jax.lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale)
)
self.normal_ = numpyro.distributions.Normal(
loc=loc, scale=scale, validate_args=validate_args
)
super().__init__(batch_shape=batch_shape, validate_args=validate_args)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self): # numpydoc ignore=GL08
return self._support

def sample(self, key, sample_shape=()):
"""
Generates samples from the censored normal distribution.
Returns
-------
Array
Containing samples from the censored normal distribution.
"""
assert numpyro.util.is_prng_key(key)
result = self.normal_.sample(key, sample_shape)
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit)

@validate_sample
def log_prob(self, value):
"""
Computes the log probability density of a given value(s) under
the censored normal distribution.
Returns
-------
Array
Containing log probability of the given value(s)
under the censored normal distribution
"""
rescaled_ulim = (self.upper_limit - self.loc) / self.scale
rescaled_llim = (self.lower_limit - self.loc) / self.scale
lim_val = jnp.where(
value <= self.lower_limit,
jax.scipy.special.log_ndtr(rescaled_llim),
jax.scipy.special.log_ndtr(-rescaled_ulim),
)
# we exploit the fact that for the
# standard normal, P(x > a) = P(-x < a)
# to compute the log complementary CDF
inbounds = jnp.logical_and(
value > self.lower_limit, value < self.upper_limit
)
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val)

return result
170 changes: 170 additions & 0 deletions test/test_censorednormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# numpydoc ignore=GL08

import jax
import jax.numpy as jnp
import numpyro
import pytest
from numpy.testing import (
assert_array_almost_equal,
assert_array_equal,
assert_equal,
)

from pyrenew.distributions import CensoredNormal


@pytest.mark.parametrize(
["loc", "scale", "lower_limit", "upper_limit", "in_val", "l_val", "h_val"],
[
[
jnp.array([0]),
jnp.array([2.0, 1.0]),
-1,
1,
jnp.array([0, 0.5]),
-2,
2,
],
[
jnp.array([0, 1]),
jnp.array([1.0]),
-1,
2,
jnp.array([0, 0.5]),
-2,
3,
],
],
)
def test_interval_censored_normal_distribution(
loc,
scale,
lower_limit,
upper_limit,
in_val,
l_val,
h_val,
):
"""
Tests the censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(
loc=loc, scale=scale, lower_limit=lower_limit, upper_limit=upper_limit
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp >= lower_limit)
assert jnp.all(samp <= upper_limit)

# test log prob of values within bounds
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values lower than the limit
assert_array_almost_equal(
censored_dist.log_prob(l_val),
jax.scipy.special.log_ndtr((lower_limit - loc) / scale),
)

# test log prob of values higher than the limit
assert_array_almost_equal(
censored_dist.log_prob(h_val),
jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale),
)

# test_broadcasting
assert_equal(samp.shape[-1], max(loc.shape[0], scale.shape[0]))

# test support of the distribution
assert_equal(censored_dist.support.lower_bound, lower_limit)
assert_equal(censored_dist.support.upper_bound, upper_limit)


@pytest.mark.parametrize(
["loc", "scale", "lower_limit", "in_val", "l_val"],
[
[0, 1, -5, jnp.array([-2, 1]), -6],
],
)
def test_left_censored_normal_distribution(
loc,
scale,
lower_limit,
in_val,
l_val,
):
"""
Tests the lower censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(
loc=loc,
scale=scale,
lower_limit=lower_limit,
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp >= lower_limit)

# test log prob of values within bounds
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values lower than the limit
assert_array_almost_equal(
censored_dist.log_prob(l_val),
jax.scipy.special.log_ndtr((lower_limit - loc) / scale),
)

# test support of the distribution
assert_equal(censored_dist.support.lower_bound, lower_limit)
assert censored_dist.support.upper_bound == jnp.inf


@pytest.mark.parametrize(
["loc", "scale", "upper_limit", "in_val", "h_val"],
[
[0, 1, 3, jnp.array([1, 2]), 5],
],
)
def test_right_censored_normal_distribution(
loc,
scale,
upper_limit,
in_val,
h_val,
):
"""
Tests the upper censored normal distribution samples
within the limit and calculation of log probability
"""
censored_dist = CensoredNormal(
loc=loc, scale=scale, upper_limit=upper_limit
)
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale)

# test samples within the bounds
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,))
assert jnp.all(samp <= upper_limit)

# test log prob of values within bounds
assert_array_equal(
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val)
)

# test log prob of values higher than the limit
assert_array_almost_equal(
censored_dist.log_prob(h_val),
jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale),
)

# test support of the distribution
assert_equal(censored_dist.support.upper_bound, upper_limit)
assert censored_dist.support.lower_bound == -jnp.inf

0 comments on commit 1e204f6

Please sign in to comment.