Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring and addition of helpers to handle flat stats #6443

Merged
merged 2 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

T = TypeVar("T")
PointType: TypeAlias = Dict[str, np.ndarray]
StatsType: TypeAlias = List[Dict[str, Any]]
StatsDict: TypeAlias = Dict[str, Any]
StatsType: TypeAlias = List[StatsDict]

# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# each of the raveled variables.
Expand Down
105 changes: 4 additions & 101 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,116 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Callable, Dict, List, Tuple, Union, cast
from abc import abstractmethod
from typing import Callable, List, Tuple, Union, cast

import numpy as np

from numpy.random import uniform
from pytensor.graph.basic import Variable

from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
from pymc.model import modelcontext
from pymc.step_methods.compound import CompoundStep
from pymc.step_methods.compound import BlockedStep
from pymc.util import get_var_name

__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"]


@unique
class Competence(IntEnum):
"""Enum for characterizing competence classes of step methods.
Values include:
0: INCOMPATIBLE
1: COMPATIBLE
2: PREFERRED
3: IDEAL
"""

INCOMPATIBLE = 0
COMPATIBLE = 1
PREFERRED = 2
IDEAL = 3


class BlockedStep(ABC):

stats_dtypes: List[Dict[str, type]] = []
vars: List[Variable] = []

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
if blocked is None:
# Try to look up default value from class
blocked = getattr(cls, "default_blocked", True)
kwargs["blocked"] = blocked

model = modelcontext(kwargs.get("model"))
kwargs.update({"model": model})

# vars can either be first arg or a kwarg
if "vars" not in kwargs and len(args) >= 1:
vars = args[0]
args = args[1:]
elif "vars" in kwargs:
vars = kwargs.pop("vars")
else: # Assume all model variables
vars = model.value_vars

if not isinstance(vars, (tuple, list)):
vars = [vars]

if len(vars) == 0:
raise ValueError("No free random variables to sample.")

if not blocked and len(vars) > 1:
# In this case we create a separate sampler for each var
# and append them to a CompoundStep
steps = []
for var in vars:
step = super().__new__(cls)
# If we don't return the instance we have to manually
# call __init__
step.__init__([var], *args, **kwargs)
# Hack for creating the class correctly when unpickling.
step.__newargs = ([var],) + args, kwargs
steps.append(step)

return CompoundStep(steps)
else:
step = super().__new__(cls)
# Hack for creating the class correctly when unpickling.
step.__newargs = (vars,) + args, kwargs
return step

# Hack for creating the class correctly when unpickling.
def __getnewargs_ex__(self):
return self.__newargs

@abstractmethod
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
"""Perform a single step of the sampler."""

@staticmethod
def competence(var, has_grad):
return Competence.INCOMPATIBLE

@classmethod
def _competence(cls, vars, have_grad):
vars = np.atleast_1d(vars)
have_grad = np.atleast_1d(have_grad)
competences = []
for var, has_grad in zip(vars, have_grad):
try:
competences.append(cls.competence(var, has_grad))
except TypeError:
competences.append(cls.competence(var))
return competences

def stop_tuning(self):
if hasattr(self, "tune"):
self.tune = False
__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select"]


class ArrayStep(BlockedStep):
Expand Down
147 changes: 145 additions & 2 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,113 @@
@author: johnsalvatier
"""

from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Dict, List, Sequence, Tuple, Union

from typing import Tuple
import numpy as np

from pymc.blocking import PointType, StatsType
from pytensor.graph.basic import Variable

from pymc.blocking import PointType, StatsDict, StatsType
from pymc.model import modelcontext

__all__ = ("Competence", "CompoundStep")


@unique
class Competence(IntEnum):
"""Enum for characterizing competence classes of step methods.
Values include:
0: INCOMPATIBLE
1: COMPATIBLE
2: PREFERRED
3: IDEAL
"""

INCOMPATIBLE = 0
COMPATIBLE = 1
PREFERRED = 2
IDEAL = 3


class BlockedStep(ABC):

stats_dtypes: List[Dict[str, type]] = []
vars: List[Variable] = []

def __new__(cls, *args, **kwargs):
blocked = kwargs.get("blocked")
if blocked is None:
# Try to look up default value from class
blocked = getattr(cls, "default_blocked", True)
kwargs["blocked"] = blocked

model = modelcontext(kwargs.get("model"))
kwargs.update({"model": model})

# vars can either be first arg or a kwarg
if "vars" not in kwargs and len(args) >= 1:
vars = args[0]
args = args[1:]
elif "vars" in kwargs:
vars = kwargs.pop("vars")
else: # Assume all model variables
vars = model.value_vars

if not isinstance(vars, (tuple, list)):
vars = [vars]

if len(vars) == 0:
raise ValueError("No free random variables to sample.")

if not blocked and len(vars) > 1:
# In this case we create a separate sampler for each var
# and append them to a CompoundStep
steps = []
for var in vars:
step = super().__new__(cls)
# If we don't return the instance we have to manually
# call __init__
step.__init__([var], *args, **kwargs)
# Hack for creating the class correctly when unpickling.
step.__newargs = ([var],) + args, kwargs
steps.append(step)

return CompoundStep(steps)
else:
step = super().__new__(cls)
# Hack for creating the class correctly when unpickling.
step.__newargs = (vars,) + args, kwargs
return step

# Hack for creating the class correctly when unpickling.
def __getnewargs_ex__(self):
return self.__newargs

@abstractmethod
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
"""Perform a single step of the sampler."""

@staticmethod
def competence(var, has_grad):
return Competence.INCOMPATIBLE

@classmethod
def _competence(cls, vars, have_grad):
vars = np.atleast_1d(vars)
have_grad = np.atleast_1d(have_grad)
competences = []
for var, has_grad in zip(vars, have_grad):
try:
competences.append(cls.competence(var, has_grad))
except TypeError:
competences.append(cls.competence(var))
return competences

def stop_tuning(self):
if hasattr(self, "tune"):
self.tune = False


class CompoundStep:
Expand Down Expand Up @@ -60,3 +163,43 @@ def reset_tuning(self):
@property
def vars(self):
return [var for method in self.methods for var in method.vars]


def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
"""Flatten a hierarchy of step methods to a list."""
if isinstance(step, BlockedStep):
return [step]
steps = []
if not isinstance(step, CompoundStep):
raise ValueError(f"Unexpected type of step method: {step}")
for sm in step.methods:
steps += flatten_steps(sm)
return steps


class StatsBijection:
"""Map between a `list` of stats to `dict` of stats."""

def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None:
# Keep a list of flat vs. original stat names
self._stat_groups: List[List[Tuple[str, str]]] = [
[(f"sampler_{s}__{statname}", statname) for statname, _ in names_dtypes.items()]
for s, names_dtypes in enumerate(sampler_stats_dtypes)
]

def map(self, stats_list: StatsType) -> StatsDict:
Armavica marked this conversation as resolved.
Show resolved Hide resolved
"""Combine stats dicts of multiple samplers into one dict."""
stats_dict = {}
for s, sts in enumerate(stats_list):
for statname, sval in sts.items():
sname = f"sampler_{s}__{statname}"
stats_dict[sname] = sval
return stats_dict

def rmap(self, stats_dict: StatsDict) -> StatsType:
"""Split a global stats dict into a list of sampler-wise stats dicts."""
stats_list = []
for namemap in self._stat_groups:
d = {statname: stats_dict[sname] for sname, statname in namemap}
stats_list.append(d)
return stats_list
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.arraystep import Competence
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
from pymc.vartypes import discrete_types
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pymc.math import logbern
from pymc.pytensorf import floatX
from pymc.stats.convergence import SamplerWarning
from pymc.step_methods.arraystep import Competence
from pymc.step_methods.compound import Competence
from pymc.step_methods.hmc import integration
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError, State
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/metropolis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
from pymc.step_methods.arraystep import (
ArrayStep,
ArrayStepShared,
Competence,
PopulationArrayStepShared,
StatsType,
metrop_select,
)
from pymc.step_methods.compound import Competence

__all__ = [
"Metropolis",
Expand Down
3 changes: 2 additions & 1 deletion pymc/step_methods/slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from pymc.blocking import RaveledVars, StatsType
from pymc.model import modelcontext
from pymc.step_methods.arraystep import ArrayStep, Competence
from pymc.step_methods.arraystep import ArrayStep
from pymc.step_methods.compound import Competence
from pymc.util import get_value_vars_from_user_vars
from pymc.vartypes import continuous_types

Expand Down
39 changes: 39 additions & 0 deletions pymc/tests/step_methods/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Metropolis,
Slice,
)
from pymc.step_methods.compound import StatsBijection, flatten_steps
from pymc.tests.helpers import StepMethodTester, fast_unstable_sampling_mode
from pymc.tests.models import simple_2model_continuous

Expand Down Expand Up @@ -91,3 +92,41 @@ def test_compound_step(self):
step2 = NUTS([c2])
step = CompoundStep([step1, step2])
assert {m.rvs_to_values[c1], m.rvs_to_values[c2]} == set(step.vars)


class TestStatsBijection:
def test_flatten_steps(self):
with pm.Model():
a = pm.Normal("a")
b = pm.Normal("b")
c = pm.Normal("c")
s1 = Metropolis([a])
s2 = Metropolis([b])
c1 = CompoundStep([s1, s2])
s3 = NUTS([c])
c2 = CompoundStep([c1, s3])
assert flatten_steps(s1) == [s1]
assert flatten_steps(c2) == [s1, s2, s3]
with pytest.raises(ValueError, match="Unexpected type"):
flatten_steps("not a step")

def test_stats_bijection(self):
step_stats_dtypes = [
{"a": float, "b": int},
{"a": float, "c": int},
]
bij = StatsBijection(step_stats_dtypes)
stats_l = [
dict(a=1.5, b=3),
dict(a=2.5, c=4),
]
stats_d = bij.map(stats_l)
assert isinstance(stats_d, dict)
assert stats_d["sampler_0__a"] == 1.5
assert stats_d["sampler_0__b"] == 3
assert stats_d["sampler_1__a"] == 2.5
assert stats_d["sampler_1__c"] == 4
rev = bij.rmap(stats_d)
assert isinstance(rev, list)
assert len(rev) == len(stats_l)
assert rev == stats_l