diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index d6694ff5f38..a9af4411d7e 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -42,6 +42,8 @@ This new version of `Theano-PyMC` comes with an experimental JAX backend which, - Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)). - Numerically improved stickbreaking transformation - e.g. for the `Dirichlet` distribution. [#4129](https://github.com/pymc-devs/pymc3/pull/4129) - Enabled the `Multinomial` distribution to handle batch sizes that have more than 2 dimensions. [#4169](https://github.com/pymc-devs/pymc3/pull/4169) +- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116)) +- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721)) ### Documentation - Added a new notebook demonstrating how to incorporate sampling from a conjugate Dirichlet-multinomial posterior density in conjunction with other step methods (see [#4199](https://github.com/pymc-devs/pymc3/pull/4199)). diff --git a/pymc3/model.py b/pymc3/model.py index 55ca443d0b2..96c317912cc 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1368,7 +1368,7 @@ def check_test_point(self, test_point=None, round_vals=2): test_point = self.test_point return Series( - {RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs}, + {RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs}, name="Log-probability of test_point", ) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 4d03674e6dd..37375bf22f3 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -54,6 +54,7 @@ PGBART, ) from .util import ( + check_start_vals, update_start_vals, get_untransformed_name, is_transformed_name, @@ -419,7 +420,16 @@ def sample( """ model = modelcontext(model) + if start is None: + start = model.test_point + else: + if isinstance(start, dict): + update_start_vals(start, model.test_point, model) + else: + for chain_start_vals in start: + update_start_vals(chain_start_vals, model.test_point, model) + check_start_vals(start, model) if cores is None: cores = min(4, _cpu_count()) @@ -487,6 +497,7 @@ def sample( progressbar=progressbar, **kwargs, ) + check_start_vals(start_, model) if start is None: start = start_ except (AttributeError, NotImplementedError, tg.NullTypeGradError): diff --git a/pymc3/tests/test_examples.py b/pymc3/tests/test_examples.py index b9eb850b0d2..d31b2bfd3a0 100644 --- a/pymc3/tests/test_examples.py +++ b/pymc3/tests/test_examples.py @@ -274,7 +274,7 @@ def build_model(self): # Estimated mean count theta = pm.Uniform("theta", 0, 100) # Poisson likelihood - pm.ZeroInflatedPoisson("y", theta, psi, observed=self.y) + pm.ZeroInflatedPoisson("y", psi, theta, observed=self.y) return model def test_run(self): diff --git a/pymc3/tests/test_hmc.py b/pymc3/tests/test_hmc.py index 1dd3e42acba..384501f2fd2 100644 --- a/pymc3/tests/test_hmc.py +++ b/pymc3/tests/test_hmc.py @@ -17,9 +17,7 @@ from . import models from pymc3.step_methods.hmc.base_hmc import BaseHMC -from pymc3.exceptions import SamplingError import pymc3 -import pytest import logging from pymc3.theanof import floatX @@ -57,16 +55,3 @@ def test_nuts_tuning(): assert not step.tune assert np.all(trace["step_size"][5:] == trace["step_size"][5]) - - -def test_nuts_error_reporting(caplog): - model = pymc3.Model() - with caplog.at_level(logging.CRITICAL) and pytest.raises(SamplingError): - with model: - pymc3.HalfNormal("a", sigma=1, transform=None, testval=-1) - pymc3.HalfNormal("b", sigma=1, transform=None) - trace = pymc3.sample(init="adapt_diag", chains=1) - assert ( - "Bad initial energy, check any log probabilities that are inf or -inf: a -inf\nb" - in caplog.text - ) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index 61ba1b7f651..e115bdcb171 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -27,7 +27,6 @@ simple_2model_continuous, ) from pymc3.sampling import assign_step_methods, sample -from pymc3.parallel_sampling import ParallelSamplingError from pymc3.exceptions import SamplingError from pymc3.model import Model, Potential, set_data @@ -963,15 +962,15 @@ def test_bad_init_nonparallel(self): HalfNormal("a", sigma=1, testval=-1, transform=None) with pytest.raises(SamplingError) as error: sample(init=None, chains=1, random_seed=1) - error.match("Bad initial") + error.match("Initial evaluation") @pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") def test_bad_init_parallel(self): with Model(): HalfNormal("a", sigma=1, testval=-1, transform=None) - with pytest.raises(ParallelSamplingError) as error: + with pytest.raises(SamplingError) as error: sample(init=None, cores=2, random_seed=1) - error.match("Bad initial") + error.match("Initial evaluation") def test_linalg(self, caplog): with Model(): diff --git a/pymc3/tests/test_util.py b/pymc3/tests/test_util.py index cf1f632a5a8..63fda204579 100644 --- a/pymc3/tests/test_util.py +++ b/pymc3/tests/test_util.py @@ -95,6 +95,40 @@ def test_soft_update_parent(self): assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"]) +class TestCheckStartVals(SeededTest): + def setup_method(self): + super().setup_method() + + def test_valid_start_point(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + + start = {"a": 0.3, "b": 2.1} + pm.util.update_start_vals(start, model.test_point, model) + pm.util.check_start_vals(start, model) + + def test_invalid_start_point(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + + start = {"a": np.nan, "b": np.nan} + pm.util.update_start_vals(start, model.test_point, model) + with pytest.raises(pm.exceptions.SamplingError): + pm.util.check_start_vals(start, model) + + def test_invalid_variable_name(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + + start = {"a": 0.3, "b": 2.1, "c": 1.0} + pm.util.update_start_vals(start, model.test_point, model) + with pytest.raises(KeyError): + pm.util.check_start_vals(start, model) + + class TestExceptions: def test_shape_error(self): with pytest.raises(pm.exceptions.ShapeError) as exinfo: diff --git a/pymc3/tuning/starting.py b/pymc3/tuning/starting.py index db49ef52aef..6c83af33f0c 100644 --- a/pymc3/tuning/starting.py +++ b/pymc3/tuning/starting.py @@ -28,7 +28,7 @@ from ..theanof import inputvars import theano.gradient as tg from ..blocking import DictToArrayBijection, ArrayOrdering -from ..util import update_start_vals, get_default_varnames, get_var_name +from ..util import check_start_vals, update_start_vals, get_default_varnames, get_var_name import warnings from inspect import getargspec @@ -89,13 +89,7 @@ def find_MAP( else: update_start_vals(start, model.test_point, model) - if not set(start.keys()).issubset(model.named_vars.keys()): - extra_keys = ", ".join(set(start.keys()) - set(model.named_vars.keys())) - valid_keys = ", ".join(model.named_vars.keys()) - raise KeyError( - "Some start parameters do not appear in the model!\n" - "Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys) - ) + check_start_vals(start, model) if vars is None: vars = model.cont_vars diff --git a/pymc3/util.py b/pymc3/util.py index d332510c2f8..8b663f01e33 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -16,10 +16,11 @@ import functools from typing import List, Dict, Tuple, Union +import numpy as np import xarray import arviz -from numpy import ndarray +from pymc3.exceptions import SamplingError from theano.tensor import TensorVariable @@ -188,6 +189,48 @@ def update_start_vals(a, b, model): a.update({k: v for k, v in b.items() if k not in a}) +def check_start_vals(start, model): + r"""Check that the starting values for MCMC do not cause the relevant log probability + to evaluate to something invalid (e.g. Inf or NaN) + + Parameters + ---------- + start : dict, or array of dict + Starting point in parameter space (or partial point) + Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not + (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can + overwrite the default. + model : Model object + Raises + ______ + KeyError if the parameters provided by `start` do not agree with the parameters contained + within `model` + pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an + invalid (i.e. non-finite) state + Returns + ------- + None + """ + start_points = [start] if isinstance(start, dict) else start + for elem in start_points: + if not set(elem.keys()).issubset(model.named_vars.keys()): + extra_keys = ", ".join(set(elem.keys()) - set(model.named_vars.keys())) + valid_keys = ", ".join(model.named_vars.keys()) + raise KeyError( + "Some start parameters do not appear in the model!\n" + "Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys) + ) + + initial_eval = model.check_test_point(test_point=elem) + + if not np.all(np.isfinite(initial_eval)): + raise SamplingError( + "Initial evaluation of model at starting point failed!\n" + "Starting values:\n{}\n\n" + "Initial evaluation results:\n{}".format(elem, str(initial_eval)) + ) + + def get_transformed(z): if hasattr(z, "transformed"): z = z.transformed @@ -214,13 +257,13 @@ def enhanced(*args, **kwargs): # FIXME: this function is poorly named, because it returns a LIST of # points, not a dictionary of points. -def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]: +def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]: # grab posterior samples for each variable - _samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()} + _samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()} # make dicts - points: List[Dict[str, ndarray]] = [] + points: List[Dict[str, np.ndarray]] = [] vn: str - s: ndarray + s: np.ndarray for c in ds.chain: for d in ds.draw: points.append({vn: s[c, d] for vn, s in _samples.items()})