Skip to content

Commit

Permalink
Add ability to set mode in check_start_vals
Browse files Browse the repository at this point in the history
  • Loading branch information
lucianopaz committed Aug 28, 2024
1 parent c92a9a9 commit 365b117
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
12 changes: 8 additions & 4 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,7 +1747,7 @@ def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]:
)
return {name: tuple(shape) for name, shape in zip(names, f())}

def check_start_vals(self, start):
def check_start_vals(self, start, **kwargs):
r"""Check that the starting values for MCMC do not cause the relevant log probability
to evaluate to something invalid (e.g. Inf or NaN)
Expand All @@ -1758,6 +1758,8 @@ def check_start_vals(self, start):
Defaults to ``trace.point(-1))`` if there is a trace provided and
``model.initial_point`` if not (defaults to empty dict). Initialization
methods for NUTS (see ``init`` keyword) can overwrite the default.
Other keyword arguments :
Any other keyword argument is sent to :py:meth:`~pymc.model.core.Model.point_logps`.
Raises
------
Expand Down Expand Up @@ -1787,7 +1789,7 @@ def check_start_vals(self, start):
f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
)

initial_eval = self.point_logps(point=elem)
initial_eval = self.point_logps(point=elem, **kwargs)

if not all(np.isfinite(v) for v in initial_eval.values()):
raise SamplingError(
Expand All @@ -1797,7 +1799,7 @@ def check_start_vals(self, start):
"You can call `model.debug()` for more details."
)

def point_logps(self, point=None, round_vals=2):
def point_logps(self, point=None, round_vals=2, **kwargs):
"""Computes the log probability of `point` for all random variables in the model.
Parameters
Expand All @@ -1807,6 +1809,8 @@ def point_logps(self, point=None, round_vals=2):
is used.
round_vals : int, default 2
Number of decimals to round log-probabilities.
Other keyword arguments :
Any other keyword argument are sent provided to :py:meth:`~pymc.model.core.Model.compile_fn`
Returns
-------
Expand All @@ -1822,7 +1826,7 @@ def point_logps(self, point=None, round_vals=2):
factor.name: np.round(np.asarray(factor_logp), round_vals)
for factor, factor_logp in zip(
factors,
self.compile_fn(factor_logps_fn)(point),
self.compile_fn(factor_logps_fn, **kwargs)(point),
)
}

Expand Down
14 changes: 14 additions & 0 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,20 @@ def test_invalid_variable_name(self):
with pytest.raises(KeyError):
model.check_start_vals(start)

@pytest.mark.parametrize("mode", [None, "JAX", "NUMBA"])
def test_mode(self, mode):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)
start = {
"a_interval__": model.rvs_to_transforms[a].forward(0.3, *a.owner.inputs).eval(),
"b_interval__": model.rvs_to_transforms[b].forward(2.1, *b.owner.inputs).eval(),
}
with patch("pymc.model.core.compile_pymc") as patched_compile_pymc:
model.check_start_vals(start, mode=mode)
patched_compile_pymc.assert_called_once()
assert patched_compile_pymc.call_args.kwargs["mode"] == mode


def test_set_initval():
# Make sure the dependencies between variables are maintained when
Expand Down

0 comments on commit 365b117

Please sign in to comment.