Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement more robust jitter init (resolves #4107) #4298

Merged
merged 4 commits into from
Dec 5, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which,
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))
- Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207)
- The `InverseGamma` distribution now implements a `logcdf`. [#3944](https://github.com/pymc-devs/pymc3/pull/3944)
- Make starting jitter methods for nuts sampling more robust by resampling values that lead to non-finite probabilities. A new optional argument `jitter-max-retries` can be passed to `pm.sample()` and `pm.init_nuts()` to control the maximum number of retries per chain. [4298](https://github.com/pymc-devs/pymc3/pull/4298)

### Documentation
- Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)).
Expand Down
67 changes: 53 additions & 14 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
chains_and_samples,
)
from .vartypes import discrete_types
from .exceptions import IncorrectArgumentsError
from .exceptions import IncorrectArgumentsError, SamplingError
from .parallel_sampling import _cpu_count, Draw
from pymc3.step_methods.hmc import quadpotential
import pymc3 as pm
Expand Down Expand Up @@ -246,6 +246,7 @@ def sample(
discard_tuned_samples=True,
compute_convergence_checks=True,
callback=None,
jitter_max_retries=10,
*,
return_inferencedata=None,
idata_kwargs: dict = None,
Expand Down Expand Up @@ -331,6 +332,10 @@ def sample(
the ``draw.chain`` argument can be used to determine which of the active chains the sample
is drawn from.
Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback.
jitter_max_retries: int
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
init methods.
return_inferencedata : bool, default=False
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
Defaults to `False`, but we'll switch to `True` in an upcoming release.
Expand Down Expand Up @@ -490,6 +495,7 @@ def sample(
model=model,
random_seed=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
**kwargs,
)
if start is None:
Expand Down Expand Up @@ -1946,13 +1952,52 @@ def sample_prior_predictive(
return prior


def _init_jitter(model, chains, jitter_max_retries):
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.

pymc3.util.check_start_vals is used to test whether the jittered starting values produce
a finite log probability. Invalid values are resampled unless `jitter_max_retries` is achieved,
in which case the last sampled values are returned
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
model : Model
twiecki marked this conversation as resolved.
Show resolved Hide resolved
chains : int
jitter_max_retries: int
Maximum number of repeated attempts at initializing values (per chain).

Returns
-------
start : ``pymc3.model.Point``
Starting point for sampler
"""
start = []
for _ in range(chains):
for i in range(jitter_max_retries + 1):
mean = {var: val.copy() for var, val in model.test_point.items()}
for val in mean.values():
val[...] += 2 * np.random.rand(*val.shape) - 1

if i < jitter_max_retries:
try:
check_start_vals(mean, model)
except SamplingError:
pass
else:
break

start.append(mean)
return start


def init_nuts(
init="auto",
chains=1,
n_init=500000,
model=None,
random_seed=None,
progressbar=True,
jitter_max_retries=10,
**kwargs,
):
"""Set up the mass matrix initialization for NUTS.
Expand All @@ -1967,7 +2012,7 @@ def init_nuts(
Initialization method to use.

* auto: Choose a default initialization method automatically.
Currently, this is `'jitter+adapt_diag'`, but this can change in the future. If you
Currently, this is ``jitter+adapt_diag``, but this can change in the future. If you
depend on the exact behaviour, choose an initialization method explicitly.
* adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
variance of the tuning samples. All chains use the test value (usually the prior mean)
Expand All @@ -1994,6 +2039,10 @@ def init_nuts(
model : Model (optional if in ``with`` context)
progressbar : bool
Whether or not to display a progressbar for advi sampling.
jitter_max_retries: int
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
init methods.
**kwargs : keyword arguments
Extra keyword arguments are forwarded to pymc3.NUTS.

Expand Down Expand Up @@ -2038,12 +2087,7 @@ def init_nuts(
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
elif init == "jitter+adapt_diag":
start = []
for _ in range(chains):
mean = {var: val.copy() for var, val in model.test_point.items()}
for val in mean.values():
val[...] += 2 * np.random.rand(*val.shape) - 1
start.append(mean)
start = _init_jitter(model, chains, jitter_max_retries)
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
var = np.ones_like(mean)
potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10)
Expand Down Expand Up @@ -2125,12 +2169,7 @@ def init_nuts(
cov = np.eye(model.ndim)
potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
elif init == "jitter+adapt_full":
start = []
for _ in range(chains):
mean = {var: val.copy() for var, val in model.test_point.items()}
for val in mean.values():
val[...] += 2 * np.random.rand(*val.shape) - 1
start.append(mean)
start = _init_jitter(model, chains, jitter_max_retries)
mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0)
cov = np.eye(model.ndim)
potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10)
Expand Down
22 changes: 22 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,28 @@ def _mocked_init_nuts(*args, **kwargs):
pm.sample(tune=1, draws=0, chains=1, init=init, start=start)


@pytest.mark.parametrize(
"testval, jitter_max_retries, expectation",
[
(0, 0, pytest.raises(SamplingError)),
(0, 1, pytest.raises(SamplingError)),
(0, 4, does_not_raise()),
(0, 10, does_not_raise()),
(1, 0, does_not_raise()),
],
)
def test_init_jitter(testval, jitter_max_retries, expectation):
with pm.Model() as m:
pm.HalfNormal("x", transform=None, testval=testval)

with expectation:
# Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1)
# and positive (valid) when it returns 1 (jitter = 1)
with mock.patch("numpy.random.rand", side_effect=[0, 0, 0, 1, 0]):
start = pm.sampling._init_jitter(m, chains=1, jitter_max_retries=jitter_max_retries)
pm.util.check_start_vals(start, m)


@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
with pm.Model() as pmodel:
Expand Down