diff --git a/pymc/backends/__init__.py b/pymc/backends/__init__.py index fbc9f971d31..d13a6dc836b 100644 --- a/pymc/backends/__init__.py +++ b/pymc/backends/__init__.py @@ -60,7 +60,34 @@ Saved backends can be loaded using `arviz.from_netcdf` """ +from copy import copy +from typing import Dict, List, Optional + from pymc.backends.arviz import predictions_to_inference_data, to_inference_data +from pymc.backends.base import BaseTrace from pymc.backends.ndarray import NDArray, point_list_to_multitrace __all__ = ["to_inference_data", "predictions_to_inference_data"] + + +def _init_trace( + *, + expected_length: int, + chain_number: int, + stats_dtypes: List[Dict[str, type]], + trace: Optional[BaseTrace], + model, +) -> BaseTrace: + """Initializes a trace backend for a chain.""" + strace: BaseTrace + if trace is None: + strace = NDArray(model=model) + elif isinstance(trace, BaseTrace): + if len(trace) > 0: + raise ValueError("Continuation of traces is no longer supported.") + strace = copy(trace) + else: + raise NotImplementedError(f"Unsupported `trace`: {trace}") + + strace.setup(expected_length, chain_number, stats_dtypes) + return strace diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 8b855f33e99..c945cfc9383 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -34,8 +34,8 @@ import pymc as pm +from pymc.backends import _init_trace from pymc.backends.base import BaseTrace, MultiTrace, _choose_chains -from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError from pymc.initial_point import ( @@ -960,7 +960,7 @@ def _iter_sample( strace: BaseTrace = _init_trace( expected_length=draws + tune, - step=step, + stats_dtypes=step.stats_dtypes, chain_number=chain, trace=trace, model=model, @@ -1229,7 +1229,7 @@ def _prepare_iter_population( traces: List[BaseTrace] = [ _init_trace( expected_length=draws + tune, - step=steppers[c], + stats_dtypes=steppers[c].stats_dtypes, chain_number=c, trace=None, model=model, @@ -1306,32 +1306,6 @@ def _iter_population( steppers[c].report._finalize(strace) -def _init_trace( - *, - expected_length: int, - step: Step, - chain_number: int, - trace: Optional[BaseTrace], - model, -) -> BaseTrace: - """Extracted helper function to create trace backends for each chain.""" - strace: BaseTrace - if trace is None: - strace = NDArray(model=model) - elif isinstance(trace, BaseTrace): - if len(trace) > 0: - raise ValueError("Continuation of traces is no longer supported.") - strace = copy(trace) - else: - raise NotImplementedError(f"Unsupported `trace`: {trace}") - - if step.generates_stats: - strace.setup(expected_length, chain_number, step.stats_dtypes) - else: - strace.setup(expected_length, chain_number) - return strace - - def _mp_sample( draws: int, tune: int, @@ -1393,7 +1367,7 @@ def _mp_sample( traces = [ _init_trace( expected_length=draws + tune, - step=step, + stats_dtypes=step.stats_dtypes, chain_number=chain_number, trace=trace, model=model, diff --git a/pymc/tests/backends/test_base.py b/pymc/tests/backends/test_base.py index 14eca7cf52d..0e8cb95a027 100644 --- a/pymc/tests/backends/test_base.py +++ b/pymc/tests/backends/test_base.py @@ -14,6 +14,9 @@ import numpy as np import pytest +import pymc as pm + +from pymc.backends import _init_trace from pymc.backends.base import _choose_chains @@ -31,3 +34,22 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces): traces, length = _choose_chains([trace_0, trace_1, trace_2], tune=tune) assert length == expected_length assert expected_n_traces == len(traces) + + +class TestInitTrace: + def test_init_trace_continuation_unsupported(self): + with pm.Model() as pmodel: + A = pm.Normal("A") + B = pm.Uniform("B") + strace = pm.backends.ndarray.NDArray(vars=[A, B]) + strace.setup(10, 0) + strace.record({"A": 2, "B_interval__": 0.1}) + assert len(strace) == 1 + with pytest.raises(ValueError, match="Continuation of traces"): + _init_trace( + expected_length=20, + stats_dtypes=pm.Metropolis().stats_dtypes, + chain_number=0, + trace=strace, + model=pmodel, + ) diff --git a/pymc/tests/sampling/test_mcmc.py b/pymc/tests/sampling/test_mcmc.py index 081658e7101..ba8eb287177 100644 --- a/pymc/tests/sampling/test_mcmc.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -544,25 +544,6 @@ def test_constant_named(self): assert np.isclose(res, 0.0) -class TestInitTrace: - def test_init_trace_continuation_unsupported(self): - with pm.Model() as pmodel: - A = pm.Normal("A") - B = pm.Uniform("B") - strace = pm.backends.ndarray.NDArray(vars=[A, B]) - strace.setup(10, 0) - strace.record({"A": 2, "B_interval__": 0.1}) - assert len(strace) == 1 - with pytest.raises(ValueError, match="Continuation of traces"): - pm.sampling.mcmc._init_trace( - expected_length=20, - step=pm.Metropolis(), - chain_number=0, - trace=strace, - model=pmodel, - ) - - def check_exec_nuts_init(method): with pm.Model() as model: pm.Normal("a", mu=0, sigma=1, size=2)