Skip to content

Commit

Permalink
Refactor _init_trace out of mcmc module
Browse files Browse the repository at this point in the history
This uncouples several things:
* `_init_trace` is now independent of abstract step-methods
* `mcmc` is now unaware of `NDArray`
* code for population-sampling can now be extracted from `mcmc`
  • Loading branch information
michaelosthege committed Nov 5, 2022
1 parent c83a59a commit a4e92cb
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 49 deletions.
27 changes: 27 additions & 0 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,34 @@
Saved backends can be loaded using `arviz.from_netcdf`
"""
from copy import copy
from typing import Dict, List, Optional

from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace
from pymc.backends.ndarray import NDArray, point_list_to_multitrace

__all__ = ["to_inference_data", "predictions_to_inference_data"]


def _init_trace(
*,
expected_length: int,
chain_number: int,
stats_dtypes: List[Dict[str, type]],
trace: Optional[BaseTrace],
model,
) -> BaseTrace:
"""Initializes a trace backend for a chain."""
strace: BaseTrace
if trace is None:
strace = NDArray(model=model)
elif isinstance(trace, BaseTrace):
if len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
strace = copy(trace)
else:
raise NotImplementedError(f"Unsupported `trace`: {trace}")

strace.setup(expected_length, chain_number, stats_dtypes)
return strace
34 changes: 4 additions & 30 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

import pymc as pm

from pymc.backends import _init_trace
from pymc.backends.base import BaseTrace, MultiTrace, _choose_chains
from pymc.backends.ndarray import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.exceptions import SamplingError
from pymc.initial_point import (
Expand Down Expand Up @@ -960,7 +960,7 @@ def _iter_sample(

strace: BaseTrace = _init_trace(
expected_length=draws + tune,
step=step,
stats_dtypes=step.stats_dtypes,
chain_number=chain,
trace=trace,
model=model,
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def _prepare_iter_population(
traces: List[BaseTrace] = [
_init_trace(
expected_length=draws + tune,
step=steppers[c],
stats_dtypes=steppers[c].stats_dtypes,
chain_number=c,
trace=None,
model=model,
Expand Down Expand Up @@ -1306,32 +1306,6 @@ def _iter_population(
steppers[c].report._finalize(strace)


def _init_trace(
*,
expected_length: int,
step: Step,
chain_number: int,
trace: Optional[BaseTrace],
model,
) -> BaseTrace:
"""Extracted helper function to create trace backends for each chain."""
strace: BaseTrace
if trace is None:
strace = NDArray(model=model)
elif isinstance(trace, BaseTrace):
if len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
strace = copy(trace)
else:
raise NotImplementedError(f"Unsupported `trace`: {trace}")

if step.generates_stats:
strace.setup(expected_length, chain_number, step.stats_dtypes)
else:
strace.setup(expected_length, chain_number)
return strace


def _mp_sample(
draws: int,
tune: int,
Expand Down Expand Up @@ -1393,7 +1367,7 @@ def _mp_sample(
traces = [
_init_trace(
expected_length=draws + tune,
step=step,
stats_dtypes=step.stats_dtypes,
chain_number=chain_number,
trace=trace,
model=model,
Expand Down
22 changes: 22 additions & 0 deletions pymc/tests/backends/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import numpy as np
import pytest

import pymc as pm

from pymc.backends import _init_trace
from pymc.backends.base import _choose_chains


Expand All @@ -31,3 +34,22 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces):
traces, length = _choose_chains([trace_0, trace_1, trace_2], tune=tune)
assert length == expected_length
assert expected_n_traces == len(traces)


class TestInitTrace:
def test_init_trace_continuation_unsupported(self):
with pm.Model() as pmodel:
A = pm.Normal("A")
B = pm.Uniform("B")
strace = pm.backends.ndarray.NDArray(vars=[A, B])
strace.setup(10, 0)
strace.record({"A": 2, "B_interval__": 0.1})
assert len(strace) == 1
with pytest.raises(ValueError, match="Continuation of traces"):
_init_trace(
expected_length=20,
stats_dtypes=pm.Metropolis().stats_dtypes,
chain_number=0,
trace=strace,
model=pmodel,
)
19 changes: 0 additions & 19 deletions pymc/tests/sampling/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,25 +544,6 @@ def test_constant_named(self):
assert np.isclose(res, 0.0)


class TestInitTrace:
def test_init_trace_continuation_unsupported(self):
with pm.Model() as pmodel:
A = pm.Normal("A")
B = pm.Uniform("B")
strace = pm.backends.ndarray.NDArray(vars=[A, B])
strace.setup(10, 0)
strace.record({"A": 2, "B_interval__": 0.1})
assert len(strace) == 1
with pytest.raises(ValueError, match="Continuation of traces"):
pm.sampling.mcmc._init_trace(
expected_length=20,
step=pm.Metropolis(),
chain_number=0,
trace=strace,
model=pmodel,
)


def check_exec_nuts_init(method):
with pm.Model() as model:
pm.Normal("a", mu=0, sigma=1, size=2)
Expand Down

0 comments on commit a4e92cb

Please sign in to comment.