Skip to content

Commit

Permalink
Merge branch 'master' into mutable-params
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jun 23, 2021
2 parents 6c5ac70 + 42763d4 commit c040b5f
Show file tree
Hide file tree
Showing 13 changed files with 1,192 additions and 79 deletions.
11 changes: 11 additions & 0 deletions docs/source/svi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ Trace_ELBO
:show-inheritance:
:member-order: bysource


TraceGraph_ELBO
---------------

.. autoclass:: numpyro.infer.elbo.TraceGraph_ELBO
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource


TraceMeanField_ELBO
-------------------

Expand Down
147 changes: 94 additions & 53 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import inspect

import numpy as np

import jax.numpy as jnp
Expand Down Expand Up @@ -103,42 +105,78 @@ def _transform_to_bijector_constraint(constraint):
return BijectorTransform(constraint.bijector)


_TFPDistributionMeta = type(tfd.Distribution)
class _TFPDistributionMeta(type(NumPyroDistribution)):
def __getitem__(cls, tfd_class):
assert issubclass(tfd_class, tfd.Distribution)

def init(self, *args, **kwargs):
self.tfp_dist = tfd_class(*args, **kwargs)

# XXX: we create this mixin class to avoid metaclass conflict between TFP and NumPyro Ditribution
class _TFPMixinMeta(_TFPDistributionMeta, type(NumPyroDistribution)):
def __init__(cls, name, bases, dct):
# XXX: _TFPDistributionMeta.__init__ registers cls as a PyTree
# for some reasons, when defining metaclass of TFPDistributionMixin to be _TFPMixinMeta,
# TFPDistributionMixin will be registered as a PyTree 2 times, which is not allowed
# in JAX, so we skip registering TFPDistributionMixin as a PyTree.
if name == "TFPDistributionMixin":
super(_TFPDistributionMeta, cls).__init__(name, bases, dct)
else:
super(_TFPMixinMeta, cls).__init__(name, bases, dct)
init.__signature__ = inspect.signature(tfd_class.__init__)

_PyroDist = type(tfd_class.__name__, (TFPDistribution,), {})
_PyroDist.tfd_class = tfd_class
_PyroDist.__init__ = init
return _PyroDist


class TFPDistributionMixin(NumPyroDistribution, metaclass=_TFPMixinMeta):
class TFPDistribution(NumPyroDistribution, metaclass=_TFPDistributionMeta):
"""
A mixin layer to make TensorFlow Probability (TFP) distribution compatible
with NumPyro internal.
A thin wrapper for TensorFlow Probability (TFP) distributions. The constructor
has the same signature as the corresponding TFP distribution.
This class can be used to convert a TFP distribution to a NumPyro-compatible one
as follows::
d = TFPDistribution[tfd.Normal](0, 1)
"""

def __init_subclass__(cls, **kwargs):
# skip register pytree because TFP distributions are already pytrees
super(object, cls).__init_subclass__(**kwargs)
tfd_class = None

def __call__(self, *args, **kwargs):
key = kwargs.pop("rng_key")
sample_intermediates = kwargs.pop("sample_intermediates", False)
if sample_intermediates:
return self.sample(*args, seed=key, **kwargs), []
return self.sample(*args, seed=key, **kwargs)
def __getattr__(self, name):
# return parameters from the constructor
if name in self.tfp_dist.parameters:
return self.tfp_dist.parameters[name]
elif name in ["dtype", "reparameterization_type"]:
return getattr(self.tfp_dist, name)
raise AttributeError(name)

@property
def batch_shape(self):
return self.tfp_dist.batch_shape

@property
def event_shape(self):
return self.tfp_dist.event_shape

@property
def has_rsample(self):
return self.tfp_dist.reparameterization_type is tfd.FULLY_REPARAMETERIZED

def sample(self, key, sample_shape=()):
return self.tfp_dist.sample(sample_shape=sample_shape, seed=key)

def log_prob(self, value):
return self.tfp_dist.log_prob(value)

@property
def mean(self):
return self.tfp_dist.mean()

@property
def variance(self):
return self.tfp_dist.variance()

def cdf(self, value):
return self.tfp_dist.cdf(value)

def icdf(self, q):
return self.tfp_dist.quantile(q)

@property
def support(self):
bijector = self._default_event_space_bijector()
bijector = self.tfp_dist._default_event_space_bijector()
if bijector is not None:
return BijectorConstraint(bijector)
else:
Expand All @@ -150,40 +188,43 @@ def is_discrete(self):
return self.support is None


class InverseGamma(tfd.InverseGamma, TFPDistributionMixin):
arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}

InverseGamma = TFPDistribution[tfd.InverseGamma]
InverseGamma.arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}

class OneHotCategorical(tfd.OneHotCategorical, TFPDistributionMixin):
arg_constraints = {"logits": constraints.real_vector}
has_enumerate_support = True
support = constraints.simplex
is_discrete = True

def enumerate_support(self, expand=True):
n = self.event_shape[-1]
values = jnp.identity(n, dtype=jnp.result_type(self.dtype))
values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
return values
def _onehot_enumerate_support(self, expand=True):
n = self.event_shape[-1]
values = jnp.identity(n, dtype=jnp.result_type(self.dtype))
values = values.reshape((n,) + (1,) * len(self.batch_shape) + (n,))
if expand:
values = jnp.broadcast_to(values, (n,) + self.batch_shape + (n,))
return values


class OrderedLogistic(tfd.OrderedLogistic, TFPDistributionMixin):
arg_constraints = {"cutpoints": constraints.ordered_vector, "loc": constraints.real}
OneHotCategorical = TFPDistribution[tfd.OneHotCategorical]
OneHotCategorical.arg_constraints = {"logits": constraints.real_vector}
OneHotCategorical.has_enumerate_support = True
OneHotCategorical.support = constraints.simplex
OneHotCategorical.is_discrete = True
OneHotCategorical.enumerate_support = _onehot_enumerate_support

OrderedLogistic = TFPDistribution[tfd.OrderedLogistic]
OrderedLogistic.arg_constraints = {
"cutpoints": constraints.ordered_vector,
"loc": constraints.real,
}

class Pareto(tfd.Pareto, TFPDistributionMixin):
arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}
Pareto = TFPDistribution[tfd.Pareto]
Pareto.arg_constraints = {
"concentration": constraints.positive,
"scale": constraints.positive,
}


__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistributionMixin"]
__all__ = ["BijectorConstraint", "BijectorTransform", "TFPDistribution"]
_len_all = len(__all__)
for _name, _Dist in tfd.__dict__.items():
if not isinstance(_Dist, type):
Expand All @@ -196,7 +237,7 @@ class Pareto(tfd.Pareto, TFPDistributionMixin):
try:
_PyroDist = locals()[_name]
except KeyError:
_PyroDist = type(_name, (_Dist, TFPDistributionMixin), {})
_PyroDist = TFPDistribution[_Dist]
_PyroDist.__module__ = __name__
if hasattr(numpyro_dist, _name):
numpyro_dist_class = getattr(numpyro_dist, _name)
Expand All @@ -212,7 +253,7 @@ class Pareto(tfd.Pareto, TFPDistributionMixin):

_PyroDist.__doc__ = """
Wraps `{}.{} <https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/distributions/{}>`_
with :class:`~numpyro.contrib.tfp.distributions.TFPDistributionMixin`.
with :class:`~numpyro.contrib.tfp.distributions.TFPDistribution`.
""".format(
_Dist.__module__, _Dist.__name__, _Dist.__name__
)
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ class Delta(Distribution):
"v": constraints.dependent(is_discrete=False),
"log_density": constraints.real,
}
reparameterized_params = ["v", "log_density"]
reparametrized_params = ["v", "log_density"]

def __init__(self, v=0.0, log_density=0.0, event_dim=0, validate_args=None):
if event_dim > jnp.ndim(v):
Expand Down
9 changes: 8 additions & 1 deletion numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from numpyro.infer.barker import BarkerMH
from numpyro.infer.elbo import ELBO, RenyiELBO, Trace_ELBO, TraceMeanField_ELBO
from numpyro.infer.elbo import (
ELBO,
RenyiELBO,
Trace_ELBO,
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.hmc import HMC, NUTS
from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs
from numpyro.infer.initialization import (
Expand Down Expand Up @@ -43,5 +49,6 @@
"SA",
"SVI",
"Trace_ELBO",
"TraceGraph_ELBO",
"TraceMeanField_ELBO",
]
14 changes: 10 additions & 4 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
sum_rightmost,
)
from numpyro.infer.elbo import Trace_ELBO
from numpyro.infer.initialization import init_to_median
from numpyro.infer.util import init_to_uniform, initialize_model
from numpyro.infer.initialization import init_to_median, init_to_uniform
from numpyro.infer.util import helpful_support_errors, initialize_model
from numpyro.nn.auto_reg_nn import AutoregressiveNN
from numpyro.nn.block_neural_arn import BlockNeuralAutoregressiveNN
from numpyro.util import not_jax_tracer
Expand Down Expand Up @@ -147,6 +147,10 @@ def _setup_prototype(self, *args, **kwargs):
self._prototype_plate_sizes = {}
for name, site in self.prototype_trace.items():
if site["type"] == "sample":
if not site["is_observed"] and site["fn"].is_discrete:
# raise support errors early for discrete sites
with helpful_support_errors(site):
biject_to(site["fn"].support)
for frame in site["cond_indep_stack"]:
if frame.name in self._prototype_frames:
assert (
Expand Down Expand Up @@ -266,7 +270,8 @@ def __call__(self, *args, **kwargs):
):
result[name] = numpyro.sample(name, site_fn)
else:
transform = biject_to(site["fn"].support)
with helpful_support_errors(site):
transform = biject_to(site["fn"].support)
guide_dist = dist.TransformedDistribution(site_fn, transform)
result[name] = numpyro.sample(name, guide_dist)

Expand Down Expand Up @@ -485,7 +490,8 @@ def __call__(self, *args, **kwargs):

for name, unconstrained_value in self._unpack_latent(latent).items():
site = self.prototype_trace[name]
transform = biject_to(site["fn"].support)
with helpful_support_errors(site):
transform = biject_to(site["fn"].support)
value = transform(unconstrained_value)
event_ndim = site["fn"].event_dim
if numpyro.get_mask() is False:
Expand Down
Loading

0 comments on commit c040b5f

Please sign in to comment.