-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decouple convergence checking from SamplerReport
#6453
Decouple convergence checking from SamplerReport
#6453
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6453 +/- ##
==========================================
+ Coverage 94.78% 94.79% +0.01%
==========================================
Files 148 148
Lines 27678 27678
==========================================
+ Hits 26234 26238 +4
+ Misses 1444 1440 -4
|
c6dbdbb
to
f419ed3
Compare
f419ed3
to
6c2f7f2
Compare
The goal was to uncouple sampling functions from `MultiTrace` and `SamplerReport`. Some calls to `SamplerReport._log_summary()` were unnecessary because `MultiTrace._add_warnings()` was never called inbetween instantiation and `_log_summary()`, therefore the traces never contained warnings. Running convergence checks and logging the warnings can also be done without needing `MultiTrace` or `SamplerReport` instances/methods.
6c2f7f2
to
49f5263
Compare
* Specify covariant input types in `StatsBijection`. * Annotate `_choose_chains` to be independent of `BaseTrace` type.
I don't think I am qualified to review this |
I should have added comments to the diff earlier.. GitHub suggested you because you edited the SMC code? Who else is familiar with it? |
S = TypeVar("S", bound=Sized) | ||
|
||
|
||
def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be Sized
.
@@ -602,7 +606,6 @@ def sample( | |||
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) " | |||
f"took {t_sampling:.0f} seconds." | |||
) | |||
mtrace.report._log_summary() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inbetween the line 574 mtrace = MultiTrace(traces)[:length]
where the MultiTrace
was created, no warnings were added to mtrace
.
Therefore, there are no warnings to log and the _log_summary()
call can safely be removed.
warnings.warn( | ||
"The number of samples is too small to check convergence reliably.", | ||
stacklevel=2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is now checked by run_convergence_checks
, just like it already checked for a minimum number of chains
multitrace = MultiTrace(traces) | ||
multitrace._report._log_summary() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too: The multitrace
can not have warnings that would be printed by _log_summary()
because none were added here or in its __init__
if idata is None: | ||
idata = to_inference_data(trace, log_likelihood=False) | ||
warns = run_convergence_checks(idata, model) | ||
trace.report._add_warnings(warns) | ||
log_warnings(warns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This replaces the _compute_convergence_checks
function and makes the trace.report
be a dead end that can easily be removed in the future
Remember from other changes:
- "number of samples is too small" warning now done by
run_convergence_checks
report._add_warnings
was done insidereport._run_convergence_checks
trace.report._log_summary()
internally calledlog_warnings()
@@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]: | |||
class StatsBijection: | |||
"""Map between a `list` of stats to `dict` of stats.""" | |||
|
|||
def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None: | |||
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typing rule of thumb: Generic input types, exact output types.
I have only modified some docstrings 😅. @aloctavodia is the best choice I think |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The goal was to uncouple sampling functions from
MultiTrace
andSamplerReport
.Some calls to
SamplerReport._log_summary()
were unnecessary becauseMultiTrace._add_warnings()
was never called inbetween instantiation and_log_summary()
, therefore the traces never contained warnings.Running convergence checks and logging the warnings can also be done without needing
MultiTrace
orSamplerReport
instances/methods.Checklist
Minor changes
"The number of samples is too small to check convergence reliably."
warning is now anINFO
level log message instead of aWarning
.SamplerReport._log_summary()
andSamplerReport._run_convergence_checks
methods were removed.Maintenance
MultiTrace
orSamplerReport
to compute/log warnings.