Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hamiltonian Monte Carlo with Energy Conserving Subsampling #905

Merged
merged 100 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from 85 commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
4321595
start
LysSanzMoreta Sep 8, 2020
b8001a9
start hmcecs two
LysSanzMoreta Sep 8, 2020
26219ce
structuring
LysSanzMoreta Sep 14, 2020
9a326db
small fix
LysSanzMoreta Sep 14, 2020
44d2fc1
ADDED: verlet, new log density
LysSanzMoreta Sep 16, 2020
4eeb1f0
FIXED: initialization model parameters
LysSanzMoreta Sep 18, 2020
ca9dece
FIXED: Arguments potential function
LysSanzMoreta Sep 21, 2020
f01c027
FIXED: Arguments mess
LysSanzMoreta Sep 22, 2020
cc2f1b0
FIXED? shapes error
LysSanzMoreta Sep 25, 2020
c3de253
Sampling working
LysSanzMoreta Sep 28, 2020
b17d53d
Seems to be working
LysSanzMoreta Sep 28, 2020
40be6c3
Added: Plotting and save samples to example
LysSanzMoreta Sep 29, 2020
f58dbf7
ADDED: Assertion errors
LysSanzMoreta Sep 29, 2020
3a25523
working on more than 1 chain
LysSanzMoreta Sep 30, 2020
a075857
ADDED: more plotting
LysSanzMoreta Sep 30, 2020
b40f662
ADDED: More tests and proxies
LysSanzMoreta Sep 30, 2020
54bca12
Small state fix
LysSanzMoreta Oct 1, 2020
765b3d6
Fixed : Proxies and init
LysSanzMoreta Oct 6, 2020
1fc6740
Working examples
LysSanzMoreta Oct 7, 2020
3993afd
Maybe working
LysSanzMoreta Oct 19, 2020
9dd7025
Started adding Block-Poisson
LysSanzMoreta Nov 2, 2020
bdcd352
small stuff
LysSanzMoreta Nov 2, 2020
742162c
Started adding poisson stuff
LysSanzMoreta Nov 2, 2020
81b8956
Working on documentation and poisson
LysSanzMoreta Nov 4, 2020
ffbaa9e
Added: Poisson stuff (missing initialization)
LysSanzMoreta Nov 6, 2020
1e83bb7
BlockPoissonRunning
LysSanzMoreta Nov 9, 2020
3f918ed
FIXED: potential estimator
LysSanzMoreta Nov 10, 2020
9fc0f3f
FINISHED: Block-poisson
LysSanzMoreta Nov 10, 2020
24e1c1b
MISSING: Postprocessing
LysSanzMoreta Nov 12, 2020
4243abe
FIXED: sign
LysSanzMoreta Nov 13, 2020
f996d18
More debugging
LysSanzMoreta Nov 13, 2020
1095f19
Fixed style.
OlaRonning Dec 9, 2020
0c18026
HMCECS working, fixed problems with SVI MAP and factored code.
OlaRonning Dec 14, 2020
d04f651
Merge remote-tracking branch 'origin/hmcecs' into hmcecs
OlaRonning Dec 15, 2020
b8f8830
Added MNIST BNN example using flax.
OlaRonning Dec 15, 2020
65531c2
Working potential with algebraic effect handlers.
OlaRonning Jan 7, 2021
216c2cf
Potential estimator integrated with ECS class.
OlaRonning Jan 7, 2021
d6e6700
ECS wrapper working on toy example.
OlaRonning Jan 8, 2021
c59b317
cleaned code.
OlaRonning Jan 8, 2021
dafaa6e
renamed hmcecs_utils to ecs_utils and added todos.
OlaRonning Jan 8, 2021
243e7bc
debugging taylor expansion.
OlaRonning Jan 12, 2021
c4252bb
Updated comments with reference and added test for num_blocks={} (the…
OlaRonning Jan 13, 2021
20b7350
Added pystan
OlaRonning Jan 13, 2021
4d7e4ed
Added components for variational proxy.
OlaRonning Jan 15, 2021
f867af1
Added variational_proxy, todo: fix estimator.
OlaRonning Jan 15, 2021
89c8ffe
Integrated variational proxy into ecs.
OlaRonning Jan 16, 2021
1403dfe
checkpoint: before redoing estimator.
OlaRonning Jan 18, 2021
1c6af82
Variational proxy running!
OlaRonning Jan 18, 2021
2a8cc23
Fixed minor bugs and example of hmcecs with variational proxy on logi…
OlaRonning Jan 18, 2021
7c41cee
merging
OlaRonning Jan 26, 2021
6f1f222
merged upstream
OlaRonning Jan 26, 2021
dd0426c
Refactored taylor_estimator into taylor_proxy and a difference estima…
OlaRonning Jan 26, 2021
60e0912
Sketched variational proxy in hmc_gibbs.
OlaRonning Jan 27, 2021
85957dc
Variational proxy running.
OlaRonning Jan 27, 2021
c01738f
Examples.
OlaRonning Jan 29, 2021
2151895
Moved estimate_likelihood
OlaRonning Jan 29, 2021
82ef761
Merge remote-tracking branch 'origin/feature/ecs' into feature/ecs
OlaRonning Jan 29, 2021
e46cb40
Added two moons
OlaRonning Jan 31, 2021
5b8af6e
Merge remote-tracking branch 'origin/feature/ecs' into feature/ecs
OlaRonning Jan 31, 2021
41bda0c
add gibbs_state and fix bugs
fehiepsi Feb 1, 2021
ab7888e
Integrated taylor proxy and updated API.
OlaRonning Feb 1, 2021
a9d2c0e
Bugs fixed and taylor working!
OlaRonning Feb 1, 2021
e4bf263
Updated variational proxy to new API.
OlaRonning Feb 2, 2021
508e96a
Variational proxy running on breast cancer!
OlaRonning Feb 2, 2021
3761905
Working regression
OlaRonning Feb 2, 2021
a695046
Fixed problems in variational; todo rethink dummy_sample ([] doesn't …
OlaRonning Feb 2, 2021
896cd19
add covtype example
fehiepsi Feb 2, 2021
4e9a192
Merge remote-tracking branch 'origin/feature/ecs' into feature/ecs
OlaRonning Feb 2, 2021
f5e8894
fix some bugs to substitute empty subsample indices and add some FIXME
fehiepsi Feb 2, 2021
b571228
FIXED ELBO computation and changed the weight scheme in variational p…
OlaRonning Feb 3, 2021
7d9cd11
fixed proxy_sum and added equations.
OlaRonning Feb 3, 2021
5dbac85
VECS working with AutoNormal on BreastCancer.
OlaRonning Feb 3, 2021
bb783f8
Using Likelihood as weight.
OlaRonning Feb 3, 2021
7644a08
factored out VECS
OlaRonning Feb 4, 2021
a997acb
Added simple test case.
OlaRonning Feb 5, 2021
ad2f799
Merge branch 'master' of github.com:pyro-ppl/numpyro into feature/ecs
OlaRonning Feb 5, 2021
fb95035
Cleaned.
OlaRonning Feb 5, 2021
e1150ea
Removed old HMCECS logistic examples.
OlaRonning Feb 5, 2021
c0a1c4c
removed old autoguide
OlaRonning Feb 5, 2021
5583162
Fixed linting.
OlaRonning Feb 5, 2021
8e08d76
Merge branch 'feature/ecs' of github.com:aleatory-science/numpyro int…
OlaRonning Feb 5, 2021
48cd7ef
fixed lint.
OlaRonning Feb 5, 2021
ff28eb0
Remove Poisson, factored out pandas for loading HIGGs dataset, added …
OlaRonning Feb 7, 2021
5febf7e
Fixed _block_update refactor. Missing new test cases, 2 more TODOs.
OlaRonning Feb 7, 2021
2c26173
fixed isort
OlaRonning Feb 7, 2021
6359292
Fixed comments, some 3 TODOs left.
OlaRonning Feb 8, 2021
0b886d7
Conditioned gradient computation and moved to unconstraint sapce for …
OlaRonning Feb 8, 2021
469a1f2
Fixed test for HMCECS and bumped jaxlib version.
OlaRonning Feb 8, 2021
0fb8d01
Fixed test.
OlaRonning Feb 8, 2021
5a6f629
Fixed lint.
OlaRonning Feb 8, 2021
29f3708
Corrected taylor_proxy works in unconstraint space. Added docstring a…
OlaRonning Feb 9, 2021
2aef856
Flipped syntax for geq in setup.py
OlaRonning Feb 9, 2021
cc2669e
Made default device for covtype example cpu.
OlaRonning Feb 9, 2021
63323ec
Added taylor proxy test.
OlaRonning Feb 10, 2021
89a99a2
Added test for variance.
OlaRonning Feb 10, 2021
0131f3e
Fixed lint.
OlaRonning Feb 10, 2021
c697e91
Added all log_density computation to test_estimate_likelihood and ass…
OlaRonning Feb 10, 2021
400389f
Fixed typo and isort.
OlaRonning Feb 10, 2021
2e1dddb
isort not included in previous commit.
OlaRonning Feb 10, 2021
6474df3
Fixed shadowing log_prob.
OlaRonning Feb 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import nbsphinx
import sphinx_rtd_theme


# import pkg_resources

# -*- coding: utf-8 -*-
Expand All @@ -33,6 +32,7 @@

# HACK: This is to ensure that local functions are documented by sphinx.
from numpyro.infer.hmc import hmc # noqa: E402

hmc(None, None)

# -- Project information -----------------------------------------------------
Expand Down
91 changes: 78 additions & 13 deletions examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import argparse
import time

import matplotlib.pyplot as plt

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.examples.datasets import COVTYPE, load_dataset
from numpyro.infer import MCMC, NUTS
from numpyro.infer import HMC, HMCECS, MCMC, NUTS, SA, SVI, Trace_ELBO, init_to_value
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.infer.hmc_gibbs import taylor_proxy
from numpyro.infer.reparam import NeuTraReparam


def _load_dataset():
Expand All @@ -33,22 +38,76 @@ def _load_dataset():
return features, labels


def model(data, labels):
def model(data, labels, subsample_size=None):
dim = data.shape[1]
coefs = numpyro.sample('coefs', dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
logits = jnp.dot(data, coefs)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels)
with numpyro.plate("N", data.shape[0], subsample_size=subsample_size) as idx:
logits = jnp.dot(data[idx], coefs)
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=labels[idx])


def benchmark_hmc(args, features, labels):
step_size = jnp.sqrt(0.5 / features.shape[0])
trajectory_length = step_size * args.num_steps
rng_key = random.PRNGKey(1)
start = time.time()
kernel = NUTS(model, trajectory_length=trajectory_length)
mcmc = MCMC(kernel, 0, args.num_samples)
mcmc.run(rng_key, features, labels)
mcmc.print_summary()
# a MAP estimate at the following source
# https://github.com/google/edward2/blob/master/examples/no_u_turn_sampler/logistic_regression.py#L117
ref_params = {"coefs": jnp.array([
+2.03420663e+00, -3.53567265e-02, -1.49223924e-01, -3.07049364e-01,
-1.00028366e-01, -1.46827862e-01, -1.64167881e-01, -4.20344204e-01,
+9.47479829e-02, -1.12681836e-02, +2.64442056e-01, -1.22087866e-01,
-6.00568838e-02, -3.79419506e-01, -1.06668741e-01, -2.97053963e-01,
-2.05253899e-01, -4.69537191e-02, -2.78072730e-02, -1.43250525e-01,
-6.77954629e-02, -4.34899796e-03, +5.90927452e-02, +7.23133609e-02,
+1.38526391e-02, -1.24497898e-01, -1.50733739e-02, -2.68872194e-02,
-1.80925727e-02, +3.47936489e-02, +4.03552800e-02, -9.98773426e-03,
+6.20188080e-02, +1.15002751e-01, +1.32145107e-01, +2.69109547e-01,
+2.45785132e-01, +1.19035013e-01, -2.59744357e-02, +9.94279515e-04,
+3.39266285e-02, -1.44057125e-02, -6.95222765e-02, -7.52013028e-02,
+1.21171586e-01, +2.29205526e-02, +1.47308692e-01, -8.34354162e-02,
-9.34122875e-02, -2.97472421e-02, -3.03937674e-01, -1.70958012e-01,
-1.59496680e-01, -1.88516974e-01, -1.20889175e+00])}
if args.algo == "HMC":
step_size = jnp.sqrt(0.5 / features.shape[0])
trajectory_length = step_size * args.num_steps
kernel = HMC(model, step_size=step_size, trajectory_length=trajectory_length, adapt_step_size=False,
dense_mass=args.dense_mass)
subsample_size = None
elif args.algo == "NUTS":
kernel = NUTS(model, dense_mass=args.dense_mass)
subsample_size = None
OlaRonning marked this conversation as resolved.
Show resolved Hide resolved
elif args.algo == "HMCECS":
subsample_size = 1000
inner_kernel = NUTS(model, init_strategy=init_to_value(values=ref_params),
dense_mass=args.dense_mass)
# note: if num_blocks=100, we'll update 10 index at each MCMC step
# so it took 50000 MCMC steps to iterative the whole dataset
kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(ref_params))
elif args.algo == "SA":
# NB: this kernel requires large num_warmup and num_samples
# and running on GPU is much faster than on CPU
kernel = SA(model, adapt_state_size=1000, init_strategy=init_to_value(values=ref_params))
subsample_size = None
elif args.algo == "FlowHMCECS":
subsample_size = 1000
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
params, losses = svi.run(random.PRNGKey(2), 2000, features, labels)
plt.plot(losses)
plt.show()

neutra = NeuTraReparam(guide, params)
neutra_model = neutra.reparam(model)
neutra_ref_params = {"auto_shared_latent": jnp.zeros(55)}
# no need to adapt mass matrix if the flow does a good job
inner_kernel = NUTS(neutra_model, init_strategy=init_to_value(values=neutra_ref_params),
adapt_mass_matrix=False)
kernel = HMCECS(inner_kernel, num_blocks=100, proxy=taylor_proxy(neutra_ref_params))
else:
raise ValueError("Invalid algorithm, either 'HMC', 'NUTS', or 'HMCECS'.")
mcmc = MCMC(kernel, args.num_warmup, args.num_samples)
mcmc.run(rng_key, features, labels, subsample_size, extra_fields=("accept_prob",))
print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
mcmc.print_summary(exclude_deterministic=False)
print('\nMCMC elapsed time:', time.time() - start)


Expand All @@ -60,14 +119,20 @@ def main(args):
if __name__ == '__main__':
assert numpyro.__version__.startswith('0.5.0')
parser = argparse.ArgumentParser(description="parse args")
parser.add_argument('-n', '--num-samples', default=100, type=int, help='number of samples')
parser.add_argument('-n', '--num-samples', default=1000, type=int, help='number of samples')
parser.add_argument('--num-warmup', default=1000, type=int, help='number of warmup steps')
parser.add_argument('--num-steps', default=10, type=int, help='number of steps (for "HMC")')
parser.add_argument('--num-chains', nargs='?', default=1, type=int)
parser.add_argument('--algo', default='NUTS', type=str, help='whether to run "HMC" or "NUTS"')
parser.add_argument('--device', default='cpu', type=str, help='use "cpu" or "gpu".')
parser.add_argument('--algo', default='HMCECS', type=str,
help='whether to run "HMCECS", "NUTS", "HMCECS", "SA" or "FlowHMCECS"')
parser.add_argument('--dense-mass', action="store_true")
parser.add_argument('--x64', action="store_true")
parser.add_argument('--device', default='gpu', type=str, help='use "cpu" or "gpu".')
args = parser.parse_args()

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
if args.x64:
numpyro.enable_x64()

main(args)
1 change: 1 addition & 0 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def effective_sample_size(x):
:return: effective sample size of ``x``.
:rtype: numpy.ndarray
"""

assert x.ndim >= 2
assert x.shape[1] >= 2

Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def sample(self, key, sample_shape=()):

@validate_sample
def log_prob(self, value):
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale)
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * self.scale) # TODO:Added jnp.abs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this TODO mean there is a bug here?

Copy link
Member Author

@OlaRonning OlaRonning Feb 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think this is necessary as the constraint on self.scale is already positive. @LysSanzMoreta let me know if you disagree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this TODO too. I think the math here is right.

value_scaled = (value - self.loc) / self.scale
return -0.5 * value_scaled ** 2 - normalize_term

Expand Down
34 changes: 24 additions & 10 deletions numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from collections import namedtuple
import csv
import gzip
import io
import os
import pickle
import struct
from urllib.parse import urlparse
from urllib.request import urlretrieve
import warnings
import zipfile

import numpy as np
Expand All @@ -23,52 +25,47 @@
'.data'))
os.makedirs(DATA_DIR, exist_ok=True)


dset = namedtuple('dset', ['name', 'urls'])


BASEBALL = dset('baseball', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/EfronMorrisBB.txt',
])


COVTYPE = dset('covtype', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/covtype.zip',
])


DIPPER_VOLE = dset('dipper_vole', [
'https://github.com/pyro-ppl/datasets/blob/master/dipper_vole.zip?raw=true',
])


MNIST = dset('mnist', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-images-idx3-ubyte.gz',
'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/train-labels-idx1-ubyte.gz',
'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-images-idx3-ubyte.gz',
'https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/t10k-labels-idx1-ubyte.gz',
])


SP500 = dset('SP500', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/SP500.csv',
])


UCBADMIT = dset('ucbadmit', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/UCBadmit.csv',
])


LYNXHARE = dset('lynxhare', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/LynxHare.txt',
])


JSB_CHORALES = dset('jsb_chorales', [
'https://d2hg8soec8ck9v.cloudfront.net/datasets/polyphonic/jsb_chorales.pickle',
])

HIGGS = dset("higgs", [
"https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz",
])


def _download(dset):
for url in dset.urls:
Expand All @@ -86,7 +83,7 @@ def _load_baseball():
def train_test_split(file):
train, test, player_names = [], [], []
with open(file, 'r') as f:
csv_reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
csv_reader = csv.reader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
for row in csv_reader:
player_names.append(row['FirstName'] + ' ' + row['LastName'])
at_bats, hits = row['At-Bats'], row['Hits']
Expand Down Expand Up @@ -240,6 +237,21 @@ def _load_jsb_chorales():
return processed_dataset


def _load_higgs():
warnings.warn("Higgs is a 2.6 GB dataset")
_download(HIGGS)

file_path = os.path.join(DATA_DIR, 'HIGGS.csv.gz')
with io.TextIOWrapper(gzip.open(file_path, 'rb')) as f:
csv_reader = csv.reader(f, delimiter=',', quoting=csv.QUOTE_NONE)
obs = []
data = []
for row in csv_reader:
obs.append(row[0])
data.append(row[1:])
return np.stack(obs), np.stack(data)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved


def _load(dset):
if dset == BASEBALL:
return _load_baseball()
Expand All @@ -257,6 +269,8 @@ def _load(dset):
return _load_lynxhare()
elif dset == JSB_CHORALES:
return _load_jsb_chorales()
elif dset == HIGGS:
return _load_higgs()
raise ValueError('Dataset - {} not found.'.format(dset.name))


Expand Down
21 changes: 16 additions & 5 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

"""
This provides a small set of effect handlers in NumPyro that are modeled
after Pyro's `poutine <http://docs.pyro.ai/en/stable/poutine.html>`_ module.
Expand Down Expand Up @@ -136,6 +135,7 @@ class trace(Messenger):
'type': 'sample',
'value': DeviceArray(-0.20584235, dtype=float32)})])
"""

def __enter__(self):
super(trace, self).__enter__()
self.trace = OrderedDict()
Expand All @@ -146,7 +146,7 @@ def postprocess_message(self, msg):
# skip recording helper messages e.g. `control_flow`, `to_data`, `to_funsor`
# which has no name
return
assert not(msg['type'] == 'sample' and msg['name'] in self.trace), \
assert not (msg['type'] == 'sample' and msg['name'] in self.trace), \
'all sites must have unique names but got `{}` duplicated'.format(msg['name'])
self.trace[msg['name']] = msg.copy()

Expand Down Expand Up @@ -191,6 +191,7 @@ class replay(Messenger):
-0.20584235
>>> assert replayed_trace['a']['value'] == exec_trace['a']['value']
"""

def __init__(self, fn=None, guide_trace=None):
assert guide_trace is not None
self.guide_trace = guide_trace
Expand Down Expand Up @@ -234,6 +235,7 @@ class block(Messenger):
>>> assert 'a' not in trace_block_a
>>> assert 'b' in trace_block_a
"""

def __init__(self, fn=None, hide_fn=None, hide=None):
if hide_fn is not None:
self.hide_fn = hide_fn
Expand Down Expand Up @@ -350,6 +352,7 @@ class condition(Messenger):
>>> assert exec_trace['a']['value'] == -1
>>> assert exec_trace['a']['is_observed']
"""

def __init__(self, fn=None, data=None, condition_fn=None):
self.condition_fn = condition_fn
self.data = data
Expand Down Expand Up @@ -386,6 +389,7 @@ class infer_config(Messenger):
:param fn: a stochastic function (callable containing NumPyro primitive calls)
:param config_fn: a callable taking a site and returning an infer dict
"""

def __init__(self, fn=None, config_fn=None):
super().__init__(fn)
self.config_fn = config_fn
Expand Down Expand Up @@ -470,6 +474,7 @@ class mask(Messenger):
:param mask: a boolean or a boolean-valued array for masking elementwise log
probability of sample sites (`True` includes a site, `False` excludes a site).
"""

def __init__(self, fn=None, mask=True):
if lax.dtype(mask) != 'bool':
raise ValueError("`mask` should be a bool array.")
Expand Down Expand Up @@ -506,6 +511,7 @@ class reparam(Messenger):
:class:`~numpyro.infer.reparam.Reparam` or None.
:type config: dict or callable
"""

def __init__(self, fn=None, config=None):
assert isinstance(config, dict) or callable(config)
self.config = config
Expand Down Expand Up @@ -550,6 +556,7 @@ class scale(Messenger):
of log probability.
:type scale: float or numpy.ndarray
"""

def __init__(self, fn=None, scale=1.):
if not_jax_tracer(scale):
if np.any(np.less_equal(scale, 0)):
Expand Down Expand Up @@ -587,6 +594,7 @@ class scope(Messenger):
:param str prefix: a string to prepend to sample names
:param str divider: a string to join the prefix and sample name; default to `'/'`
"""

def __init__(self, fn=None, prefix='', divider='/'):
self.prefix = prefix
self.divider = divider
Expand Down Expand Up @@ -638,6 +646,7 @@ class seed(Messenger):
>>> y = handlers.seed(model, rng_seed=1)()
>>> assert x == y
"""

def __init__(self, fn=None, rng_seed=None):
if isinstance(rng_seed, int) or (isinstance(rng_seed, jnp.ndarray) and not jnp.shape(rng_seed)):
rng_seed = random.PRNGKey(rng_seed)
Expand All @@ -647,10 +656,10 @@ def __init__(self, fn=None, rng_seed=None):
super(seed, self).__init__(fn)

def process_message(self, msg):
if (msg['type'] == 'sample' and not msg['is_observed'] and
msg['kwargs']['rng_key'] is None) or msg['type'] in ['prng_key', 'plate', 'control_flow']:
# no need to create a new key when value is available
if (msg['type'] == 'sample' and not msg['is_observed'] and msg['kwargs']['rng_key'] is None) \
or msg['type'] in ['prng_key', 'plate', 'control_flow']:
if msg['value'] is not None:
# no need to create a new key when value is available
return
self.rng_key, rng_key_sample = random.split(self.rng_key)
msg['kwargs']['rng_key'] = rng_key_sample
Expand Down Expand Up @@ -691,6 +700,7 @@ class substitute(Messenger):
>>> exec_trace = trace(substitute(model, {'a': -1})).get_trace()
>>> assert exec_trace['a']['value'] == -1
"""

def __init__(self, fn=None, data=None, substitute_fn=None):
self.substitute_fn = substitute_fn
self.data = data
Expand Down Expand Up @@ -760,6 +770,7 @@ class do(Messenger):
>>> assert not exec_trace['z'].get('stop', None)
>>> assert z_square == 1
"""

def __init__(self, fn=None, data=None):
self.data = data
self._intervener_id = str(id(self))
Expand Down
Loading