Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for mutable params #1016

Merged
merged 35 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f99e5df
sketch for mutable states
fehiepsi Apr 21, 2021
425cf8f
sketch how mutable work
fehiepsi Apr 22, 2021
c39f67e
modifying optim is a bad design, but it helps figure out what is need…
fehiepsi Apr 23, 2021
0966c8c
master
fehiepsi Apr 23, 2021
8ebef6d
Merge remote-tracking branch 'upstream/master' into mutable-params
fehiepsi Apr 25, 2021
decfe92
add a reasonable implementation
fehiepsi Apr 25, 2021
ae21555
add test for mutable_state
fehiepsi Apr 25, 2021
d43a393
add more complete test for mutable state
fehiepsi Apr 25, 2021
392c509
fix bug at svi
fehiepsi Apr 25, 2021
078ba2e
add has_aux to Minimize
fehiepsi Apr 25, 2021
5a4bcc7
add apply_rng to the construction of haiku
fehiepsi Apr 26, 2021
7f27055
apply with state
fehiepsi Apr 26, 2021
f63f277
make lint
fehiepsi Apr 26, 2021
76109bb
update black
fehiepsi Apr 26, 2021
8bb2b7d
Merge remote-tracking branch 'upstream/master' into mutable-params
fehiepsi May 29, 2021
5101309
make the docs clearer
fehiepsi May 29, 2021
62eaff8
add note that optimizers will skip mutable parameters
fehiepsi May 29, 2021
d999815
make sure that module is mutable
fehiepsi Jun 6, 2021
950a0ee
update flax.nn to flax.linen
fehiepsi Jun 9, 2021
3344323
temp save
fehiepsi Jun 9, 2021
a31e21d
Merge remote-tracking branch 'upstream/master' into mutable-params
fehiepsi Jun 13, 2021
ff7de9d
temporary save the working files
fehiepsi Jun 17, 2021
bca7f9c
support mutable state in ELBO, clean up the API
fehiepsi Jun 17, 2021
a9989db
support mutable in plate
fehiepsi Jun 17, 2021
328c023
add TODO for elbo
fehiepsi Jun 17, 2021
a3525ce
clean up the api
fehiepsi Jun 20, 2021
f0ba52c
make mutable work for flax
fehiepsi Jun 20, 2021
777e1d2
simplify the logic: state is None == params is None
fehiepsi Jun 20, 2021
49a0a12
not expose mutable primitive
fehiepsi Jun 20, 2021
c8224be
restrict tfp version
fehiepsi Jun 20, 2021
265948e
use mutable msg in TraceMeanField_ELBO
fehiepsi Jun 20, 2021
6c5ac70
fix bug that make test_mutable fail
fehiepsi Jun 20, 2021
c040b5f
Merge branch 'master' into mutable-params
fehiepsi Jun 23, 2021
5f9974f
Update tracegraph elbo subclass
fehiepsi Jun 23, 2021
5321b23
Make format
fehiepsi Jun 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ install: FORCE
pip install -e .[dev,doc,test,examples]

doctest: FORCE
$(MAKE) -C docs doctest
JAX_PLATFORM_NAME=cpu $(MAKE) -C docs doctest

test: lint FORCE
pytest -v test
Expand Down
3 changes: 2 additions & 1 deletion examples/covtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ def benchmark_hmc(args, features, labels):
subsample_size = 1000
guide = AutoBNAFNormal(model, num_flows=1, hidden_factors=[8])
svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO())
params, losses = svi.run(random.PRNGKey(2), 2000, features, labels)
svi_result = svi.run(random.PRNGKey(2), 2000, features, labels)
params, losses = svi_result.params, svi_result.losses
plt.plot(losses)
plt.show()

Expand Down
5 changes: 2 additions & 3 deletions examples/hmcecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ def run_hmcecs(hmcecs_key, args, data, obs, inner_kernel):
optimizer = numpyro.optim.Adam(step_size=1e-3)
guide = autoguide.AutoDelta(model)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
params, losses = svi.run(
svi_key, args.num_svi_steps, data, obs, args.subsample_size
)
svi_result = svi.run(svi_key, args.num_svi_steps, data, obs, args.subsample_size)
params, losses = svi_result.params, svi_result.losses
ref_params = {"theta": params["theta_auto_loc"]}

# taylor proxy estimates log likelihood (ll) by
Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def cond(pred, true_fun, false_fun, operand):
... return cond(cluster > 0, true_fun, false_fun, None)
>>>
>>> svi = SVI(model, guide, numpyro.optim.Adam(1e-2), Trace_ELBO(num_particles=100))
>>> params, losses = svi.run(random.PRNGKey(0), num_steps=2500)
>>> svi_result = svi.run(random.PRNGKey(0), num_steps=2500)

.. warning:: This is an experimental utility function that allows users to use
JAX control flow with NumPyro's effect handlers. Currently, `sample` and
Expand Down
218 changes: 178 additions & 40 deletions numpyro/contrib/module.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ class scope(Messenger):
"""
This handler prepend a prefix followed by a divider to the name of sample sites.

Example::
**Example**

.. doctest::

Expand Down Expand Up @@ -745,7 +745,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg["type"] not in ("sample", "param", "plate")) or msg.get(
if (msg["type"] not in ("sample", "param", "mutable", "plate")) or msg.get(
"_control_flow_done", False
):
if msg["type"] == "control_flow":
Expand Down
195 changes: 124 additions & 71 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,60 @@
from numpyro.infer.util import get_importance_trace, log_density


class Trace_ELBO:
class ELBO:
"""
Base class for all ELBO objectives.

Subclasses should implement either :meth:`loss` or :meth:`loss_with_mutable_state`.

:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
"""

def __init__(self, num_particles=1):
self.num_particles = num_particles

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""
return self.loss_with_mutable_state(
rng_key, param_map, model, guide, *args, **kwargs
)["loss"]

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
"""
Likes :meth:`loss` but also update and return the mutable state, which stores the
values at :func:`~numpyro.mutable` sites.

:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: a tuple of ELBO loss and the mutable state
"""
raise NotImplementedError("This ELBO objective does not support mutable state.")


class Trace_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI. The estimator is constructed
along the lines of references [1] and [2]. There are no restrictions on the
Expand Down Expand Up @@ -43,52 +96,56 @@ class Trace_ELBO:
def __init__(self, num_particles=1):
self.num_particles = num_particles

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
def single_particle_elbo(rng_key):
params = param_map.copy()
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
guide_log_density, guide_trace = log_density(
seeded_guide, args, kwargs, param_map
)
mutable_params = {
name: site["value"]
for name, site in guide_trace.items()
if site["type"] == "mutable"
}
params.update(mutable_params)
seeded_model = replay(seeded_model, guide_trace)
model_log_density, _ = log_density(seeded_model, args, kwargs, param_map)
model_log_density, model_trace = log_density(
seeded_model, args, kwargs, params
)
mutable_params.update(
{
name: site["value"]
for name, site in model_trace.items()
if site["type"] == "mutable"
}
)

# log p(z) - log q(z)
elbo = model_log_density - guide_log_density
return elbo
elbo_particle = model_log_density - guide_log_density
if mutable_params:
if self.num_particles == 1:
return elbo_particle, mutable_params
else:
raise ValueError(
"Currently, we only support mutable states with num_particles=1."
)
else:
return elbo_particle, None

# Return (-elbo) since by convention we do gradient descent on a loss and
# the ELBO is a lower bound that needs to be maximized.
if self.num_particles == 1:
return -single_particle_elbo(rng_key)
elbo, mutable_state = single_particle_elbo(rng_key)
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))


class ELBO(Trace_ELBO):
def __init__(self, num_particles=1):
warnings.warn(
"Using ELBO directly in SVI is deprecated. Please use Trace_ELBO class instead.",
FutureWarning,
)
super().__init__(num_particles=num_particles)
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


def _get_log_prob_sum(site):
Expand Down Expand Up @@ -128,7 +185,7 @@ def _check_mean_field_requirement(model_trace, guide_trace):
)


class TraceMeanField_ELBO(Trace_ELBO):
class TraceMeanField_ELBO(ELBO):
"""
A trace implementation of ELBO-based SVI. This is currently the only
ELBO estimator in NumPyro that uses analytic KL divergences when those
Expand All @@ -146,30 +203,31 @@ class TraceMeanField_ELBO(Trace_ELBO):
dependency structures.
"""

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the ELBO with an estimator that uses num_particles many samples/particles.

:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:return: negative of the Evidence Lower Bound (ELBO) to be minimized.
"""

def loss_with_mutable_state(
self, rng_key, param_map, model, guide, *args, **kwargs
):
def single_particle_elbo(rng_key):
params = param_map.copy()
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
seeded_guide = seed(guide, guide_seed)
subs_guide = substitute(seeded_guide, data=param_map)
guide_trace = trace(subs_guide).get_trace(*args, **kwargs)
subs_model = substitute(replay(seeded_model, guide_trace), data=param_map)
mutable_params = {
name: site["value"]
for name, site in guide_trace.items()
if site["type"] == "mutable"
}
params.update(mutable_params)
subs_model = substitute(replay(seeded_model, guide_trace), data=params)
model_trace = trace(subs_model).get_trace(*args, **kwargs)
mutable_params.update(
{
name: site["value"]
for name, site in model_trace.items()
if site["type"] == "mutable"
}
)
_check_mean_field_requirement(model_trace, guide_trace)

elbo_particle = 0
Expand All @@ -196,16 +254,26 @@ def single_particle_elbo(rng_key):
assert site["infer"].get("is_auxiliary")
elbo_particle = elbo_particle - _get_log_prob_sum(site)

return elbo_particle
if mutable_params:
if self.num_particles == 1:
return elbo_particle, mutable_params
else:
raise ValueError(
"Currently, we only support mutable states with num_particles=1."
)
else:
return elbo_particle, None

if self.num_particles == 1:
return -single_particle_elbo(rng_key)
elbo, mutable_state = single_particle_elbo(rng_key)
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


class RenyiELBO(Trace_ELBO):
class RenyiELBO(ELBO):
r"""
An implementation of Renyi's :math:`\alpha`-divergence
variational inference following reference [1].
Expand Down Expand Up @@ -235,24 +303,9 @@ def __init__(self, alpha=0, num_particles=2):
"for the case alpha = 1."
)
self.alpha = alpha
super(RenyiELBO, self).__init__(num_particles=num_particles)
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Evaluates the Renyi ELBO with an estimator that uses num_particles many samples/particles.

:param jax.random.PRNGKey rng_key: random number generator seed.
:param dict param_map: dictionary of current parameter values keyed by site
name.
:param model: Python callable with NumPyro primitives for the model.
:param guide: Python callable with NumPyro primitives for the guide.
:param args: arguments to the model / guide (these can possibly vary during
the course of fitting).
:param kwargs: keyword arguments to the model / guide (these can possibly vary
during the course of fitting).
:returns: negative of the Renyi Evidence Lower Bound (ELBO) to be minimized.
"""

def single_particle_elbo(rng_key):
model_seed, guide_seed = random.split(rng_key)
seeded_model = seed(model, model_seed)
Expand Down Expand Up @@ -458,7 +511,7 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes):
return downstream_costs, downstream_guide_cost_nodes


class TraceGraph_ELBO:
class TraceGraph_ELBO(ELBO):
"""
A TraceGraph implementation of ELBO-based SVI. The gradient estimator
is constructed along the lines of reference [1] specialized to the case
Expand All @@ -479,7 +532,7 @@ class TraceGraph_ELBO:
"""

def __init__(self, num_particles=1):
self.num_particles = num_particles
super().__init__(num_particles=num_particles)

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
"""
Expand Down
Loading