From a574ed5c5409aec4c1def096ba05393293b6a4df Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Fri, 7 Oct 2022 12:01:57 +0200 Subject: [PATCH] Extract convergence checks function from `SamplerReport` --- pymc/backends/report.py | 123 +++--------------------------- pymc/sampling.py | 6 +- pymc/stats/convergence.py | 117 ++++++++++++++++++++++++++++ pymc/step_methods/hmc/base_hmc.py | 2 +- pymc/step_methods/hmc/nuts.py | 2 +- pymc/step_methods/step_sizes.py | 2 +- 6 files changed, 136 insertions(+), 116 deletions(-) create mode 100644 pymc/stats/convergence.py diff --git a/pymc/backends/report.py b/pymc/backends/report.py index 8be6d4f938b..11cdf931397 100644 --- a/pymc/backends/report.py +++ b/pymc/backends/report.py @@ -13,63 +13,28 @@ # limitations under the License. import dataclasses -import enum import logging -from typing import Any, Optional +from typing import Optional import arviz -from pymc.util import get_untransformed_name, is_transformed_name +from pymc.stats.convergence import ( + _LEVELS, + SamplerWarning, + log_warnings, + run_convergence_checks, +) logger = logging.getLogger("pymc") -@enum.unique -class WarningType(enum.Enum): - # For HMC and NUTS - DIVERGENCE = 1 - TUNING_DIVERGENCE = 2 - DIVERGENCES = 3 - TREEDEPTH = 4 - # Problematic sampler parameters - BAD_PARAMS = 5 - # Indications that chains did not converge, eg Rhat - CONVERGENCE = 6 - BAD_ACCEPTANCE = 7 - BAD_ENERGY = 8 - - -@dataclasses.dataclass -class SamplerWarning: - kind: WarningType - message: str - level: str - step: Optional[int] = None - exec_info: Optional[Any] = None - extra: Optional[Any] = None - divergence_point_source: Optional[dict] = None - divergence_point_dest: Optional[dict] = None - divergence_info: Optional[Any] = None - - -_LEVELS = { - "info": logging.INFO, - "error": logging.ERROR, - "warn": logging.WARN, - "debug": logging.DEBUG, - "critical": logging.CRITICAL, -} - - class SamplerReport: """Bundle warnings, convergence stats and metadata of a sampling run.""" def __init__(self): - self._chain_warnings = {} - self._global_warnings = [] - self._ess = None - self._rhat = None + self._chain_warnings: Dict[int, List[SamplerWarning]] = {} + self._global_warnings: List[SamplerWarning] = [] self._n_tune = None self._n_draws = None self._t_sampling = None @@ -109,65 +74,7 @@ def raise_ok(self, level="error"): raise ValueError("Serious convergence issues during sampling.") def _run_convergence_checks(self, idata: arviz.InferenceData, model): - if not hasattr(idata, "posterior"): - msg = "No posterior samples. Unable to run convergence checks" - warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) - self._add_warnings([warn]) - return - - if idata["posterior"].sizes["chain"] == 1: - msg = ( - "Only one chain was sampled, this makes it impossible to " - "run some convergence checks" - ) - warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") - self._add_warnings([warn]) - return - - elif idata["posterior"].sizes["chain"] < 4: - msg = ( - "We recommend running at least 4 chains for robust computation of " - "convergence diagnostics" - ) - warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") - self._add_warnings([warn]) - return - - valid_name = [rv.name for rv in model.free_RVs + model.deterministics] - varnames = [] - for rv in model.free_RVs: - rv_name = rv.name - if is_transformed_name(rv_name): - rv_name2 = get_untransformed_name(rv_name) - rv_name = rv_name2 if rv_name2 in valid_name else rv_name - if rv_name in idata["posterior"]: - varnames.append(rv_name) - - self._ess = ess = arviz.ess(idata, var_names=varnames) - self._rhat = rhat = arviz.rhat(idata, var_names=varnames) - - warnings = [] - rhat_max = max(val.max() for val in rhat.values()) - if rhat_max > 1.01: - msg = ( - "The rhat statistic is larger than 1.01 for some " - "parameters. This indicates problems during sampling. " - "See https://arxiv.org/abs/1903.08008 for details" - ) - warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat) - warnings.append(warn) - - eff_min = min(val.min() for val in ess.values()) - eff_per_chain = eff_min / idata["posterior"].sizes["chain"] - if eff_per_chain < 100: - msg = ( - "The effective sample size per chain is smaller than 100 for some parameters. " - " A higher number is needed for reliable rhat and ess computation. " - "See https://arxiv.org/abs/1903.08008 for details" - ) - warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess) - warnings.append(warn) - + warnings = run_convergence_checks(idata, model) self._add_warnings(warnings) def _add_warnings(self, warnings, chain=None): @@ -178,15 +85,9 @@ def _add_warnings(self, warnings, chain=None): warn_list.extend(warnings) def _log_summary(self): - def log_warning(warn): - level = _LEVELS[warn.level] - logger.log(level, warn.message) - for chain, warns in self._chain_warnings.items(): - for warn in warns: - log_warning(warn) - for warn in self._global_warnings: - log_warning(warn) + log_warnings(warns) + log_warnings(self._global_warnings) def _slice(self, start, stop, step): report = SamplerReport() diff --git a/pymc/sampling.py b/pymc/sampling.py index ad26b309b94..3a54b598869 100644 --- a/pymc/sampling.py +++ b/pymc/sampling.py @@ -70,6 +70,7 @@ ) from pymc.model import Model, modelcontext from pymc.parallel_sampling import Draw, _cpu_count +from pymc.stats.convergence import run_convergence_checks from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared from pymc.step_methods.hmc import quadpotential @@ -677,7 +678,7 @@ def sample( _log.info( f'Sampling {n_chains} chain{"s" if n_chains > 1 else ""} for {n_tune:_d} tune and {n_draws:_d} draw iterations ' f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " - f"took {mtrace.report.t_sampling:.0f} seconds." + f"took {t_sampling:.0f} seconds." ) mtrace.report._log_summary() @@ -695,7 +696,8 @@ def sample( stacklevel=2, ) else: - mtrace.report._run_convergence_checks(idata, model) + convergence_warnings = run_convergence_checks(idata, model) + mtrace.report._add_warnings(convergence_warnings) if return_inferencedata: return idata diff --git a/pymc/stats/convergence.py b/pymc/stats/convergence.py new file mode 100644 index 00000000000..3288f5e8816 --- /dev/null +++ b/pymc/stats/convergence.py @@ -0,0 +1,117 @@ +import dataclasses +import enum +import logging + +from typing import Any, List, Optional, Sequence + +import arviz + +from pymc.util import get_untransformed_name, is_transformed_name + +_LEVELS = { + "info": logging.INFO, + "error": logging.ERROR, + "warn": logging.WARN, + "debug": logging.DEBUG, + "critical": logging.CRITICAL, +} + +logger = logging.getLogger("pymc") + + +@enum.unique +class WarningType(enum.Enum): + # For HMC and NUTS + DIVERGENCE = 1 + TUNING_DIVERGENCE = 2 + DIVERGENCES = 3 + TREEDEPTH = 4 + # Problematic sampler parameters + BAD_PARAMS = 5 + # Indications that chains did not converge, eg Rhat + CONVERGENCE = 6 + BAD_ACCEPTANCE = 7 + BAD_ENERGY = 8 + + +@dataclasses.dataclass +class SamplerWarning: + kind: WarningType + message: str + level: str + step: Optional[int] = None + exec_info: Optional[Any] = None + extra: Optional[Any] = None + divergence_point_source: Optional[dict] = None + divergence_point_dest: Optional[dict] = None + divergence_info: Optional[Any] = None + + +def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWarning]: + if not hasattr(idata, "posterior"): + msg = "No posterior samples. Unable to run convergence checks" + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) + return [warn] + + if idata["posterior"].sizes["chain"] == 1: + msg = ( + "Only one chain was sampled, this makes it impossible to " "run some convergence checks" + ) + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") + return [warn] + + elif idata["posterior"].sizes["chain"] < 4: + msg = ( + "We recommend running at least 4 chains for robust computation of " + "convergence diagnostics" + ) + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") + return [warn] + + warnings = [] + valid_name = [rv.name for rv in model.free_RVs + model.deterministics] + varnames = [] + for rv in model.free_RVs: + rv_name = rv.name + if is_transformed_name(rv_name): + rv_name2 = get_untransformed_name(rv_name) + rv_name = rv_name2 if rv_name2 in valid_name else rv_name + if rv_name in idata["posterior"]: + varnames.append(rv_name) + + ess = arviz.ess(idata, var_names=varnames) + rhat = arviz.rhat(idata, var_names=varnames) + + warnings = [] + rhat_max = max(val.max() for val in rhat.values()) + if rhat_max > 1.01: + msg = ( + "The rhat statistic is larger than 1.01 for some " + "parameters. This indicates problems during sampling. " + "See https://arxiv.org/abs/1903.08008 for details" + ) + warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat) + warnings.append(warn) + + eff_min = min(val.min() for val in ess.values()) + eff_per_chain = eff_min / idata["posterior"].sizes["chain"] + if eff_per_chain < 100: + msg = ( + "The effective sample size per chain is smaller than 100 for some parameters. " + " A higher number is needed for reliable rhat and ess computation. " + "See https://arxiv.org/abs/1903.08008 for details" + ) + warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess) + warnings.append(warn) + + return warnings + + +def log_warning(warn: SamplerWarning): + level = _LEVELS[warn.level] + logger.log(level, warn.message) + + +def log_warnings(warnings: Sequence[SamplerWarning]): + for warn in warnings: + log_warning(warn) diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 1f45856ab97..61b1cc50820 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -21,10 +21,10 @@ import numpy as np from pymc.aesaraf import floatX -from pymc.backends.report import SamplerWarning, WarningType from pymc.blocking import DictToArrayBijection, RaveledVars from pymc.exceptions import SamplingError from pymc.model import Point, modelcontext +from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import step_sizes from pymc.step_methods.arraystep import GradientSharedStep from pymc.step_methods.hmc import integration diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index c9005c96447..5dd2231a8ea 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -17,8 +17,8 @@ import numpy as np from pymc.aesaraf import floatX -from pymc.backends.report import SamplerWarning, WarningType from pymc.math import logbern +from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods.arraystep import Competence from pymc.step_methods.hmc.base_hmc import BaseHMC, DivergenceInfo, HMCStepData from pymc.step_methods.hmc.integration import IntegrationError diff --git a/pymc/step_methods/step_sizes.py b/pymc/step_methods/step_sizes.py index e9e374700d0..739fb737b5f 100644 --- a/pymc/step_methods/step_sizes.py +++ b/pymc/step_methods/step_sizes.py @@ -16,7 +16,7 @@ from scipy import stats -from pymc.backends.report import SamplerWarning, WarningType +from pymc.stats.convergence import SamplerWarning, WarningType class DualAverageAdaptation: