From 365b11758ebc5f14ead89ef34149ac0eb005aa3d Mon Sep 17 00:00:00 2001 From: Luciano Paz Date: Wed, 28 Aug 2024 12:09:00 +0200 Subject: [PATCH] Add ability to set mode in check_start_vals --- pymc/model/core.py | 12 ++++++++---- tests/model/test_core.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/pymc/model/core.py b/pymc/model/core.py index 3a27417661..b5d3b673c1 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1747,7 +1747,7 @@ def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]: ) return {name: tuple(shape) for name, shape in zip(names, f())} - def check_start_vals(self, start): + def check_start_vals(self, start, **kwargs): r"""Check that the starting values for MCMC do not cause the relevant log probability to evaluate to something invalid (e.g. Inf or NaN) @@ -1758,6 +1758,8 @@ def check_start_vals(self, start): Defaults to ``trace.point(-1))`` if there is a trace provided and ``model.initial_point`` if not (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can overwrite the default. + Other keyword arguments : + Any other keyword argument is sent to :py:meth:`~pymc.model.core.Model.point_logps`. Raises ------ @@ -1787,7 +1789,7 @@ def check_start_vals(self, start): f"Valid keys are: {valid_keys}, but {extra_keys} was supplied" ) - initial_eval = self.point_logps(point=elem) + initial_eval = self.point_logps(point=elem, **kwargs) if not all(np.isfinite(v) for v in initial_eval.values()): raise SamplingError( @@ -1797,7 +1799,7 @@ def check_start_vals(self, start): "You can call `model.debug()` for more details." ) - def point_logps(self, point=None, round_vals=2): + def point_logps(self, point=None, round_vals=2, **kwargs): """Computes the log probability of `point` for all random variables in the model. Parameters @@ -1807,6 +1809,8 @@ def point_logps(self, point=None, round_vals=2): is used. round_vals : int, default 2 Number of decimals to round log-probabilities. + Other keyword arguments : + Any other keyword argument are sent provided to :py:meth:`~pymc.model.core.Model.compile_fn` Returns ------- @@ -1822,7 +1826,7 @@ def point_logps(self, point=None, round_vals=2): factor.name: np.round(np.asarray(factor_logp), round_vals) for factor, factor_logp in zip( factors, - self.compile_fn(factor_logps_fn)(point), + self.compile_fn(factor_logps_fn, **kwargs)(point), ) } diff --git a/tests/model/test_core.py b/tests/model/test_core.py index 669704bb51..2504f5c79f 100644 --- a/tests/model/test_core.py +++ b/tests/model/test_core.py @@ -756,6 +756,20 @@ def test_invalid_variable_name(self): with pytest.raises(KeyError): model.check_start_vals(start) + @pytest.mark.parametrize("mode", [None, "JAX", "NUMBA"]) + def test_mode(self, mode): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + start = { + "a_interval__": model.rvs_to_transforms[a].forward(0.3, *a.owner.inputs).eval(), + "b_interval__": model.rvs_to_transforms[b].forward(2.1, *b.owner.inputs).eval(), + } + with patch("pymc.model.core.compile_pymc") as patched_compile_pymc: + model.check_start_vals(start, mode=mode) + patched_compile_pymc.assert_called_once() + assert patched_compile_pymc.call_args.kwargs["mode"] == mode + def test_set_initval(): # Make sure the dependencies between variables are maintained when