Skip to content

Commit

Permalink
Improve tuning by skipping the first samples + add new experimental t…
Browse files Browse the repository at this point in the history
…uning method (#5004)

* Fix issue in hmc gradient storage

* Skip first samples during NUTS adaptation

* Add test and doc for jitter+adapt_diag_grad

* Improve tests of init methods

* Add new tuning method to release notes

* Remove old gradient mass matrix adaptation

* Remove weight argument in quadpotential add_sample
  • Loading branch information
aseyboldt authored Sep 22, 2021
1 parent bcc40ce commit 4f8ad5d
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 104 deletions.
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
- The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical
observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)).
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc3/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
- ...

### Maintenance
Expand Down
62 changes: 23 additions & 39 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,25 +287,7 @@ def sample(
by default. See ``discard_tuned_samples``.
init : str
Initialization method to use for auto-assigned NUTS samplers.
* auto: Choose a default initialization method automatically.
Currently, this is ``jitter+adapt_diag``, but this can change in the future.
If you depend on the exact behaviour, choose an initialization method explicitly.
* adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
variance of the tuning samples. All chains use the test value (usually the prior mean)
as starting point.
* jitter+adapt_diag: Same as ``adapt_diag``, but add uniform jitter in [-1, 1] to the
starting point in each chain.
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
sample variance of the tuning samples.
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
on the variance of the gradients during tuning. This is **experimental** and might be
removed in a future release.
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map: Use the MAP as starting point. This is discouraged.
* adapt_full: Adapt a dense mass matrix using the sample covariances
See `pm.init_nuts` for a list of all options.
step : function or iterable of functions
A step function or collection of functions. If there are variables without step methods,
step methods for those variables will be assigned automatically. By default the NUTS step
Expand Down Expand Up @@ -516,6 +498,7 @@ def sample(
random_seed=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
**kwargs,
)
if start is None:
Expand Down Expand Up @@ -2078,6 +2061,7 @@ def init_nuts(
random_seed=None,
progressbar=True,
jitter_max_retries=10,
tune=None,
**kwargs,
):
"""Set up the mass matrix initialization for NUTS.
Expand All @@ -2099,11 +2083,11 @@ def init_nuts(
as starting point.
* jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in
[-1, 1] as starting point in each chain.
* jitter+adapt_diag_grad:
An experimental initialization method that uses information from gradients and samples
during tuning.
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
sample variance of the tuning samples.
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
on the variance of the gradients during tuning. This is **experimental** and might be
removed in a future release.
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map: Use the MAP as starting point. This is discouraged.
Expand Down Expand Up @@ -2174,24 +2158,24 @@ def init_nuts(
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
elif init == "advi+adapt_diag_grad":
approx: pm.MeanField = pm.fit(
random_seed=random_seed,
n=n_init,
method="advi",
model=model,
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
elif init == "jitter+adapt_diag_grad":
start = _init_jitter(model, model.initial_point, chains, jitter_max_retries)
mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0)
var = np.ones_like(mean)
n = len(var)

if tune is not None and tune > 250:
stop_adaptation = tune - 50
else:
stop_adaptation = None

potential = quadpotential.QuadPotentialDiagAdaptExp(
n,
mean,
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
)
start = approx.sample(draws=chains)
start = list(start)
std_apoint = approx.std.eval()
cov = std_apoint ** 2
mean = approx.mean.get_value()
weight = 50
n = len(cov)
potential = quadpotential.QuadPotentialDiagAdaptGrad(n, mean, cov, weight)
elif init == "advi+adapt_diag":
approx = pm.fit(
random_seed=random_seed,
Expand Down
6 changes: 2 additions & 4 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
self.start_energy = np.array(start.energy)

self.left = self.right = start
self.proposal = Proposal(
start.q.data, start.q_grad.data, start.energy, 1.0, start.model_logp
)
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp)
self.depth = 0
self.log_size = 0
self.log_weighted_accept_sum = -np.inf
Expand Down Expand Up @@ -350,7 +348,7 @@ def _single_step(self, left, epsilon):
log_size = -energy_change
proposal = Proposal(
right.q.data,
right.q_grad.data,
right.q_grad,
right.energy,
log_p_accept_weighted,
right.model_logp,
Expand Down
Loading

0 comments on commit 4f8ad5d

Please sign in to comment.