From 8fa8318ffd4711e64e3a2764af67461cb230d9c8 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 1 Jul 2020 19:06:00 +0200 Subject: [PATCH] Add more info to divergence warnings --- pymc3/backends/report.py | 48 ++++++++++++++++++------------ pymc3/step_methods/hmc/base_hmc.py | 40 +++++++++++++++++-------- pymc3/step_methods/hmc/nuts.py | 5 ++-- pymc3/step_methods/step_sizes.py | 2 +- 4 files changed, 61 insertions(+), 34 deletions(-) diff --git a/pymc3/backends/report.py b/pymc3/backends/report.py index 4384b85cbf5..42f6b8a9768 100644 --- a/pymc3/backends/report.py +++ b/pymc3/backends/report.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple import logging import enum -import typing +from typing import Any, Optional +import dataclasses + from ..util import is_transformed_name, get_untransformed_name import arviz @@ -38,9 +39,17 @@ class WarningType(enum.Enum): BAD_ENERGY = 8 -SamplerWarning = namedtuple( - 'SamplerWarning', - "kind, message, level, step, exec_info, extra") +@dataclasses.dataclass +class SamplerWarning: + kind: WarningType + message: str + level: str + step: Optional[int] = None + exec_info: Optional[Any] = None + extra: Optional[Any] = None + divergence_point_source: Optional[dict] = None + divergence_point_dest: Optional[dict] = None + divergence_info: Optional[Any] = None _LEVELS = { @@ -53,7 +62,8 @@ class WarningType(enum.Enum): class SamplerReport: - """This object bundles warnings, convergence statistics and metadata of a sampling run.""" + """Bundle warnings, convergence stats and metadata of a sampling run.""" + def __init__(self): self._chain_warnings = {} self._global_warnings = [] @@ -75,17 +85,17 @@ def ok(self): for warn in self._warnings) @property - def n_tune(self) -> typing.Optional[int]: + def n_tune(self) -> Optional[int]: """Number of tune iterations - not necessarily kept in trace!""" return self._n_tune @property - def n_draws(self) -> typing.Optional[int]: + def n_draws(self) -> Optional[int]: """Number of draw iterations.""" return self._n_draws @property - def t_sampling(self) -> typing.Optional[float]: + def t_sampling(self) -> Optional[float]: """ Number of seconds that the sampling procedure took. @@ -110,8 +120,7 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model): if idata.posterior.sizes['chain'] == 1: msg = ("Only one chain was sampled, this makes it impossible to " "run some convergence checks") - warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', - None, None, None) + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info') self._add_warnings([warn]) return @@ -134,41 +143,42 @@ def _run_convergence_checks(self, idata: arviz.InferenceData, model): msg = ("The rhat statistic is larger than 1.4 for some " "parameters. The sampler did not converge.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'error', None, None, rhat) + WarningType.CONVERGENCE, msg, 'error', extra=rhat) warnings.append(warn) elif rhat_max > 1.2: msg = ("The rhat statistic is larger than 1.2 for some " "parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'warn', None, None, rhat) + WarningType.CONVERGENCE, msg, 'warn', extra=rhat) warnings.append(warn) elif rhat_max > 1.05: msg = ("The rhat statistic is larger than 1.05 for some " "parameters. This indicates slight problems during " "sampling.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'info', None, None, rhat) + WarningType.CONVERGENCE, msg, 'info', extra=rhat) warnings.append(warn) eff_min = min(val.min() for val in ess.values()) - n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw'] + sizes = idata.posterior.sizes + n_samples = sizes['chain'] * sizes['draw'] if eff_min < 200 and n_samples >= 500: msg = ("The estimated number of effective samples is smaller than " "200 for some parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'error', None, None, ess) + WarningType.CONVERGENCE, msg, 'error', extra=ess) warnings.append(warn) elif eff_min / n_samples < 0.1: msg = ("The number of effective samples is smaller than " "10% for some parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'warn', None, None, ess) + WarningType.CONVERGENCE, msg, 'warn', extra=ess) warnings.append(warn) elif eff_min / n_samples < 0.25: msg = ("The number of effective samples is smaller than " "25% for some parameters.") warn = SamplerWarning( - WarningType.CONVERGENCE, msg, 'info', None, None, ess) + WarningType.CONVERGENCE, msg, 'info', extra=ess) warnings.append(warn) self._add_warnings(warnings) @@ -201,7 +211,7 @@ def filter_warns(warnings): filtered.append(warn) elif (start <= warn.step < stop and (warn.step - start) % step == 0): - warn = warn._replace(step=warn.step - start) + warn = dataclasses.replace(warn, step=warn.step - start) filtered.append(warn) return filtered diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index a426f491116..f1431794fab 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -28,10 +28,16 @@ logger = logging.getLogger("pymc3") -HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats") +HMCStepData = namedtuple( + "HMCStepData", + "end, accept_stat, divergence_info, stats" +) +DivergenceInfo = namedtuple( + "DivergenceInfo", + "message, exec_info, state, state_div" +) -DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state") class BaseHMC(arraystep.GradientSharedStep): """Superclass to implement Hamiltonian/hybrid monte carlo.""" @@ -151,8 +157,6 @@ def astep(self, q0): message_energy, "critical", self.iter_count, - None, - None, ) self._warnings.append(warning) raise SamplingError("Bad initial energy") @@ -170,19 +174,30 @@ def astep(self, q0): self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune) if hmc_step.divergence_info: info = hmc_step.divergence_info + point = None + point_dest = None + info_store = None if self.tune: kind = WarningType.TUNING_DIVERGENCE - point = None else: kind = WarningType.DIVERGENCE self._num_divs_sample += 1 # We don't want to fill up all memory with divergence info if self._num_divs_sample < 100: point = self._logp_dlogp_func.array_to_dict(info.state.q) - else: - point = None + point_dest = self._logp_dlogp_func.array_to_dict( + info.state_div.q + ) + info_store = info warning = SamplerWarning( - kind, info.message, "debug", self.iter_count, info.exec_info, point + kind, + info.message, + "debug", + self.iter_count, + info.exec_info, + divergence_point_source=point, + divergence_point_dest=point_dest, + divergence_info=info_store, ) self._warnings.append(warning) @@ -191,7 +206,10 @@ def astep(self, q0): if not self.tune: self._samples_after_tune += 1 - stats = {"tune": self.tune, "diverging": bool(hmc_step.divergence_info)} + stats = { + "tune": self.tune, + "diverging": bool(hmc_step.divergence_info) + } stats.update(hmc_step.stats) stats.update(self.step_adapt.stats()) @@ -230,9 +248,7 @@ def warnings(self): ) if message: - warning = SamplerWarning( - WarningType.DIVERGENCES, message, "error", None, None, None - ) + warning = SamplerWarning(WarningType.DIVERGENCES, message, "error") warnings.append(warning) warnings.extend(self.step_adapt.warnings()) diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index d2409acca25..d39f509d941 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -200,7 +200,7 @@ def warnings(self): "The chain reached the maximum tree depth. Increase " "max_treedepth, increase target_accept or reparameterize." ) - warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None) + warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn') warnings.append(warn) return warnings @@ -321,6 +321,7 @@ def _single_step(self, left, epsilon): except IntegrationError as err: error_msg = str(err) error = err + right = None else: # h - H0 energy_change = right.energy - self.start_energy @@ -353,7 +354,7 @@ def _single_step(self, left, epsilon): ) error = None tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1) - divergance_info = DivergenceInfo(error_msg, error, left) + divergance_info = DivergenceInfo(error_msg, error, left, right) return tree, divergance_info, False def _build_subtree(self, left, depth, epsilon): diff --git a/pymc3/step_methods/step_sizes.py b/pymc3/step_methods/step_sizes.py index 51262b82397..bdf0683c621 100644 --- a/pymc3/step_methods/step_sizes.py +++ b/pymc3/step_methods/step_sizes.py @@ -77,7 +77,7 @@ def warnings(self): % (mean_accept, target_accept)) info = {'target': target_accept, 'actual': mean_accept} warning = SamplerWarning( - WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info) + WarningType.BAD_ACCEPTANCE, msg, 'warn', extra=info) return [warning] else: return []