-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract convergence checks function from
SamplerReport
- Loading branch information
1 parent
91dbfd2
commit a574ed5
Showing
6 changed files
with
136 additions
and
116 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import dataclasses | ||
import enum | ||
import logging | ||
|
||
from typing import Any, List, Optional, Sequence | ||
|
||
import arviz | ||
|
||
from pymc.util import get_untransformed_name, is_transformed_name | ||
|
||
_LEVELS = { | ||
"info": logging.INFO, | ||
"error": logging.ERROR, | ||
"warn": logging.WARN, | ||
"debug": logging.DEBUG, | ||
"critical": logging.CRITICAL, | ||
} | ||
|
||
logger = logging.getLogger("pymc") | ||
|
||
|
||
@enum.unique | ||
class WarningType(enum.Enum): | ||
# For HMC and NUTS | ||
DIVERGENCE = 1 | ||
TUNING_DIVERGENCE = 2 | ||
DIVERGENCES = 3 | ||
TREEDEPTH = 4 | ||
# Problematic sampler parameters | ||
BAD_PARAMS = 5 | ||
# Indications that chains did not converge, eg Rhat | ||
CONVERGENCE = 6 | ||
BAD_ACCEPTANCE = 7 | ||
BAD_ENERGY = 8 | ||
|
||
|
||
@dataclasses.dataclass | ||
class SamplerWarning: | ||
kind: WarningType | ||
message: str | ||
level: str | ||
step: Optional[int] = None | ||
exec_info: Optional[Any] = None | ||
extra: Optional[Any] = None | ||
divergence_point_source: Optional[dict] = None | ||
divergence_point_dest: Optional[dict] = None | ||
divergence_info: Optional[Any] = None | ||
|
||
|
||
def run_convergence_checks(idata: arviz.InferenceData, model) -> List[SamplerWarning]: | ||
if not hasattr(idata, "posterior"): | ||
msg = "No posterior samples. Unable to run convergence checks" | ||
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info", None, None, None) | ||
return [warn] | ||
|
||
if idata["posterior"].sizes["chain"] == 1: | ||
msg = ( | ||
"Only one chain was sampled, this makes it impossible to " "run some convergence checks" | ||
) | ||
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") | ||
return [warn] | ||
|
||
elif idata["posterior"].sizes["chain"] < 4: | ||
msg = ( | ||
"We recommend running at least 4 chains for robust computation of " | ||
"convergence diagnostics" | ||
) | ||
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, "info") | ||
return [warn] | ||
|
||
warnings = [] | ||
valid_name = [rv.name for rv in model.free_RVs + model.deterministics] | ||
varnames = [] | ||
for rv in model.free_RVs: | ||
rv_name = rv.name | ||
if is_transformed_name(rv_name): | ||
rv_name2 = get_untransformed_name(rv_name) | ||
rv_name = rv_name2 if rv_name2 in valid_name else rv_name | ||
if rv_name in idata["posterior"]: | ||
varnames.append(rv_name) | ||
|
||
ess = arviz.ess(idata, var_names=varnames) | ||
rhat = arviz.rhat(idata, var_names=varnames) | ||
|
||
warnings = [] | ||
rhat_max = max(val.max() for val in rhat.values()) | ||
if rhat_max > 1.01: | ||
msg = ( | ||
"The rhat statistic is larger than 1.01 for some " | ||
"parameters. This indicates problems during sampling. " | ||
"See https://arxiv.org/abs/1903.08008 for details" | ||
) | ||
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "info", extra=rhat) | ||
warnings.append(warn) | ||
|
||
eff_min = min(val.min() for val in ess.values()) | ||
eff_per_chain = eff_min / idata["posterior"].sizes["chain"] | ||
if eff_per_chain < 100: | ||
msg = ( | ||
"The effective sample size per chain is smaller than 100 for some parameters. " | ||
" A higher number is needed for reliable rhat and ess computation. " | ||
"See https://arxiv.org/abs/1903.08008 for details" | ||
) | ||
warn = SamplerWarning(WarningType.CONVERGENCE, msg, "error", extra=ess) | ||
warnings.append(warn) | ||
|
||
return warnings | ||
|
||
|
||
def log_warning(warn: SamplerWarning): | ||
level = _LEVELS[warn.level] | ||
logger.log(level, warn.message) | ||
|
||
|
||
def log_warnings(warnings: Sequence[SamplerWarning]): | ||
for warn in warnings: | ||
log_warning(warn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters