From 3bc8c36aacf1ab6055ed11357fc24b9ec5f6a141 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Fri, 18 Nov 2022 19:40:14 +0100 Subject: [PATCH] Require all step methods to return stats The reason for this change is the resulting simplification of code, including simpler branching and less type ambiguity. Closes #6270 --- pymc/blocking.py | 7 ++-- pymc/sampling/mcmc.py | 12 +++---- pymc/sampling/parallel.py | 10 +----- pymc/sampling/population.py | 26 +++++++-------- pymc/step_methods/arraystep.py | 55 ++++++++++++------------------- pymc/step_methods/compound.py | 42 +++++++++-------------- pymc/step_methods/hmc/base_hmc.py | 4 +-- pymc/step_methods/hmc/hmc.py | 1 - pymc/step_methods/hmc/nuts.py | 1 - pymc/step_methods/metropolis.py | 35 +++++++++----------- pymc/step_methods/slicer.py | 19 ++++++----- pymc/tests/sampling/test_mcmc.py | 6 ++-- scripts/run_mypy.py | 1 + 13 files changed, 93 insertions(+), 126 deletions(-) diff --git a/pymc/blocking.py b/pymc/blocking.py index 15e249d51e0..88c71e79ec1 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -20,15 +20,18 @@ from __future__ import annotations from functools import partial -from typing import Callable, Dict, Generic, NamedTuple, TypeVar +from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar import numpy as np +from typing_extensions import TypeAlias + __all__ = ["DictToArrayBijection"] T = TypeVar("T") -PointType = Dict[str, np.ndarray] +PointType: TypeAlias = Dict[str, np.ndarray] +StatsType: TypeAlias = List[Dict[str, Any]] # `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for # each of the raveled variables. diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index a0362497fe8..efa5d8c480d 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -912,14 +912,10 @@ def _iter_sample( step.iter_count = 0 if i == tune: step.stop_tuning() - if step.generates_stats: - point, stats = step.step(point) - strace.record(point, stats) - log_warning_stats(stats) - diverging = i > tune and stats and stats[0].get("diverging") - else: - point = step.step(point) - strace.record(point, []) + point, stats = step.step(point) + strace.record(point, stats) + log_warning_stats(stats) + diverging = i > tune and stats and stats[0].get("diverging") if callback is not None: callback( trace=strace, diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 94eb749ff71..8a528c07275 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -173,7 +173,7 @@ def _start_loop(self): if draw < self._draws + self._tune: try: - point, stats = self._compute_point() + point, stats = self._step_method.step(self._point) except SamplingError as e: e = ExceptionWithTraceback(e, e.__traceback__) self._msg_pipe.send(("error", e)) @@ -191,14 +191,6 @@ def _start_loop(self): else: raise ValueError("Unknown message " + msg[0]) - def _compute_point(self): - if self._step_method.generates_stats: - point, stats = self._step_method.step(self._point) - else: - point = self._step_method.step(self._point) - stats = None - return point, stats - def _run_process(*args): _Process(*args).run() diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 96ed3cf6e25..07fbd6d340a 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -17,7 +17,7 @@ import logging from copy import copy -from typing import Iterator, List, Sequence, Union +from typing import Iterator, List, Sequence, Tuple, Union import cloudpickle import numpy as np @@ -31,7 +31,11 @@ from pymc.model import modelcontext from pymc.stats.convergence import log_warning_stats from pymc.step_methods import CompoundStep -from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared +from pymc.step_methods.arraystep import ( + BlockedStep, + PopulationArrayStepShared, + StatsType, +) from pymc.util import RandomSeed __all__ = () @@ -224,7 +228,7 @@ def _run_secondary(c, stepper_dumps, secondary_end): _log.exception(f"ChainWalker{c}") return - def step(self, tune_stop: bool, population): + def step(self, tune_stop: bool, population) -> List[Tuple[PointType, StatsType]]: """Step the entire population of chains. Parameters @@ -239,18 +243,18 @@ def step(self, tune_stop: bool, population): update : list List of (Point, stats) tuples for all chains """ - updates = [None] * self.nchains + updates: List[Tuple[PointType, StatsType]] = [] if self.is_parallelized: for c in range(self.nchains): self._primary_ends[c].send((tune_stop, population)) # Blockingly get the step outcomes for c in range(self.nchains): - updates[c] = self._primary_ends[c].recv() + updates.append(self._primary_ends[c].recv()) else: for c in range(self.nchains): if tune_stop: self._steppers[c].stop_tuning() - updates[c] = self._steppers[c].step(population[c]) + updates.append(self._steppers[c].step(population[c])) return updates @@ -378,13 +382,9 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): - if steppers[c].generates_stats: - points[c], stats = updates[c] - strace.record(points[c], stats) - log_warning_stats(stats) - else: - points[c] = updates[c] - strace.record(points[c]) + points[c], stats = updates[c] + strace.record(points[c], stats) + log_warning_stats(stats) # yield the state of all chains in parallel yield traces except KeyboardInterrupt: diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 534755f7523..dad7fb27943 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -14,22 +14,20 @@ from abc import ABC, abstractmethod from enum import IntEnum, unique -from typing import Dict, List, Tuple, TypeVar, Union +from typing import Callable, Dict, List, Tuple, Union, cast import numpy as np from aesara.graph.basic import Variable from numpy.random import uniform -from pymc.blocking import DictToArrayBijection, PointType, RaveledVars +from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType from pymc.model import modelcontext from pymc.step_methods.compound import CompoundStep from pymc.util import get_var_name __all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"] -StatsType = TypeVar("StatsType") - @unique class Competence(IntEnum): @@ -49,7 +47,6 @@ class Competence(IntEnum): class BlockedStep(ABC): - generates_stats = False stats_dtypes: List[Dict[str, type]] = [] vars: List[Variable] = [] @@ -103,7 +100,7 @@ def __getnewargs_ex__(self): return self.__newargs @abstractmethod - def step(point: PointType, *args, **kwargs) -> Union[PointType, Tuple[PointType, StatsType]]: + def step(self, point: PointType) -> Tuple[PointType, StatsType]: """Perform a single step of the sampler.""" @staticmethod @@ -146,19 +143,17 @@ def __init__(self, vars, fs, allvars=False, blocked=True): self.allvars = allvars self.blocked = blocked - def step(self, point: PointType): + def step(self, point: PointType) -> Tuple[PointType, StatsType]: - partial_funcs_and_point = [DictToArrayBijection.mapf(x, start_point=point) for x in self.fs] + partial_funcs_and_point: List[Union[Callable, PointType]] = [ + DictToArrayBijection.mapf(x, start_point=point) for x in self.fs + ] if self.allvars: partial_funcs_and_point.append(point) - apoint = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars}) - step_res = self.astep(apoint, *partial_funcs_and_point) - - if self.generates_stats: - apoint_new, stats = step_res - else: - apoint_new = step_res + var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars} + apoint = DictToArrayBijection.map(var_dict) + apoint_new, stats = self.astep(apoint, *partial_funcs_and_point) if not isinstance(apoint_new, RaveledVars): # We assume that the mapping has stayed the same @@ -166,15 +161,10 @@ def step(self, point: PointType): point_new = DictToArrayBijection.rmap(apoint_new, start_point=point) - if self.generates_stats: - return point_new, stats - - return point_new + return point_new, stats @abstractmethod - def astep( - self, apoint: RaveledVars, point: PointType, *args - ) -> Union[RaveledVars, Tuple[RaveledVars, StatsType]]: + def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: """Perform a single sample step in a raveled and concatenated parameter space.""" @@ -198,19 +188,15 @@ def __init__(self, vars, shared, blocked=True): self.shared = {get_var_name(var): shared for var, shared in shared.items()} self.blocked = blocked - def step(self, point): + def step(self, point: PointType) -> Tuple[PointType, StatsType]: for name, shared_var in self.shared.items(): shared_var.set_value(point[name]) - q = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars}) + var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars} + q = DictToArrayBijection.map(var_dict) - step_res = self.astep(q) - - if self.generates_stats: - apoint, stats = step_res - else: - apoint = step_res + apoint, stats = self.astep(q) if not isinstance(apoint, RaveledVars): # We assume that the mapping has stayed the same @@ -218,10 +204,11 @@ def step(self, point): new_point = DictToArrayBijection.rmap(apoint, start_point=point) - if self.generates_stats: - return new_point, stats + return new_point, stats - return new_point + @abstractmethod + def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: + """Perform a single sample step in a raveled and concatenated parameter space.""" class PopulationArrayStepShared(ArrayStepShared): @@ -281,7 +268,7 @@ def __init__( super().__init__(vars, func._extra_vars_shared, blocked) - def step(self, point): + def step(self, point) -> Tuple[PointType, StatsType]: self._logp_dlogp_func._extra_are_set = True return super().step(point) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 4bb86056ed0..d8c0c0d8453 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -17,9 +17,12 @@ @author: johnsalvatier """ -from collections import namedtuple -import numpy as np + +from typing import Tuple + +from pymc.blocking import PointType +from pymc.step_methods.arraystep import StatsType class CompoundStep: @@ -28,36 +31,23 @@ class CompoundStep: def __init__(self, methods): self.methods = list(methods) - self.generates_stats = any(method.generates_stats for method in self.methods) self.stats_dtypes = [] for method in self.methods: - if method.generates_stats: - self.stats_dtypes.extend(method.stats_dtypes) + self.stats_dtypes.extend(method.stats_dtypes) self.name = ( f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" ) - def step(self, point): - if self.generates_stats: - states = [] - for method in self.methods: - if method.generates_stats: - point, state = method.step(point) - states.extend(state) - else: - point = method.step(point) - # Model logp can only be the logp of the _last_ state, if there is - # one. Pop all others (if dict), or set to np.nan (if namedtuple). - for state in states[:-1]: - if isinstance(state, dict): - state.pop("model_logp", None) - elif isinstance(state, namedtuple): - state = state._replace(logp=np.nan) - return point, states - else: - for method in self.methods: - point = method.step(point) - return point + def step(self, point) -> Tuple[PointType, StatsType]: + stats = [] + for method in self.methods: + point, sts = method.step(point) + stats.extend(sts) + # Model logp can only be the logp of the _last_ stats, + # if there is one. Pop all others. + for sts in stats[:-1]: + sts.pop("model_logp", None) + return point, stats def stop_tuning(self): for method in self.methods: diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 0eb6fd77b19..844435d6de9 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -23,7 +23,7 @@ import numpy as np from pymc.aesaraf import floatX -from pymc.blocking import DictToArrayBijection, RaveledVars +from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType from pymc.exceptions import SamplingError from pymc.model import Point, modelcontext from pymc.stats.convergence import SamplerWarning, WarningType @@ -157,7 +157,7 @@ def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData: Subclasses must overwrite this abstract method and return an `HMCStepData` object. """ - def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]: + def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: """Perform a single HMC iteration.""" perf_start = time.perf_counter() process_start = time.process_time() diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 804fadc3046..bf3a11d4cf5 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -39,7 +39,6 @@ class HamiltonianMC(BaseHMC): name = "hmc" default_blocked = True - generates_stats = True stats_dtypes = [ { "step_size": np.float64, diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index f1251d5c550..a692a19f321 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -97,7 +97,6 @@ class NUTS(BaseHMC): name = "nuts" default_blocked = True - generates_stats = True stats_dtypes = [ { "depth": np.int64, diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index dcd4b3b6d5d..c2e64d71676 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import aesara import numpy as np @@ -38,6 +38,7 @@ ArrayStepShared, Competence, PopulationArrayStepShared, + StatsType, metrop_select, ) @@ -126,7 +127,6 @@ class Metropolis(ArrayStepShared): name = "metropolis" default_blocked = False - generates_stats = True stats_dtypes = [ { "accept": np.float64, @@ -244,7 +244,7 @@ def reset_tuning(self): self.accepted_sum[:] = 0 return - def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: + def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info q0 = q0.data @@ -374,7 +374,6 @@ class BinaryMetropolis(ArrayStep): name = "binary_metropolis" - generates_stats = True stats_dtypes = [ { "accept": np.float64, @@ -400,8 +399,8 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None): super().__init__(vars, [model.compile_logp()]) - def astep(self, q0: RaveledVars, logp) -> Tuple[RaveledVars, List[Dict[str, Any]]]: - + def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + logp = args[0] logp_q0 = logp(q0) point_map_info = q0.point_map_info q0 = q0.data @@ -502,8 +501,8 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None): super().__init__(vars, [model.compile_logp()]) - def astep(self, q0: RaveledVars, logp: Callable[[RaveledVars], np.ndarray]) -> RaveledVars: - + def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + logp: Callable[[RaveledVars], np.ndarray] = args[0] order = self.order if self.shuffle_dims: nr.shuffle(order) @@ -522,7 +521,7 @@ def astep(self, q0: RaveledVars, logp: Callable[[RaveledVars], np.ndarray]) -> R if accepted: logp_curr = logp_prop - return q + return q, [] @staticmethod def competence(var): @@ -616,8 +615,8 @@ def __init__(self, vars, proposal="uniform", order="random", model=None): super().__init__(vars, [model.compile_logp()]) - def astep_unif(self, q0: RaveledVars, logp) -> RaveledVars: - + def astep_unif(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + logp = args[0] point_map_info = q0.point_map_info q0 = q0.data @@ -635,9 +634,9 @@ def astep_unif(self, q0: RaveledVars, logp) -> RaveledVars: if accepted: logp_curr = logp_prop - return q + return q, [] - def astep_prop(self, q0: RaveledVars, logp) -> RaveledVars: + def astep_prop(self, q0: RaveledVars, logp) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info q0 = q0.data @@ -652,9 +651,9 @@ def astep_prop(self, q0: RaveledVars, logp) -> RaveledVars: for dim, k in dimcats: logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) - return q + return q, [] - def astep(self, q0, logp): + def astep(self, q0, logp) -> Tuple[RaveledVars, StatsType]: raise NotImplementedError() def metropolis_proportional(self, q, logp, logp_curr, dim, k): @@ -744,7 +743,6 @@ class DEMetropolis(PopulationArrayStepShared): name = "DEMetropolis" default_blocked = True - generates_stats = True stats_dtypes = [ { "accept": np.float64, @@ -804,7 +802,7 @@ def __init__( self.delta_logp = delta_logp(initial_values, model.logp(), vars, shared) super().__init__(vars, shared) - def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: + def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info q0 = q0.data @@ -894,7 +892,6 @@ class DEMetropolisZ(ArrayStepShared): name = "DEMetropolisZ" default_blocked = True - generates_stats = True stats_dtypes = [ { "accept": np.float64, @@ -974,7 +971,7 @@ def reset_tuning(self): setattr(self, attr, initial_value) return - def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]: + def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]: point_map_info = q0.point_map_info q0 = q0.data diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 1344b4166dd..bf72cadfd72 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -14,10 +14,12 @@ # Modified from original implementation by Dominik Wabersich (2013) +from typing import Tuple + import numpy as np import numpy.random as nr -from pymc.blocking import RaveledVars +from pymc.blocking import RaveledVars, StatsType from pymc.model import modelcontext from pymc.step_methods.arraystep import ArrayStep, Competence from pymc.util import get_value_vars_from_user_vars @@ -47,7 +49,6 @@ class Slice(ArrayStep): name = "slice" default_blocked = False - generates_stats = True stats_dtypes = [ { "nstep_out": int, @@ -69,8 +70,10 @@ def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, * super().__init__(vars, [self.model.compile_logp()], **kwargs) - def astep(self, q0, logp): - q0_val = q0.data + def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]: + # The arguments are determined by the list passed via `super().__init__(..., fs, ...)` + logp = args[0] + q0_val = apoint.data self.w = np.resize(self.w, len(q0_val)) # this is a repmat nstep_out = nstep_in = 0 @@ -81,9 +84,9 @@ def astep(self, q0, logp): # The points are not copied, so it's fine to update them inplace in the # loop below - q_ra = RaveledVars(q, q0.point_map_info) - ql_ra = RaveledVars(ql, q0.point_map_info) - qr_ra = RaveledVars(qr, q0.point_map_info) + q_ra = RaveledVars(q, apoint.point_map_info) + ql_ra = RaveledVars(ql, apoint.point_map_info) + qr_ra = RaveledVars(qr, apoint.point_map_info) for i, wi in enumerate(self.w): # uniformly sample from 0 to p(q), but in log space @@ -142,7 +145,7 @@ def astep(self, q0, logp): "nstep_in": nstep_in, } - return q, (stats,) + return RaveledVars(q, apoint.point_map_info), [stats] @staticmethod def competence(var, has_grad): diff --git a/pymc/tests/sampling/test_mcmc.py b/pymc/tests/sampling/test_mcmc.py index ba8eb287177..625170f051c 100644 --- a/pymc/tests/sampling/test_mcmc.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -642,7 +642,7 @@ def test_step_args(): npt.assert_allclose(idata1.sample_stats.scaling, 0) -class ApolypticMetropolis(pm.Metropolis): +class ApocalypticMetropolis(pm.Metropolis): """A stepper that warns in every iteration.""" stats_dtypes = [ @@ -673,7 +673,7 @@ def test_logs_sampler_warnings(caplog, cores): draws=3, cores=cores, chains=cores, - step=ApolypticMetropolis(), + step=ApocalypticMetropolis(), compute_convergence_checks=False, discard_tuned_samples=False, keep_warning_stat=True, @@ -702,7 +702,7 @@ def test_keep_warning_stat_setting(keep_warning_stat): sample_kwargs["keep_warning_stat"] = True with pm.Model(): pm.Normal("n") - idata = pm.sample(step=ApolypticMetropolis(), **sample_kwargs) + idata = pm.sample(step=ApocalypticMetropolis(), **sample_kwargs) if keep_warning_stat: assert "warning" in idata.warmup_sample_stats diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 4ba0d40e0dc..35dc1c91d32 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -62,6 +62,7 @@ pymc/stats/__init__.py pymc/stats/convergence.py pymc/step_methods/__init__.py +pymc/step_methods/arraystep.py pymc/step_methods/compound.py pymc/step_methods/hmc/__init__.py pymc/step_methods/hmc/base_hmc.py