Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require all step methods to return stats #6313

Merged
merged 2 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from aeppl.logprob import CheckParameterValue
from aeppl.transforms import RVTransform
from aesara import scalar
from aesara.compile.mode import Mode, get_mode
from aesara.compile import Function, Mode, get_mode
from aesara.gradient import grad
from aesara.graph import node_rewriter, rewrite_graph
from aesara.graph.basic import (
Expand Down Expand Up @@ -1044,7 +1044,7 @@ def compile_pymc(
random_seed: SeedSequenceSeed = None,
mode=None,
**kwargs,
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
) -> Function:
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.

This function also ensures shared RandomState/Generator used by RandomVariables
Expand Down
7 changes: 5 additions & 2 deletions pymc/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
from __future__ import annotations

from functools import partial
from typing import Callable, Dict, Generic, NamedTuple, TypeVar
from typing import Any, Callable, Dict, Generic, List, NamedTuple, TypeVar

import numpy as np

from typing_extensions import TypeAlias

__all__ = ["DictToArrayBijection"]


T = TypeVar("T")
PointType = Dict[str, np.ndarray]
PointType: TypeAlias = Dict[str, np.ndarray]
StatsType: TypeAlias = List[Dict[str, Any]]

# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
# each of the raveled variables.
Expand Down
12 changes: 4 additions & 8 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,14 +912,10 @@ def _iter_sample(
step.iter_count = 0
if i == tune:
step.stop_tuning()
if step.generates_stats:
point, stats = step.step(point)
strace.record(point, stats)
log_warning_stats(stats)
diverging = i > tune and stats and stats[0].get("diverging")
else:
point = step.step(point)
strace.record(point, [])
point, stats = step.step(point)
strace.record(point, stats)
log_warning_stats(stats)
diverging = i > tune and stats and stats[0].get("diverging")
if callback is not None:
callback(
trace=strace,
Expand Down
10 changes: 1 addition & 9 deletions pymc/sampling/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _start_loop(self):

if draw < self._draws + self._tune:
try:
point, stats = self._compute_point()
point, stats = self._step_method.step(self._point)
except SamplingError as e:
e = ExceptionWithTraceback(e, e.__traceback__)
self._msg_pipe.send(("error", e))
Expand All @@ -191,14 +191,6 @@ def _start_loop(self):
else:
raise ValueError("Unknown message " + msg[0])

def _compute_point(self):
if self._step_method.generates_stats:
point, stats = self._step_method.step(self._point)
else:
point = self._step_method.step(self._point)
stats = None
return point, stats


def _run_process(*args):
_Process(*args).run()
Expand Down
26 changes: 13 additions & 13 deletions pymc/sampling/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging

from copy import copy
from typing import Iterator, List, Sequence, Union
from typing import Iterator, List, Sequence, Tuple, Union

import cloudpickle
import numpy as np
Expand All @@ -31,7 +31,11 @@
from pymc.model import modelcontext
from pymc.stats.convergence import log_warning_stats
from pymc.step_methods import CompoundStep
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.arraystep import (
BlockedStep,
PopulationArrayStepShared,
StatsType,
)
from pymc.util import RandomSeed

__all__ = ()
Expand Down Expand Up @@ -224,7 +228,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
_log.exception(f"ChainWalker{c}")
return

def step(self, tune_stop: bool, population):
def step(self, tune_stop: bool, population) -> List[Tuple[PointType, StatsType]]:
"""Step the entire population of chains.

Parameters
Expand All @@ -239,18 +243,18 @@ def step(self, tune_stop: bool, population):
update : list
List of (Point, stats) tuples for all chains
"""
updates = [None] * self.nchains
updates: List[Tuple[PointType, StatsType]] = []
if self.is_parallelized:
for c in range(self.nchains):
self._primary_ends[c].send((tune_stop, population))
# Blockingly get the step outcomes
for c in range(self.nchains):
updates[c] = self._primary_ends[c].recv()
updates.append(self._primary_ends[c].recv())
else:
for c in range(self.nchains):
if tune_stop:
self._steppers[c].stop_tuning()
updates[c] = self._steppers[c].step(population[c])
updates.append(self._steppers[c].step(population[c]))
return updates


Expand Down Expand Up @@ -378,13 +382,9 @@ def _iter_population(

# apply the update to the points and record to the traces
for c, strace in enumerate(traces):
if steppers[c].generates_stats:
points[c], stats = updates[c]
strace.record(points[c], stats)
log_warning_stats(stats)
else:
points[c] = updates[c]
strace.record(points[c])
points[c], stats = updates[c]
strace.record(points[c], stats)
log_warning_stats(stats)
# yield the state of all chains in parallel
yield traces
except KeyboardInterrupt:
Expand Down
57 changes: 22 additions & 35 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,20 @@

from abc import ABC, abstractmethod
from enum import IntEnum, unique
from typing import Dict, List, Tuple, TypeVar, Union
from typing import Callable, Dict, List, Tuple, Union, cast

import numpy as np

from aesara.graph.basic import Variable
from numpy.random import uniform

from pymc.blocking import DictToArrayBijection, PointType, RaveledVars
from pymc.blocking import DictToArrayBijection, PointType, RaveledVars, StatsType
from pymc.model import modelcontext
from pymc.step_methods.compound import CompoundStep
from pymc.util import get_var_name

__all__ = ["ArrayStep", "ArrayStepShared", "metrop_select", "Competence"]

StatsType = TypeVar("StatsType")


@unique
class Competence(IntEnum):
Expand All @@ -49,7 +47,6 @@ class Competence(IntEnum):

class BlockedStep(ABC):

generates_stats = False
stats_dtypes: List[Dict[str, type]] = []
vars: List[Variable] = []

Expand Down Expand Up @@ -103,7 +100,7 @@ def __getnewargs_ex__(self):
return self.__newargs

@abstractmethod
def step(point: PointType, *args, **kwargs) -> Union[PointType, Tuple[PointType, StatsType]]:
def step(self, point: PointType) -> Tuple[PointType, StatsType]:
"""Perform a single step of the sampler."""

@staticmethod
Expand Down Expand Up @@ -146,35 +143,28 @@ def __init__(self, vars, fs, allvars=False, blocked=True):
self.allvars = allvars
self.blocked = blocked

def step(self, point: PointType):
def step(self, point: PointType) -> Tuple[PointType, StatsType]:

partial_funcs_and_point = [DictToArrayBijection.mapf(x, start_point=point) for x in self.fs]
partial_funcs_and_point: List[Union[Callable, PointType]] = [
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
]
if self.allvars:
partial_funcs_and_point.append(point)

apoint = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
step_res = self.astep(apoint, *partial_funcs_and_point)

if self.generates_stats:
apoint_new, stats = step_res
else:
apoint_new = step_res
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
apoint = DictToArrayBijection.map(var_dict)
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)

if not isinstance(apoint_new, RaveledVars):
# We assume that the mapping has stayed the same
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)

point_new = DictToArrayBijection.rmap(apoint_new, start_point=point)

if self.generates_stats:
return point_new, stats

return point_new
return point_new, stats

@abstractmethod
def astep(
self, apoint: RaveledVars, point: PointType, *args
) -> Union[RaveledVars, Tuple[RaveledVars, StatsType]]:
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exact composition of the args is specified via the constructor parameters 🤐

"""Perform a single sample step in a raveled and concatenated parameter space."""


Expand All @@ -198,30 +188,27 @@ def __init__(self, vars, shared, blocked=True):
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
self.blocked = blocked

def step(self, point):
def step(self, point: PointType) -> Tuple[PointType, StatsType]:

for name, shared_var in self.shared.items():
shared_var.set_value(point[name])

q = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
var_dict = {cast(str, v.name): point[cast(str, v.name)] for v in self.vars}
q = DictToArrayBijection.map(var_dict)

step_res = self.astep(q)

if self.generates_stats:
apoint, stats = step_res
else:
apoint = step_res
apoint, stats = self.astep(q)

if not isinstance(apoint, RaveledVars):
# We assume that the mapping has stayed the same
apoint = RaveledVars(apoint, q.point_map_info)

new_point = DictToArrayBijection.rmap(apoint, start_point=point)

if self.generates_stats:
return new_point, stats
return new_point, stats

return new_point
@abstractmethod
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
"""Perform a single sample step in a raveled and concatenated parameter space."""


class PopulationArrayStepShared(ArrayStepShared):
Expand Down Expand Up @@ -281,12 +268,12 @@ def __init__(

super().__init__(vars, func._extra_vars_shared, blocked)

def step(self, point):
def step(self, point) -> Tuple[PointType, StatsType]:
self._logp_dlogp_func._extra_are_set = True
return super().step(point)


def metrop_select(mr, q, q0):
def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> Tuple[np.ndarray, bool]:
"""Perform rejection/acceptance step for Metropolis class samplers.

Returns the new sample q if a uniform random number is less than the
Expand Down
41 changes: 15 additions & 26 deletions pymc/step_methods/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

@author: johnsalvatier
"""
from collections import namedtuple

import numpy as np

from typing import Tuple

from pymc.blocking import PointType, StatsType


class CompoundStep:
Expand All @@ -28,36 +30,23 @@ class CompoundStep:

def __init__(self, methods):
self.methods = list(methods)
self.generates_stats = any(method.generates_stats for method in self.methods)
self.stats_dtypes = []
for method in self.methods:
if method.generates_stats:
self.stats_dtypes.extend(method.stats_dtypes)
self.stats_dtypes.extend(method.stats_dtypes)
self.name = (
f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]"
)

def step(self, point):
if self.generates_stats:
states = []
for method in self.methods:
if method.generates_stats:
point, state = method.step(point)
states.extend(state)
else:
point = method.step(point)
# Model logp can only be the logp of the _last_ state, if there is
# one. Pop all others (if dict), or set to np.nan (if namedtuple).
for state in states[:-1]:
if isinstance(state, dict):
state.pop("model_logp", None)
elif isinstance(state, namedtuple):
state = state._replace(logp=np.nan)
return point, states
else:
for method in self.methods:
point = method.step(point)
return point
def step(self, point) -> Tuple[PointType, StatsType]:
stats = []
for method in self.methods:
point, sts = method.step(point)
stats.extend(sts)
# Model logp can only be the logp of the _last_ stats,
# if there is one. Pop all others.
for sts in stats[:-1]:
sts.pop("model_logp", None)
return point, stats

def stop_tuning(self):
for method in self.methods:
Expand Down
4 changes: 2 additions & 2 deletions pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np

from pymc.aesaraf import floatX
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.blocking import DictToArrayBijection, RaveledVars, StatsType
from pymc.exceptions import SamplingError
from pymc.model import Point, modelcontext
from pymc.stats.convergence import SamplerWarning, WarningType
Expand Down Expand Up @@ -157,7 +157,7 @@ def _hamiltonian_step(self, start, p0, step_size) -> HMCStepData:
Subclasses must overwrite this abstract method and return an `HMCStepData` object.
"""

def astep(self, q0: RaveledVars) -> tuple[RaveledVars, list[dict[str, Any]]]:
def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]:
"""Perform a single HMC iteration."""
perf_start = time.perf_counter()
process_start = time.process_time()
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class HamiltonianMC(BaseHMC):

name = "hmc"
default_blocked = True
generates_stats = True
stats_dtypes = [
{
"step_size": np.float64,
Expand Down
1 change: 0 additions & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class NUTS(BaseHMC):
name = "nuts"

default_blocked = True
generates_stats = True
stats_dtypes = [
{
"depth": np.int64,
Expand Down
Loading