Skip to content

Commit

Permalink
Require all step methods to return stats
Browse files Browse the repository at this point in the history
The reason for this change is the resulting simplification of code,
including simpler branching and less type ambiguity.
At the same time it allowed for fixing of a lot of type hints
and method signatures on step methods.

Closes pymc-devs#6270
  • Loading branch information
michaelosthege authored and wrongu committed Dec 1, 2022
1 parent 92b6026 commit 7349614
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 131 deletions.
7 changes: 5 additions & 2 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
from __future__ import annotations

from functools import partial
from typing import Callable, Dict, Generic, NamedTuple, TypeVar
from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar

import numpy as np

from typing_extensions import TypeAlias

__all__ = ["DictToArrayBijection"]


T = TypeVar("T")
PointType = Dict[str, np.ndarray]
PointType: TypeAlias = Dict[str, np.ndarray]
StatsType: TypeAlias = List[Dict[str, Any]]

# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# each of the raveled variables.
Expand Down
12 changes: 4 additions & 8 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,14 +912,10 @@ def _iter_sample(
step.iter_count = 0
if i == tune:
step.stop_tuning()
if step.generates_stats:
point, stats = step.step(point)
strace.record(point, stats)
log_warning_stats(stats)
diverging = i > tune and stats and stats[0].get("diverging")
else:
point = step.step(point)
strace.record(point, [])
point, stats = step.step(point)
strace.record(point, stats)
log_warning_stats(stats)
diverging = i > tune and stats and stats[0].get("diverging")
if callback is not None:
callback(
trace=strace,
Expand Down
10 changes: 1 addition & 9 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _start_loop(self):

if draw < self._draws + self._tune:
try:
point, stats = self._compute_point()
point, stats = self._step_method.step(self._point)
except SamplingError as e:
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(("error", e))
Expand All @@ -191,14 +191,6 @@ def _start_loop(self):
else:
raise ValueError("Unknown message " + msg[0])

def _compute_point(self):
if self._step_method.generates_stats:
point, stats = self._step_method.step(self._point)
else:
point = self._step_method.step(self._point)
stats = None
return point, stats


def _run_process(*args):
_Process(*args).run()
Expand Down
26 changes: 13 additions & 13 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging

from copy import copy
from typing import Iterator, List, Sequence, Union
from typing import Iterator, List, Sequence, Tuple, Union

import cloudpickle
import numpy as np
Expand All @@ -31,7 +31,11 @@
from pymc.model import modelcontext
from pymc.stats.convergence import log_warning_stats
from pymc.step_methods import CompoundStep
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.arraystep import (
BlockedStep,
PopulationArrayStepShared,
StatsType,
)
from pymc.util import RandomSeed

__all__ = ()
Expand Down Expand Up @@ -224,7 +228,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
_log.exception(f"ChainWalker{c}")
return

def step(self, tune_stop: bool, population):
def step(self, tune_stop: bool, population) -> List[Tuple[PointType, StatsType]]:
"""Step the entire population of chains.
Parameters
Expand All @@ -239,18 +243,18 @@ def step(self, tune_stop: bool, population):
update : list
List of (Point, stats) tuples for all chains
"""
updates = [None] * self.nchains
updates: List[Tuple[PointType, StatsType]] = []
if self.is_parallelized:
for c in range(self.nchains):
self._primary_ends[c].send((tune_stop, population))
# Blockingly get the step outcomes
for c in range(self.nchains):
updates[c] = self._primary_ends[c].recv()
updates.append(self._primary_ends[c].recv())
else:
for c in range(self.nchains):
if tune_stop:
self._steppers[c].stop_tuning()
updates[c] = self._steppers[c].step(population[c])
updates.append(self._steppers[c].step(population[c]))
return updates


Expand Down Expand Up @@ -378,13 +382,9 @@ def _iter_population(

# apply the update to the points and record to the traces
for c, strace in enumerate(traces):
if steppers[c].generates_stats:
points[c], stats = updates[c]
strace.record(points[c], stats)
log_warning_stats(stats)
else:
points[c] = updates[c]
strace.record(points[c])
points[c], stats = updates[c]
strace.record(points[c], stats)
log_warning_stats(stats)
# yield the state of all chains in parallel
yield traces
except KeyboardInterrupt:
Expand Down
55 changes: 21 additions & 34 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@

from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Dict, List, Tuple, TypeVar, Union
from typing import Callable, Dict, List, Tuple, Union, cast

import numpy as np

from aesara.graph.basic import Variable
from numpy.random import uniform

from pymc.blocking import DictToArrayBijection, PointType, RaveledVars
from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
from pymc.model import modelcontext
from pymc.step_methods.compound import CompoundStep
from pymc.util import get_var_name

__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"]

StatsType = TypeVar("StatsType")


@unique
class Competence(IntEnum):
Expand All @@ -49,7 +47,6 @@ class Competence(IntEnum):

class BlockedStep(ABC):

generates_stats = False
stats_dtypes: List[Dict[str, type]] = []
vars: List[Variable] = []

Expand Down Expand Up @@ -103,7 +100,7 @@ def __getnewargs_ex__(self):
return self.__newargs

@abstractmethod
def step(point: PointType, *args, **kwargs) -> Union[PointType, Tuple[PointType, StatsType]]:
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
"""Perform a single step of the sampler."""

@staticmethod
Expand Down Expand Up @@ -146,35 +143,28 @@ def __init__(self, vars, fs, allvars=False, blocked=True):
self.allvars = allvars
self.blocked = blocked

def step(self, point: PointType):
def step(self, point: PointType) -> Tuple[PointType, StatsType]:

partial_funcs_and_point = [DictToArrayBijection.mapf(x, start_point=point) for x in self.fs]
partial_funcs_and_point: List[Union[Callable, PointType]] = [
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
]
if self.allvars:
partial_funcs_and_point.append(point)

apoint = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
step_res = self.astep(apoint, *partial_funcs_and_point)

if self.generates_stats:
apoint_new, stats = step_res
else:
apoint_new = step_res
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
apoint = DictToArrayBijection.map(var_dict)
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)

if not isinstance(apoint_new, RaveledVars):
# We assume that the mapping has stayed the same
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)

point_new = DictToArrayBijection.rmap(apoint_new, start_point=point)

if self.generates_stats:
return point_new, stats

return point_new
return point_new, stats

@abstractmethod
def astep(
self, apoint: RaveledVars, point: PointType, *args
) -> Union[RaveledVars, Tuple[RaveledVars, StatsType]]:
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
"""Perform a single sample step in a raveled and concatenated parameter space."""


Expand All @@ -198,30 +188,27 @@ def __init__(self, vars, shared, blocked=True):
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
self.blocked = blocked

def step(self, point):
def step(self, point: PointType) -> Tuple[PointType, StatsType]:

for name, shared_var in self.shared.items():
shared_var.set_value(point[name])

q = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
q = DictToArrayBijection.map(var_dict)

step_res = self.astep(q)

if self.generates_stats:
apoint, stats = step_res
else:
apoint = step_res
apoint, stats = self.astep(q)

if not isinstance(apoint, RaveledVars):
# We assume that the mapping has stayed the same
apoint = RaveledVars(apoint, q.point_map_info)

new_point = DictToArrayBijection.rmap(apoint, start_point=point)

if self.generates_stats:
return new_point, stats
return new_point, stats

return new_point
@abstractmethod
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
"""Perform a single sample step in a raveled and concatenated parameter space."""


class PopulationArrayStepShared(ArrayStepShared):
Expand Down Expand Up @@ -281,7 +268,7 @@ def __init__(

super().__init__(vars, func._extra_vars_shared, blocked)

def step(self, point):
def step(self, point) -> Tuple[PointType, StatsType]:
self._logp_dlogp_func._extra_are_set = True
return super().step(point)

Expand Down
41 changes: 15 additions & 26 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
@author: johnsalvatier
"""
from collections import namedtuple

import numpy as np

from typing import Tuple

from pymc.blocking import PointType, StatsType


class CompoundStep:
Expand All @@ -28,36 +30,23 @@ class CompoundStep:

def __init__(self, methods):
self.methods = list(methods)
self.generates_stats = any(method.generates_stats for method in self.methods)
self.stats_dtypes = []
for method in self.methods:
if method.generates_stats:
self.stats_dtypes.extend(method.stats_dtypes)
self.stats_dtypes.extend(method.stats_dtypes)
self.name = (
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
)

def step(self, point):
if self.generates_stats:
states = []
for method in self.methods:
if method.generates_stats:
point, state = method.step(point)
states.extend(state)
else:
point = method.step(point)
# Model logp can only be the logp of the _last_ state, if there is
# one. Pop all others (if dict), or set to np.nan (if namedtuple).
for state in states[:-1]:
if isinstance(state, dict):
state.pop("model_logp", None)
elif isinstance(state, namedtuple):
state = state._replace(logp=np.nan)
return point, states
else:
for method in self.methods:
point = method.step(point)
return point
def step(self, point) -> Tuple[PointType, StatsType]:
stats = []
for method in self.methods:
point, sts = method.step(point)
stats.extend(sts)
# Model logp can only be the logp of the _last_ stats,
# if there is one. Pop all others.
for sts in stats[:-1]:
sts.pop("model_logp", None)
return point, stats

def stop_tuning(self):
for method in self.methods:
Expand Down
4 changes: 2 additions & 2 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np

from pymc.aesaraf import floatX
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
from pymc.exceptions import SamplingError
from pymc.model import Point, modelcontext
from pymc.stats.convergence import SamplerWarning, WarningType
Expand Down Expand Up @@ -157,7 +157,7 @@ def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
Subclasses must overwrite this abstract method and return an `HMCStepData` object.
"""

def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]:
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
"""Perform a single HMC iteration."""
perf_start = time.perf_counter()
process_start = time.process_time()
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class HamiltonianMC(BaseHMC):

name = "hmc"
default_blocked = True
generates_stats = True
stats_dtypes = [
{
"step_size": np.float64,
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class NUTS(BaseHMC):
name = "nuts"

default_blocked = True
generates_stats = True
stats_dtypes = [
{
"depth": np.int64,
Expand Down
Loading

0 comments on commit 7349614

Please sign in to comment.