Skip to content

Commit

Permalink
Extract convergence checks function from SamplerReport
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and ricardoV94 committed Oct 7, 2022
1 parent 91dbfd2 commit a574ed5
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 116 deletions.
123 changes: 12 additions & 111 deletions pymc/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,63 +13,28 @@
# limitations under the License.

import dataclasses
import enum
import logging

from typing import Any, Optional
from typing import Optional

import arviz

from pymc.util import get_untransformed_name, is_transformed_name
from pymc.stats.convergence import (
_LEVELS,
SamplerWarning,
log_warnings,
run_convergence_checks,
)

logger = logging.getLogger("pymc")


@enum.unique
class WarningType(enum.Enum):
# For HMC and NUTS
DIVERGENCE = 1
TUNING_DIVERGENCE = 2
DIVERGENCES = 3
TREEDEPTH = 4
# Problematic sampler parameters
BAD_PARAMS = 5
# Indications that chains did not converge, eg Rhat
CONVERGENCE = 6
BAD_ACCEPTANCE = 7
BAD_ENERGY = 8


@dataclasses.dataclass
class SamplerWarning:
kind: WarningType
message: str
level: str
step: Optional[int] = None
exec_info: Optional[Any] = None
extra: Optional[Any] = None
divergence_point_source: Optional[dict] = None
divergence_point_dest: Optional[dict] = None
divergence_info: Optional[Any] = None


_LEVELS = {
"info": logging.INFO,
"error": logging.ERROR,
"warn": logging.WARN,
"debug": logging.DEBUG,
"critical": logging.CRITICAL,
}


class SamplerReport:
"""Bundle warnings, convergence stats and metadata of a sampling run."""

def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._ess = None
self._rhat = None
self._chain_warnings: Dict[int, List[SamplerWarning]] = {}
self._global_warnings: List[SamplerWarning] = []
self._n_tune = None
self._n_draws = None
self._t_sampling = None
Expand Down Expand Up @@ -109,65 +74,7 @@ def raise_ok(self, level="error"):
raise ValueError("Serious convergence issues during sampling.")

def _run_convergence_checks(self, idata: arviz.InferenceData, model):
if not hasattr(idata, "posterior"):
msg = "No posterior samples. Unable to run convergence checks"
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
self._add_warnings([warn])
return

if idata["posterior"].sizes["chain"] == 1:
msg = (
"Only one chain was sampled, this makes it impossible to "
"run some convergence checks"
)
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
self._add_warnings([warn])
return

elif idata["posterior"].sizes["chain"] < 4:
msg = (
"We recommend running at least 4 chains for robust computation of "
"convergence diagnostics"
)
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
self._add_warnings([warn])
return

valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
varnames = []
for rv in model.free_RVs:
rv_name = rv.name
if is_transformed_name(rv_name):
rv_name2 = get_untransformed_name(rv_name)
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
if rv_name in idata["posterior"]:
varnames.append(rv_name)

self._ess = ess = arviz.ess(idata, var_names=varnames)
self._rhat = rhat = arviz.rhat(idata, var_names=varnames)

warnings = []
rhat_max = max(val.max() for val in rhat.values())
if rhat_max > 1.01:
msg = (
"The rhat statistic is larger than 1.01 for some "
"parameters. This indicates problems during sampling. "
"See https://arxiv.org/abs/1903.08008 for details"
)
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat)
warnings.append(warn)

eff_min = min(val.min() for val in ess.values())
eff_per_chain = eff_min / idata["posterior"].sizes["chain"]
if eff_per_chain < 100:
msg = (
"The effective sample size per chain is smaller than 100 for some parameters. "
" A higher number is needed for reliable rhat and ess computation. "
"See https://arxiv.org/abs/1903.08008 for details"
)
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
warnings.append(warn)

warnings = run_convergence_checks(idata, model)
self._add_warnings(warnings)

def _add_warnings(self, warnings, chain=None):
Expand All @@ -178,15 +85,9 @@ def _add_warnings(self, warnings, chain=None):
warn_list.extend(warnings)

def _log_summary(self):
def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)

for chain, warns in self._chain_warnings.items():
for warn in warns:
log_warning(warn)
for warn in self._global_warnings:
log_warning(warn)
log_warnings(warns)
log_warnings(self._global_warnings)

def _slice(self, start, stop, step):
report = SamplerReport()
Expand Down
6 changes: 4 additions & 2 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
)
from pymc.model import Model, modelcontext
from pymc.parallel_sampling import Draw, _cpu_count
from pymc.stats.convergence import run_convergence_checks
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
from pymc.step_methods.hmc import quadpotential
Expand Down Expand Up @@ -677,7 +678,7 @@ def sample(
_log.info(
f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations '
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
f"took {mtrace.report.t_sampling:.0f} seconds."
f"took {t_sampling:.0f} seconds."
)
mtrace.report._log_summary()

Expand All @@ -695,7 +696,8 @@ def sample(
stacklevel=2,
)
else:
mtrace.report._run_convergence_checks(idata, model)
convergence_warnings = run_convergence_checks(idata, model)
mtrace.report._add_warnings(convergence_warnings)

if return_inferencedata:
return idata
Expand Down
117 changes: 117 additions & 0 deletions pymc/stats/convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import dataclasses
import enum
import logging

from typing import Any, List, Optional, Sequence

import arviz

from pymc.util import get_untransformed_name, is_transformed_name

_LEVELS = {
"info": logging.INFO,
"error": logging.ERROR,
"warn": logging.WARN,
"debug": logging.DEBUG,
"critical": logging.CRITICAL,
}

logger = logging.getLogger("pymc")


@enum.unique
class WarningType(enum.Enum):
# For HMC and NUTS
DIVERGENCE = 1
TUNING_DIVERGENCE = 2
DIVERGENCES = 3
TREEDEPTH = 4
# Problematic sampler parameters
BAD_PARAMS = 5
# Indications that chains did not converge, eg Rhat
CONVERGENCE = 6
BAD_ACCEPTANCE = 7
BAD_ENERGY = 8


@dataclasses.dataclass
class SamplerWarning:
kind: WarningType
message: str
level: str
step: Optional[int] = None
exec_info: Optional[Any] = None
extra: Optional[Any] = None
divergence_point_source: Optional[dict] = None
divergence_point_dest: Optional[dict] = None
divergence_info: Optional[Any] = None


def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWarning]:
if not hasattr(idata, "posterior"):
msg = "No posterior samples. Unable to run convergence checks"
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None)
return [warn]

if idata["posterior"].sizes["chain"] == 1:
msg = (
"Only one chain was sampled, this makes it impossible to " "run some convergence checks"
)
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
return [warn]

elif idata["posterior"].sizes["chain"] < 4:
msg = (
"We recommend running at least 4 chains for robust computation of "
"convergence diagnostics"
)
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info")
return [warn]

warnings = []
valid_name = [rv.name for rv in model.free_RVs + model.deterministics]
varnames = []
for rv in model.free_RVs:
rv_name = rv.name
if is_transformed_name(rv_name):
rv_name2 = get_untransformed_name(rv_name)
rv_name = rv_name2 if rv_name2 in valid_name else rv_name
if rv_name in idata["posterior"]:
varnames.append(rv_name)

ess = arviz.ess(idata, var_names=varnames)
rhat = arviz.rhat(idata, var_names=varnames)

warnings = []
rhat_max = max(val.max() for val in rhat.values())
if rhat_max > 1.01:
msg = (
"The rhat statistic is larger than 1.01 for some "
"parameters. This indicates problems during sampling. "
"See https://arxiv.org/abs/1903.08008 for details"
)
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat)
warnings.append(warn)

eff_min = min(val.min() for val in ess.values())
eff_per_chain = eff_min / idata["posterior"].sizes["chain"]
if eff_per_chain < 100:
msg = (
"The effective sample size per chain is smaller than 100 for some parameters. "
" A higher number is needed for reliable rhat and ess computation. "
"See https://arxiv.org/abs/1903.08008 for details"
)
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess)
warnings.append(warn)

return warnings


def log_warning(warn: SamplerWarning):
level = _LEVELS[warn.level]
logger.log(level, warn.message)


def log_warnings(warnings: Sequence[SamplerWarning]):
for warn in warnings:
log_warning(warn)
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import numpy as np

from pymc.aesaraf import floatX
from pymc.backends.report import SamplerWarning, WarningType
from pymc.blocking import DictToArrayBijection, RaveledVars
from pymc.exceptions import SamplingError
from pymc.model import Point, modelcontext
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods import step_sizes
from pymc.step_methods.arraystep import GradientSharedStep
from pymc.step_methods.hmc import integration
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import numpy as np

from pymc.aesaraf import floatX
from pymc.backends.report import SamplerWarning, WarningType
from pymc.math import logbern
from pymc.stats.convergence import SamplerWarning, WarningType
from pymc.step_methods.arraystep import Competence
from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData
from pymc.step_methods.hmc.integration import IntegrationError
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/step_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from scipy import stats

from pymc.backends.report import SamplerWarning, WarningType
from pymc.stats.convergence import SamplerWarning, WarningType


class DualAverageAdaptation:
Expand Down

0 comments on commit a574ed5

Please sign in to comment.