diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 3f73b652c4f..ed6d43e1ab3 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -48,6 +48,7 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which, - Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721)) - Refactored MvNormal.random method with better handling of sample, batch and event shapes. [#4207](https://github.com/pymc-devs/pymc3/pull/4207) - The `InverseGamma` distribution now implements a `logcdf`. [#3944](https://github.com/pymc-devs/pymc3/pull/3944) +- Make starting jitter methods for nuts sampling more robust by resampling values that lead to non-finite probabilities. A new optional argument `jitter-max-retries` can be passed to `pm.sample()` and `pm.init_nuts()` to control the maximum number of retries per chain. [4298](https://github.com/pymc-devs/pymc3/pull/4298) ### Documentation - Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)). diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 1dfa82c5d79..aa041606ee3 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -60,7 +60,7 @@ chains_and_samples, ) from .vartypes import discrete_types -from .exceptions import IncorrectArgumentsError +from .exceptions import IncorrectArgumentsError, SamplingError from .parallel_sampling import _cpu_count, Draw from pymc3.step_methods.hmc import quadpotential import pymc3 as pm @@ -246,6 +246,7 @@ def sample( discard_tuned_samples=True, compute_convergence_checks=True, callback=None, + jitter_max_retries=10, *, return_inferencedata=None, idata_kwargs: dict = None, @@ -331,6 +332,10 @@ def sample( the ``draw.chain`` argument can be used to determine which of the active chains the sample is drawn from. Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. + jitter_max_retries : int + Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter + that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full`` + init methods. return_inferencedata : bool, default=False Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False) Defaults to `False`, but we'll switch to `True` in an upcoming release. @@ -490,6 +495,7 @@ def sample( model=model, random_seed=random_seed, progressbar=progressbar, + jitter_max_retries=jitter_max_retries, **kwargs, ) if start is None: @@ -1946,6 +1952,44 @@ def sample_prior_predictive( return prior +def _init_jitter(model, chains, jitter_max_retries): + """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. + + pymc3.util.check_start_vals is used to test whether the jittered starting values produce + a finite log probability. Invalid values are resampled unless `jitter_max_retries` is achieved, + in which case the last sampled values are returned. + + Parameters + ---------- + model : pymc3.Model + chains : int + jitter_max_retries : int + Maximum number of repeated attempts at initializing values (per chain). + + Returns + ------- + start : ``pymc3.model.Point`` + Starting point for sampler + """ + start = [] + for _ in range(chains): + for i in range(jitter_max_retries + 1): + mean = {var: val.copy() for var, val in model.test_point.items()} + for val in mean.values(): + val[...] += 2 * np.random.rand(*val.shape) - 1 + + if i < jitter_max_retries: + try: + check_start_vals(mean, model) + except SamplingError: + pass + else: + break + + start.append(mean) + return start + + def init_nuts( init="auto", chains=1, @@ -1953,6 +1997,7 @@ def init_nuts( model=None, random_seed=None, progressbar=True, + jitter_max_retries=10, **kwargs, ): """Set up the mass matrix initialization for NUTS. @@ -1967,7 +2012,7 @@ def init_nuts( Initialization method to use. * auto: Choose a default initialization method automatically. - Currently, this is `'jitter+adapt_diag'`, but this can change in the future. If you + Currently, this is ``jitter+adapt_diag``, but this can change in the future. If you depend on the exact behaviour, choose an initialization method explicitly. * adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the variance of the tuning samples. All chains use the test value (usually the prior mean) @@ -1994,6 +2039,10 @@ def init_nuts( model : Model (optional if in ``with`` context) progressbar : bool Whether or not to display a progressbar for advi sampling. + jitter_max_retries : int + Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter + that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full`` + init methods. **kwargs : keyword arguments Extra keyword arguments are forwarded to pymc3.NUTS. @@ -2038,12 +2087,7 @@ def init_nuts( var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10) elif init == "jitter+adapt_diag": - start = [] - for _ in range(chains): - mean = {var: val.copy() for var, val in model.test_point.items()} - for val in mean.values(): - val[...] += 2 * np.random.rand(*val.shape) - 1 - start.append(mean) + start = _init_jitter(model, chains, jitter_max_retries) mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) var = np.ones_like(mean) potential = quadpotential.QuadPotentialDiagAdapt(model.ndim, mean, var, 10) @@ -2125,12 +2169,7 @@ def init_nuts( cov = np.eye(model.ndim) potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10) elif init == "jitter+adapt_full": - start = [] - for _ in range(chains): - mean = {var: val.copy() for var, val in model.test_point.items()} - for val in mean.values(): - val[...] += 2 * np.random.rand(*val.shape) - 1 - start.append(mean) + start = _init_jitter(model, chains, jitter_max_retries) mean = np.mean([model.dict_to_array(vals) for vals in start], axis=0) cov = np.eye(model.ndim) potential = quadpotential.QuadPotentialFullAdapt(model.ndim, mean, cov, 10) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index c8c6f134838..9d4018354b2 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -815,6 +815,28 @@ def _mocked_init_nuts(*args, **kwargs): pm.sample(tune=1, draws=0, chains=1, init=init, start=start) +@pytest.mark.parametrize( + "testval, jitter_max_retries, expectation", + [ + (0, 0, pytest.raises(SamplingError)), + (0, 1, pytest.raises(SamplingError)), + (0, 4, does_not_raise()), + (0, 10, does_not_raise()), + (1, 0, does_not_raise()), + ], +) +def test_init_jitter(testval, jitter_max_retries, expectation): + with pm.Model() as m: + pm.HalfNormal("x", transform=None, testval=testval) + + with expectation: + # Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1) + # and positive (valid) when it returns 1 (jitter = 1) + with mock.patch("numpy.random.rand", side_effect=[0, 0, 0, 1, 0]): + start = pm.sampling._init_jitter(m, chains=1, jitter_max_retries=jitter_max_retries) + pm.util.check_start_vals(start, m) + + @pytest.fixture(scope="class") def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]: with pm.Model() as pmodel: