Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Improper Distribution #612

Merged
merged 14 commits into from
Jun 6, 2020
16 changes: 16 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ ExpandedDistribution
:show-inheritance:
:member-order: bysource

ImproperUniform
---------------
.. autoclass:: numpyro.distributions.distribution.ImproperUniform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Independent
-----------
.. autoclass:: numpyro.distributions.distribution.Independent
Expand Down Expand Up @@ -372,6 +380,14 @@ ZeroInflatedPoisson
Constraints
===========

Constraint
----------
.. autoclass:: numpyro.distributions.constraints.Constraint
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

boolean
-------
.. autodata:: numpyro.distributions.constraints.boolean
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from numpyro.distributions.distribution import (
Distribution,
ExpandedDistribution,
ImproperUniform,
Independent,
MaskedDistribution,
TransformedDistribution,
Expand Down Expand Up @@ -87,6 +88,7 @@
'Gumbel',
'HalfCauchy',
'HalfNormal',
'ImproperUniform',
'Independent',
'InverseGamma',
'LKJ',
Expand Down
13 changes: 13 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,22 @@


class Constraint(object):
"""
Abstract base class for constraints.

A constraint object represents a region over which a variable is valid,
e.g. within which a variable can be optimized.
"""
def __call__(self, x):
raise NotImplementedError

def check(self, value):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expose this doc so I want to match PyTorch behavior here.

"""
Returns a byte tensor of `sample_shape + batch_shape` indicating
whether each event in value satisfies this constraint.
"""
return self(value)


class _Boolean(Constraint):
def __call__(self, x):
Expand Down
62 changes: 61 additions & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import jax.numpy as np
from jax import lax

from numpyro.distributions.constraints import is_dependent, real
from numpyro.distributions.constraints import Constraint, is_dependent, real # noqa: F401
from numpyro.distributions.transforms import Transform
from numpyro.distributions.util import lazy_property, sum_rightmost, validate_sample
from numpyro.util import not_jax_tracer
Expand Down Expand Up @@ -350,6 +350,66 @@ def variance(self):
return np.broadcast_to(self.base_dist.variance, self.batch_shape + self.event_shape)


class ImproperUniform(Distribution):
"""
A helper distribution with zero :meth:`log_prob` over the `support` domain.

.. note:: `sample` method is not implemented for this distribution.

**Usage**::

>>> from numpyro.distributions import constraints
>>>
>>> def model():
... # ordered vector with length 10
... x = sample('x', ImproperUniform(constraints.ordered_vector, (), event_shape=(10,))
...
... # real matrix with shape (3, 4)
... y = sample('y', ImproperUniform(constraints.real, (), event_shape=(3, 4))
...
... # a shape-(6, 8) batch of length-5 vectors greater than 3
... z = sample('z', ImproperUniform(constraints.greater_than(3), (6, 8), event_shape=(5,))

If you want to set improper prior over all values greater than `a`, where `a` is
another random variable, you might use

>>> x = sample('x', ImproperUniform(constraints.greater_than(a), (), event_shape=()))

or if you want to reparameterize it

>>> from numpyro.distributions import constraints, transforms
>>> from numpyro.contrib.reparam import reparam, TransformReparam
>>>
>>> with reparam(config={'x': TransformReparam()}):
... x = sample('x',
... TransformedDistribution(ImproperUniform(constraints.positive, (), ()),
... transforms.AffineTransform(a, 1)))

:param Constraint support: the support of this distribution.
:param tuple batch_shape: batch shape of this distribution. It is usually safe to
set `batch_shape=()`.
:param tuple event_shape: event shape of this distribution.
"""
arg_constraints = {}

def __init__(self, support, batch_shape, event_shape, validate_args=None):
self.support = support
super().__init__(batch_shape, event_shape, validate_args=validate_args)

@validate_sample
def log_prob(self, value):
batch_shape = np.shape(value)[:np.ndim(value) - len(self.event_shape)]
batch_shape = lax.broadcast_shapes(batch_shape, self.batch_shape)
return np.zeros(batch_shape)

def _validate_sample(self, value):
mask = super(ImproperUniform, self)._validate_sample(value)
batch_dim = np.ndim(value) - len(self.event_shape)
if batch_dim < np.ndim(mask):
mask = np.all(np.reshape(mask, np.shape(mask)[:batch_dim] + (-1,)), -1)
return mask


class Independent(Distribution):
"""
Reinterprets batch dimensions of a distribution as event dims by shifting
Expand Down
31 changes: 23 additions & 8 deletions numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

def init_to_median(site=None, num_samples=15):
"""
Initialize to the prior median.
Initialize to the prior median. For priors with no `.sample` method implemented,
we defer to the :func:`init_to_uniform` strategy.

:param int num_samples: number of prior points to calculate median.
"""
Expand All @@ -22,13 +23,17 @@ def init_to_median(site=None, num_samples=15):
if site['type'] == 'sample' and not site['is_observed']:
rng_key = site['kwargs'].get('rng_key')
sample_shape = site['kwargs'].get('sample_shape')
samples = site['fn'].sample(rng_key, sample_shape=(num_samples,) + sample_shape)
return np.median(samples, axis=0)
try:
samples = site['fn'].sample(rng_key, sample_shape=(num_samples,) + sample_shape)
return np.median(samples, axis=0)
except NotImplementedError:
return init_to_uniform(site)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just return prototype_value = np.full(site['fn'].shape(), np.nan) instead of calling init_to_uniform so that we don't need to additional work like calling random.split?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For substitute_fn, I think it only works if we substitue(seed(model), ...) because seed will give rng_key for the site we want to apply substitute_fn. Inside substitute_fn, we dont use sample primitive, so there is no random.split here IIUC.



def init_to_prior(site=None):
"""
Initialize to a prior sample.
Initialize to a prior sample. For priors with no `.sample` method implemented,
we defer to the :func:`init_to_uniform` strategy.
"""
return init_to_median(site, num_samples=1)

Expand All @@ -47,13 +52,23 @@ def init_to_uniform(site=None, radius=2):
rng_key = site['kwargs'].get('rng_key')
sample_shape = site['kwargs'].get('sample_shape')
rng_key, subkey = random.split(rng_key)
transform = biject_to(fn.support)
# this is used to interpret the changes of event_shape in
# domain and codomain spaces
prototype_value = fn.sample(subkey, sample_shape=())
transform = biject_to(fn.support)
unconstrained_event_shape = np.shape(transform.inv(prototype_value))
try:
prototype_value = fn.sample(subkey, sample_shape=())
unconstrained_shape = np.shape(transform.inv(prototype_value))
except NotImplementedError:
# XXX: this works for ImproperUniform prior,
# we can't use this logic for general priors
# because some distributions such as TransformedDistribution might
# have wrong event_shape.
prototype_value = np.full(fn.event_shape, np.nan)
unconstrained_event_shape = np.shape(transform.inv(prototype_value))
unconstrained_shape = fn.batch_shape + unconstrained_event_shape

unconstrained_samples = dist.Uniform(-radius, radius).sample(
rng_key, sample_shape=sample_shape + unconstrained_event_shape)
rng_key, sample_shape=sample_shape + unconstrained_shape)
return transform(unconstrained_samples)


Expand Down
11 changes: 6 additions & 5 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,9 @@ def _find_valid_params(rng_key, exit_early=False):
return (init_params, pe, z_grad), is_valid


def get_model_transforms(model, model_args=(), model_kwargs=None):
def _get_model_transforms(model, model_args=(), model_kwargs=None):
model_kwargs = {} if model_kwargs is None else model_kwargs
seeded_model = seed(model, random.PRNGKey(0))
model_trace = trace(seeded_model).get_trace(*model_args, **model_kwargs)
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
inv_transforms = {}
# model code may need to be replayed in the presence of deterministic sites
replay_model = False
Expand Down Expand Up @@ -344,8 +343,10 @@ def initialize_model(rng_key, model,
at `deterministic` sites in the model.
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
inv_transforms, replay_model, model_trace = get_model_transforms(
model, model_args, model_kwargs)
substituted_model = substitute(seed(model, rng_key if np.ndim(rng_key) == 1 else rng_key[0]),
substitute_fn=init_strategy)
inv_transforms, replay_model, model_trace = _get_model_transforms(
substituted_model, model_args, model_kwargs)
constrained_values = {k: v['value'] for k, v in model_trace.items()
if v['type'] == 'sample' and not v['is_observed']}

Expand Down
8 changes: 6 additions & 2 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,8 @@ def test_gamma_poisson_log_prob(shape):
def test_log_prob_gradient(jax_dist, sp_dist, params):
if jax_dist in [dist.LKJ, dist.LKJCholesky]:
pytest.skip('we have separated tests for LKJCholesky distribution')
rng_key = random.PRNGKey(0)

rng_key = random.PRNGKey(0)
value = jax_dist(*params).sample(rng_key)

def fn(*args):
Expand Down Expand Up @@ -670,7 +670,7 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
assert jax_dist(*oob_params)

# Invalid parameter values throw ValueError
if not dependent_constraint:
if not dependent_constraint and jax_dist is not dist.ImproperUniform:
with pytest.raises(ValueError):
jax_dist(*oob_params, validate_args=True)

Expand Down Expand Up @@ -792,6 +792,10 @@ def test_biject_to(constraint, shape):
x = random.normal(rng_key, shape)
y = transform(x)

# test inv work for NaN arrays:
x_nan = transform.inv(np.full(np.shape(y), np.nan))
assert (x_nan.shape == x.shape)

# test codomain
batch_shape = shape if event_dim == 0 else shape[:-1]
assert_array_equal(transform.codomain(y), np.ones(batch_shape, dtype=np.bool_))
Expand Down
14 changes: 14 additions & 0 deletions test/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,17 @@ def model(data):
for name, p in init_params[0].items():
# XXX: the result is equal if we disable fast-math-mode
assert_allclose(p[i], init_params_i[0][name], atol=1e-6)


def test_improper_expand():

def model():
population = np.array([1000., 2000., 3000.])
with numpyro.plate("region", 3):
numpyro.sample("incidence",
dist.ImproperUniform(support=constraints.interval(0, population),
Copy link
Member

@neerajprad neerajprad Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a log_prob shape check here?

d = dist.ImproperUniform(support=constraints.interval(0, population)
incidence = numpyro.sample("incidence",
                           d,
                           batch_shape=(3,),
                           event_shape=event_shape))
assert d.log_prob(incidence).shape == (3,)

batch_shape=(3,),
event_shape=(3,)))

model_info = initialize_model(random.PRNGKey(0), model)
assert model_info.param_info.z['incidence'].shape == (3, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted to verify for this example is that with event_shape=(), this should be (3,). Is that correct?

Copy link
Member Author

@fehiepsi fehiepsi Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me test it. I don't know what happens when users provide an invalid event_shape.

edit: Oh, now I see what you and Fritz meant before. Here, support is batched... interesting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the result is still (3, 3). Something is wrong...

Copy link
Member Author

@fehiepsi fehiepsi Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @neerajprad! I just added that test case and fixed the issue at init_to_uniform. Does the fix sound correct to you?

  • before
prototype_value = np.full(site['fn'].event_shape, np.nan)
unconstrained_event_shape = np.shape(transform.inv(prototype_value))
unconstrained_shape = site['fn'].batch_shape + unconstrained_event_shape
  • after
prototype_value = np.full(site['fn'].shape(), np.nan)
unconstrained_shape = np.shape(transform.inv(prototype_value))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, that makes sense.

2 changes: 1 addition & 1 deletion test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def test_improper_prior():

def model(data):
mean = numpyro.sample('mean', dist.Normal(0, 1).mask(False))
std = numpyro.sample('std', dist.LogNormal(0, 1).mask(False))
std = numpyro.sample('std', dist.ImproperUniform(dist.constraints.positive, ()))
return numpyro.sample('obs', dist.Normal(mean, std), obs=data)

data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,))
Expand Down