Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple convergence checking from SamplerReport #6453

Merged

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Jan 14, 2023

The goal was to uncouple sampling functions from MultiTrace and SamplerReport.

Some calls to SamplerReport._log_summary() were unnecessary because MultiTrace._add_warnings() was never called inbetween instantiation and _log_summary(), therefore the traces never contained warnings.

Running convergence checks and logging the warnings can also be done without needing MultiTrace or SamplerReport instances/methods.

Checklist

Minor changes

  • The "The number of samples is too small to check convergence reliably." warning is now an INFO level log message instead of a Warning.
  • SamplerReport._log_summary() and SamplerReport._run_convergence_checks methods were removed.

Maintenance

  • More type hints in SMC code
  • SMC and MCMC sampling functions no longer rely on instantiating a MultiTrace or SamplerReport to compute/log warnings.

@michaelosthege michaelosthege added the trace-backend Traces and ArviZ stuff label Jan 14, 2023
@michaelosthege michaelosthege self-assigned this Jan 14, 2023
@codecov
Copy link

codecov bot commented Jan 14, 2023

Codecov Report

Merging #6453 (ab128bc) into main (6ab0c03) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6453      +/-   ##
==========================================
+ Coverage   94.78%   94.79%   +0.01%     
==========================================
  Files         148      148              
  Lines       27678    27678              
==========================================
+ Hits        26234    26238       +4     
+ Misses       1444     1440       -4     
Impacted Files Coverage Δ
pymc/backends/base.py 84.82% <100.00%> (+0.06%) ⬆️
pymc/backends/report.py 78.84% <100.00%> (-1.16%) ⬇️
pymc/sampling/mcmc.py 93.06% <100.00%> (-0.07%) ⬇️
pymc/smc/kernels.py 97.44% <100.00%> (+0.02%) ⬆️
pymc/smc/sampling.py 86.61% <100.00%> (+0.19%) ⬆️
pymc/stats/convergence.py 95.61% <100.00%> (+2.88%) ⬆️
pymc/step_methods/compound.py 97.45% <100.00%> (ø)
pymc/tests/smc/test_smc.py 100.00% <100.00%> (ø)

@michaelosthege michaelosthege force-pushed the decouple-convergence-checking branch from c6dbdbb to f419ed3 Compare January 14, 2023 18:14
@michaelosthege michaelosthege marked this pull request as ready for review January 14, 2023 18:14
@michaelosthege michaelosthege added the SMC Sequential Monte Carlo label Jan 14, 2023
@michaelosthege michaelosthege force-pushed the decouple-convergence-checking branch from f419ed3 to 6c2f7f2 Compare January 14, 2023 18:42
The goal was to uncouple sampling functions
from `MultiTrace` and `SamplerReport`.

Some calls to `SamplerReport._log_summary()` were unnecessary because
`MultiTrace._add_warnings()` was never called inbetween instantiation
and `_log_summary()`, therefore the traces never contained warnings.

Running convergence checks and logging the warnings can also be done
without needing `MultiTrace` or `SamplerReport` instances/methods.
@michaelosthege michaelosthege force-pushed the decouple-convergence-checking branch from 6c2f7f2 to 49f5263 Compare January 14, 2023 18:46
* Specify covariant input types in `StatsBijection`.
* Annotate `_choose_chains` to be independent of `BaseTrace` type.
@OriolAbril
Copy link
Member

I don't think I am qualified to review this

@michaelosthege
Copy link
Member Author

I don't think I am qualified to review this

I should have added comments to the diff earlier..

GitHub suggested you because you edited the SMC code? Who else is familiar with it?

Comment on lines +524 to +527
S = TypeVar("S", bound=Sized)


def _choose_chains(traces: Sequence[S], tune: int) -> Tuple[List[S], int]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This annotates it as returning a list of the same type of items as given in the input, but with the constraint that these items must be Sized.

@@ -602,7 +606,6 @@ def sample(
f"({n_tune*n_chains:_d} + {n_draws*n_chains:_d} draws total) "
f"took {t_sampling:.0f} seconds."
)
mtrace.report._log_summary()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inbetween the line 574 mtrace = MultiTrace(traces)[:length] where the MultiTrace was created, no warnings were added to mtrace.
Therefore, there are no warnings to log and the _log_summary() call can safely be removed.

Comment on lines -616 to -619
warnings.warn(
"The number of samples is too small to check convergence reliably.",
stacklevel=2,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now checked by run_convergence_checks, just like it already checked for a minimum number of chains

Comment on lines -929 to -930
multitrace = MultiTrace(traces)
multitrace._report._log_summary()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too: The multitrace can not have warnings that would be printed by _log_summary() because none were added here or in its __init__

Comment on lines +241 to +245
if idata is None:
idata = to_inference_data(trace, log_likelihood=False)
warns = run_convergence_checks(idata, model)
trace.report._add_warnings(warns)
log_warnings(warns)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replaces the _compute_convergence_checks function and makes the trace.report be a dead end that can easily be removed in the future

Remember from other changes:

  • "number of samples is too small" warning now done by run_convergence_checks
  • report._add_warnings was done inside report._run_convergence_checks
  • trace.report._log_summary() internally called log_warnings()

@@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
class StatsBijection:
"""Map between a `list` of stats to `dict` of stats."""

def __init__(self, sampler_stats_dtypes: Sequence[Dict[str, type]]) -> None:
def __init__(self, sampler_stats_dtypes: Sequence[Mapping[str, type]]) -> None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing rule of thumb: Generic input types, exact output types.

@OriolAbril
Copy link
Member

GitHub suggested you because you edited the SMC code? Who else is familiar with it?

I have only modified some docstrings 😅. @aloctavodia is the best choice I think

Copy link
Member

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@aloctavodia aloctavodia merged commit 5802f12 into pymc-devs:main Jan 20, 2023
@michaelosthege michaelosthege deleted the decouple-convergence-checking branch January 20, 2023 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
SMC Sequential Monte Carlo trace-backend Traces and ArviZ stuff
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants