diff --git a/pymc/blocking.py b/pymc/blocking.py index 9d5c29849a2..72fc4d72cdf 100644 --- a/pymc/blocking.py +++ b/pymc/blocking.py @@ -31,7 +31,8 @@ T = TypeVar("T") PointType: TypeAlias = Dict[str, np.ndarray] -StatsType: TypeAlias = List[Dict[str, Any]] +StatsDict: TypeAlias = Dict[str, Any] +StatsType: TypeAlias = List[StatsDict] # `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for # each of the raveled variables. diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index c46324d31f5..a0d83c04918 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -12,116 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC, abstractmethod -from enum import IntEnum, unique -from typing import Callable, Dict, List, Tuple, Union, cast +from abc import abstractmethod +from typing import Callable, List, Tuple, Union, cast import numpy as np from numpy.random import uniform -from pytensor.graph.basic import Variable from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType from pymc.model import modelcontext -from pymc.step_methods.compound import CompoundStep +from pymc.step_methods.compound import BlockedStep from pymc.util import get_var_name -__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"] - - -@unique -class Competence(IntEnum): - """Enum for characterizing competence classes of step methods. - Values include: - 0: INCOMPATIBLE - 1: COMPATIBLE - 2: PREFERRED - 3: IDEAL - """ - - INCOMPATIBLE = 0 - COMPATIBLE = 1 - PREFERRED = 2 - IDEAL = 3 - - -class BlockedStep(ABC): - - stats_dtypes: List[Dict[str, type]] = [] - vars: List[Variable] = [] - - def __new__(cls, *args, **kwargs): - blocked = kwargs.get("blocked") - if blocked is None: - # Try to look up default value from class - blocked = getattr(cls, "default_blocked", True) - kwargs["blocked"] = blocked - - model = modelcontext(kwargs.get("model")) - kwargs.update({"model": model}) - - # vars can either be first arg or a kwarg - if "vars" not in kwargs and len(args) >= 1: - vars = args[0] - args = args[1:] - elif "vars" in kwargs: - vars = kwargs.pop("vars") - else: # Assume all model variables - vars = model.value_vars - - if not isinstance(vars, (tuple, list)): - vars = [vars] - - if len(vars) == 0: - raise ValueError("No free random variables to sample.") - - if not blocked and len(vars) > 1: - # In this case we create a separate sampler for each var - # and append them to a CompoundStep - steps = [] - for var in vars: - step = super().__new__(cls) - # If we don't return the instance we have to manually - # call __init__ - step.__init__([var], *args, **kwargs) - # Hack for creating the class correctly when unpickling. - step.__newargs = ([var],) + args, kwargs - steps.append(step) - - return CompoundStep(steps) - else: - step = super().__new__(cls) - # Hack for creating the class correctly when unpickling. - step.__newargs = (vars,) + args, kwargs - return step - - # Hack for creating the class correctly when unpickling. - def __getnewargs_ex__(self): - return self.__newargs - - @abstractmethod - def step(self, point: PointType) -> Tuple[PointType, StatsType]: - """Perform a single step of the sampler.""" - - @staticmethod - def competence(var, has_grad): - return Competence.INCOMPATIBLE - - @classmethod - def _competence(cls, vars, have_grad): - vars = np.atleast_1d(vars) - have_grad = np.atleast_1d(have_grad) - competences = [] - for var, has_grad in zip(vars, have_grad): - try: - competences.append(cls.competence(var, has_grad)) - except TypeError: - competences.append(cls.competence(var)) - return competences - - def stop_tuning(self): - if hasattr(self, "tune"): - self.tune = False +__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select"] class ArrayStep(BlockedStep): diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index 486f2fc8f10..b1161164412 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -18,10 +18,113 @@ @author: johnsalvatier """ +from abc import ABC, abstractmethod +from enum import IntEnum, unique +from typing import Dict, List, Sequence, Tuple, Union -from typing import Tuple +import numpy as np -from pymc.blocking import PointType, StatsType +from pytensor.graph.basic import Variable + +from pymc.blocking import PointType, StatsDict, StatsType +from pymc.model import modelcontext + +__all__ = ("Competence", "CompoundStep") + + +@unique +class Competence(IntEnum): + """Enum for characterizing competence classes of step methods. + Values include: + 0: INCOMPATIBLE + 1: COMPATIBLE + 2: PREFERRED + 3: IDEAL + """ + + INCOMPATIBLE = 0 + COMPATIBLE = 1 + PREFERRED = 2 + IDEAL = 3 + + +class BlockedStep(ABC): + + stats_dtypes: List[Dict[str, type]] = [] + vars: List[Variable] = [] + + def __new__(cls, *args, **kwargs): + blocked = kwargs.get("blocked") + if blocked is None: + # Try to look up default value from class + blocked = getattr(cls, "default_blocked", True) + kwargs["blocked"] = blocked + + model = modelcontext(kwargs.get("model")) + kwargs.update({"model": model}) + + # vars can either be first arg or a kwarg + if "vars" not in kwargs and len(args) >= 1: + vars = args[0] + args = args[1:] + elif "vars" in kwargs: + vars = kwargs.pop("vars") + else: # Assume all model variables + vars = model.value_vars + + if not isinstance(vars, (tuple, list)): + vars = [vars] + + if len(vars) == 0: + raise ValueError("No free random variables to sample.") + + if not blocked and len(vars) > 1: + # In this case we create a separate sampler for each var + # and append them to a CompoundStep + steps = [] + for var in vars: + step = super().__new__(cls) + # If we don't return the instance we have to manually + # call __init__ + step.__init__([var], *args, **kwargs) + # Hack for creating the class correctly when unpickling. + step.__newargs = ([var],) + args, kwargs + steps.append(step) + + return CompoundStep(steps) + else: + step = super().__new__(cls) + # Hack for creating the class correctly when unpickling. + step.__newargs = (vars,) + args, kwargs + return step + + # Hack for creating the class correctly when unpickling. + def __getnewargs_ex__(self): + return self.__newargs + + @abstractmethod + def step(self, point: PointType) -> Tuple[PointType, StatsType]: + """Perform a single step of the sampler.""" + + @staticmethod + def competence(var, has_grad): + return Competence.INCOMPATIBLE + + @classmethod + def _competence(cls, vars, have_grad): + vars = np.atleast_1d(vars) + have_grad = np.atleast_1d(have_grad) + competences = [] + for var, has_grad in zip(vars, have_grad): + try: + competences.append(cls.competence(var, has_grad)) + except TypeError: + competences.append(cls.competence(var)) + return competences + + def stop_tuning(self): + if hasattr(self, "tune"): + self.tune = False class CompoundStep: @@ -60,3 +163,43 @@ def reset_tuning(self): @property def vars(self): return [var for method in self.methods for var in method.vars] + + +def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]: + """Flatten a hierarchy of step methods to a list.""" + if isinstance(step, BlockedStep): + return [step] + steps = [] + if not isinstance(step, CompoundStep): + raise ValueError(f"Unexpected type of step method: {step}") + for sm in step.methods: + steps += flatten_steps(sm) + return steps + + +class StatsBijection: + """Map between a `list` of stats to `dict` of stats.""" + + def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: + # Keep a list of flat vs. original stat names + self._stat_groups: List[List[Tuple[str, str]]] = [ + [(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()] + for s, names_dtypes in enumerate(sampler_stats_dtypes) + ] + + def map(self, stats_list: StatsType) -> StatsDict: + """Combine stats dicts of multiple samplers into one dict.""" + stats_dict = {} + for s, sts in enumerate(stats_list): + for statname, sval in sts.items(): + sname = f"sampler_{s}__{statname}" + stats_dict[sname] = sval + return stats_dict + + def rmap(self, stats_dict: StatsDict) -> StatsType: + """Split a global stats dict into a list of sampler-wise stats dicts.""" + stats_list = [] + for namemap in self._stat_groups: + d = {statname: stats_dict[sname] for sname, statname in namemap} + stats_list.append(d) + return stats_list diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index e3ffbf4d77d..c9fb48a30cf 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -19,7 +19,7 @@ import numpy as np from pymc.stats.convergence import SamplerWarning -from pymc.step_methods.arraystep import Competence +from pymc.step_methods.compound import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State from pymc.vartypes import discrete_types diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 993c3f4224b..fc448ae09aa 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -21,7 +21,7 @@ from pymc.math import logbern from pymc.pytensorf import floatX from pymc.stats.convergence import SamplerWarning -from pymc.step_methods.arraystep import Competence +from pymc.step_methods.compound import Competence from pymc.step_methods.hmc import integration from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError, State diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 839ba417d15..ba5cf365c41 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -36,11 +36,11 @@ from pymc.step_methods.arraystep import ( ArrayStep, ArrayStepShared, - Competence, PopulationArrayStepShared, StatsType, metrop_select, ) +from pymc.step_methods.compound import Competence __all__ = [ "Metropolis", diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 59d2a690090..c7335638bcb 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -21,7 +21,8 @@ from pymc.blocking import RaveledVars, StatsType from pymc.model import modelcontext -from pymc.step_methods.arraystep import ArrayStep, Competence +from pymc.step_methods.arraystep import ArrayStep +from pymc.step_methods.compound import Competence from pymc.util import get_value_vars_from_user_vars from pymc.vartypes import continuous_types diff --git a/pymc/tests/step_methods/test_compound.py b/pymc/tests/step_methods/test_compound.py index 954c3ca19ab..a2fd41d4b63 100644 --- a/pymc/tests/step_methods/test_compound.py +++ b/pymc/tests/step_methods/test_compound.py @@ -25,6 +25,7 @@ Metropolis, Slice, ) +from pymc.step_methods.compound import StatsBijection, flatten_steps from pymc.tests.helpers import StepMethodTester, fast_unstable_sampling_mode from pymc.tests.models import simple_2model_continuous @@ -91,3 +92,41 @@ def test_compound_step(self): step2 = NUTS([c2]) step = CompoundStep([step1, step2]) assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars) + + +class TestStatsBijection: + def test_flatten_steps(self): + with pm.Model(): + a = pm.Normal("a") + b = pm.Normal("b") + c = pm.Normal("c") + s1 = Metropolis([a]) + s2 = Metropolis([b]) + c1 = CompoundStep([s1, s2]) + s3 = NUTS([c]) + c2 = CompoundStep([c1, s3]) + assert flatten_steps(s1) == [s1] + assert flatten_steps(c2) == [s1, s2, s3] + with pytest.raises(ValueError, match="Unexpected type"): + flatten_steps("not a step") + + def test_stats_bijection(self): + step_stats_dtypes = [ + {"a": float, "b": int}, + {"a": float, "c": int}, + ] + bij = StatsBijection(step_stats_dtypes) + stats_l = [ + dict(a=1.5, b=3), + dict(a=2.5, c=4), + ] + stats_d = bij.map(stats_l) + assert isinstance(stats_d, dict) + assert stats_d["sampler_0__a"] == 1.5 + assert stats_d["sampler_0__b"] == 3 + assert stats_d["sampler_1__a"] == 2.5 + assert stats_d["sampler_1__c"] == 4 + rev = bij.rmap(stats_d) + assert isinstance(rev, list) + assert len(rev) == len(stats_l) + assert rev == stats_l