-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
create a censored normal distribution (#428)
- Loading branch information
Showing
4 changed files
with
303 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Distributions | ||
=========== | ||
|
||
.. automodule:: pyrenew.distributions | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# numpydoc ignore=GL08 | ||
|
||
from pyrenew.distributions.censorednormal import CensoredNormal | ||
|
||
__all__ = [ | ||
"CensoredNormal", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
# numpydoc ignore=GL08 | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpyro | ||
import numpyro.util | ||
from numpyro.distributions import constraints | ||
from numpyro.distributions.util import promote_shapes, validate_sample | ||
|
||
|
||
class CensoredNormal(numpyro.distributions.Distribution): | ||
""" | ||
Censored normal distribution under which samples | ||
are truncated to lie within a specified interval. | ||
This implementation is adapted from | ||
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py | ||
""" | ||
|
||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} | ||
pytree_data_fields = ( | ||
"loc", | ||
"scale", | ||
"lower_limit", | ||
"upper_limit", | ||
"_support", | ||
) | ||
|
||
def __init__( | ||
self, | ||
loc=0, | ||
scale=1, | ||
lower_limit=-jnp.inf, | ||
upper_limit=jnp.inf, | ||
validate_args=None, | ||
): | ||
""" | ||
Default constructor | ||
Parameters | ||
---------- | ||
loc : ArrayLike or float, optional | ||
The mean of the normal distribution. | ||
Defaults to 0. | ||
scale : ArrayLike or float, optional | ||
The standard deviation of the normal | ||
distribution. Must be positive. Defaults to 1. | ||
lower_limit : float, optional | ||
The lower bound of the interval for censoring. | ||
Defaults to -inf (no lower bound). | ||
upper_limit : float, optional | ||
The upper bound of the interval for censoring. | ||
Defaults to inf (no upper bound). | ||
validate_args : bool, optional | ||
If True, checks if parameters are valid. | ||
Defaults to None. | ||
Returns | ||
------- | ||
None | ||
""" | ||
self.loc, self.scale = promote_shapes(loc, scale) | ||
self.lower_limit = lower_limit | ||
self.upper_limit = upper_limit | ||
self._support = constraints.interval( | ||
self.lower_limit, self.upper_limit | ||
) | ||
batch_shape = jax.lax.broadcast_shapes( | ||
jnp.shape(loc), jnp.shape(scale) | ||
) | ||
self.normal_ = numpyro.distributions.Normal( | ||
loc=loc, scale=scale, validate_args=validate_args | ||
) | ||
super().__init__(batch_shape=batch_shape, validate_args=validate_args) | ||
|
||
@constraints.dependent_property(is_discrete=False, event_dim=0) | ||
def support(self): # numpydoc ignore=GL08 | ||
return self._support | ||
|
||
def sample(self, key, sample_shape=()): | ||
""" | ||
Generates samples from the censored normal distribution. | ||
Returns | ||
------- | ||
Array | ||
Containing samples from the censored normal distribution. | ||
""" | ||
assert numpyro.util.is_prng_key(key) | ||
result = self.normal_.sample(key, sample_shape) | ||
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit) | ||
|
||
@validate_sample | ||
def log_prob(self, value): | ||
""" | ||
Computes the log probability density of a given value(s) under | ||
the censored normal distribution. | ||
Returns | ||
------- | ||
Array | ||
Containing log probability of the given value(s) | ||
under the censored normal distribution | ||
""" | ||
rescaled_ulim = (self.upper_limit - self.loc) / self.scale | ||
rescaled_llim = (self.lower_limit - self.loc) / self.scale | ||
lim_val = jnp.where( | ||
value <= self.lower_limit, | ||
jax.scipy.special.log_ndtr(rescaled_llim), | ||
jax.scipy.special.log_ndtr(-rescaled_ulim), | ||
) | ||
# we exploit the fact that for the | ||
# standard normal, P(x > a) = P(-x < a) | ||
# to compute the log complementary CDF | ||
inbounds = jnp.logical_and( | ||
value > self.lower_limit, value < self.upper_limit | ||
) | ||
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# numpydoc ignore=GL08 | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpyro | ||
import pytest | ||
from numpy.testing import ( | ||
assert_array_almost_equal, | ||
assert_array_equal, | ||
assert_equal, | ||
) | ||
|
||
from pyrenew.distributions import CensoredNormal | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["loc", "scale", "lower_limit", "upper_limit", "in_val", "l_val", "h_val"], | ||
[ | ||
[ | ||
jnp.array([0]), | ||
jnp.array([2.0, 1.0]), | ||
-1, | ||
1, | ||
jnp.array([0, 0.5]), | ||
-2, | ||
2, | ||
], | ||
[ | ||
jnp.array([0, 1]), | ||
jnp.array([1.0]), | ||
-1, | ||
2, | ||
jnp.array([0, 0.5]), | ||
-2, | ||
3, | ||
], | ||
], | ||
) | ||
def test_interval_censored_normal_distribution( | ||
loc, | ||
scale, | ||
lower_limit, | ||
upper_limit, | ||
in_val, | ||
l_val, | ||
h_val, | ||
): | ||
""" | ||
Tests the censored normal distribution samples | ||
within the limit and calculation of log probability | ||
""" | ||
censored_dist = CensoredNormal( | ||
loc=loc, scale=scale, lower_limit=lower_limit, upper_limit=upper_limit | ||
) | ||
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale) | ||
|
||
# test samples within the bounds | ||
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,)) | ||
assert jnp.all(samp >= lower_limit) | ||
assert jnp.all(samp <= upper_limit) | ||
|
||
# test log prob of values within bounds | ||
assert_array_equal( | ||
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val) | ||
) | ||
|
||
# test log prob of values lower than the limit | ||
assert_array_almost_equal( | ||
censored_dist.log_prob(l_val), | ||
jax.scipy.special.log_ndtr((lower_limit - loc) / scale), | ||
) | ||
|
||
# test log prob of values higher than the limit | ||
assert_array_almost_equal( | ||
censored_dist.log_prob(h_val), | ||
jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale), | ||
) | ||
|
||
# test_broadcasting | ||
assert_equal(samp.shape[-1], max(loc.shape[0], scale.shape[0])) | ||
|
||
# test support of the distribution | ||
assert_equal(censored_dist.support.lower_bound, lower_limit) | ||
assert_equal(censored_dist.support.upper_bound, upper_limit) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["loc", "scale", "lower_limit", "in_val", "l_val"], | ||
[ | ||
[0, 1, -5, jnp.array([-2, 1]), -6], | ||
], | ||
) | ||
def test_left_censored_normal_distribution( | ||
loc, | ||
scale, | ||
lower_limit, | ||
in_val, | ||
l_val, | ||
): | ||
""" | ||
Tests the lower censored normal distribution samples | ||
within the limit and calculation of log probability | ||
""" | ||
censored_dist = CensoredNormal( | ||
loc=loc, | ||
scale=scale, | ||
lower_limit=lower_limit, | ||
) | ||
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale) | ||
|
||
# test samples within the bounds | ||
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,)) | ||
assert jnp.all(samp >= lower_limit) | ||
|
||
# test log prob of values within bounds | ||
assert_array_equal( | ||
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val) | ||
) | ||
|
||
# test log prob of values lower than the limit | ||
assert_array_almost_equal( | ||
censored_dist.log_prob(l_val), | ||
jax.scipy.special.log_ndtr((lower_limit - loc) / scale), | ||
) | ||
|
||
# test support of the distribution | ||
assert_equal(censored_dist.support.lower_bound, lower_limit) | ||
assert censored_dist.support.upper_bound == jnp.inf | ||
|
||
|
||
@pytest.mark.parametrize( | ||
["loc", "scale", "upper_limit", "in_val", "h_val"], | ||
[ | ||
[0, 1, 3, jnp.array([1, 2]), 5], | ||
], | ||
) | ||
def test_right_censored_normal_distribution( | ||
loc, | ||
scale, | ||
upper_limit, | ||
in_val, | ||
h_val, | ||
): | ||
""" | ||
Tests the upper censored normal distribution samples | ||
within the limit and calculation of log probability | ||
""" | ||
censored_dist = CensoredNormal( | ||
loc=loc, scale=scale, upper_limit=upper_limit | ||
) | ||
normal_dist = numpyro.distributions.Normal(loc=loc, scale=scale) | ||
|
||
# test samples within the bounds | ||
samp = censored_dist.sample(jax.random.PRNGKey(0), sample_shape=(100,)) | ||
assert jnp.all(samp <= upper_limit) | ||
|
||
# test log prob of values within bounds | ||
assert_array_equal( | ||
censored_dist.log_prob(in_val), normal_dist.log_prob(in_val) | ||
) | ||
|
||
# test log prob of values higher than the limit | ||
assert_array_almost_equal( | ||
censored_dist.log_prob(h_val), | ||
jax.scipy.special.log_ndtr(-(upper_limit - loc) / scale), | ||
) | ||
|
||
# test support of the distribution | ||
assert_equal(censored_dist.support.upper_bound, upper_limit) | ||
assert censored_dist.support.lower_bound == -jnp.inf |