Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InferenceObjects integration #465

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft

Add InferenceObjects integration #465

wants to merge 6 commits into from

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Feb 20, 2023

This PR adapts most of the code from https://github.com/sethaxen/DynamicPPLInferenceObjects.jl to be an extension for Julia v1.9 and later and a submodule for earlier Julia versions. Fixes #464


ndraws = size(data, :draw)
nchains = size(data, :chain)
# TODO: optionally post-process idata to convert index variables like Symbol("y[1]") to Symbol("y")
Copy link
Member Author

@sethaxen sethaxen Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty important for the results to be useful with ArviZ but is seemingly non-trivial so will wait for a future PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do actually have some functionality do perform this now:)

There is https://turinglang.org/DynamicPPL.jl/dev/api/#DynamicPPL.value_iterator_from_chain which makes use of this under the hood; in particular, to get the "innermost" VarName, you can use varname_leaves

DynamicPPL.jl/src/utils.jl

Lines 820 to 844 in 1ebe8bc

"""
varname_leaves(vn::VarName, val)
Return an iterator over all varnames that are represented by `vn` on `val`.
# Examples
```jldoctest
julia> using DynamicPPL: varname_leaves
julia> foreach(println, varname_leaves(@varname(x), rand(2)))
x[1]
x[2]
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2)))
x[1:2][1]
x[1:2][2]
julia> x = (y = 1, z = [[2.0], [3.0]]);
julia> foreach(println, varname_leaves(@varname(x), x))
x.y
x.z[1][1]
x.z[2][1]
```
"""

Using this you can take the varname from the varinfo + a value, and then determine the varname-leaves.

@@ -0,0 +1,10 @@
function AbstractPPL.condition(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole file is type-piracy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it rather be an extension to AbstractPPL? Then it would not be type piracy (or rather, only the one that extensions were designed for).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would then we also make InferenceObjects a full dependency of AbstractPPL for v1.8 and earlier?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AbstractPPL is supposed to be extremely lightweight (https://github.com/TuringLang/AbstractPPL.jl/blob/main/Project.toml), so I don't think that's an attractive option. Maybe an optional dependency with Requires or a full-blown subpackage would be better (one can avoid loading it in newer Julia versions).

@coveralls
Copy link

coveralls commented Feb 20, 2023

Pull Request Test Coverage Report for Build 4221259736

  • 0 of 87 (0.0%) changed or added relevant lines in 6 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-2.5%) to 71.908%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLInferenceObjectsExt/condition.jl 0 4 0.0%
ext/DynamicPPLInferenceObjectsExt/varinfo.jl 0 4 0.0%
ext/DynamicPPLInferenceObjectsExt/utils.jl 0 12 0.0%
ext/DynamicPPLInferenceObjectsExt/generated_quantities.jl 0 15 0.0%
ext/DynamicPPLInferenceObjectsExt/pointwise_loglikelihoods.jl 0 17 0.0%
ext/DynamicPPLInferenceObjectsExt/predict.jl 0 35 0.0%
Totals Coverage Status
Change from base Build 4137356954: -2.5%
Covered Lines: 1820
Relevant Lines: 2531

💛 - Coveralls

@codecov
Copy link

codecov bot commented Feb 20, 2023

Codecov Report

Base: 74.42% // Head: 71.86% // Decreases project coverage by -2.56% ⚠️

Coverage data is based on head (2f7b834) compared to base (8532375).
Patch coverage: 0.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #465      +/-   ##
==========================================
- Coverage   74.42%   71.86%   -2.56%     
==========================================
  Files          20       26       +6     
  Lines        2444     2531      +87     
==========================================
  Hits         1819     1819              
- Misses        625      712      +87     
Impacted Files Coverage Δ
ext/DynamicPPLInferenceObjectsExt/condition.jl 0.00% <0.00%> (ø)
...amicPPLInferenceObjectsExt/generated_quantities.jl 0.00% <0.00%> (ø)
...PPLInferenceObjectsExt/pointwise_loglikelihoods.jl 0.00% <0.00%> (ø)
ext/DynamicPPLInferenceObjectsExt/predict.jl 0.00% <0.00%> (ø)
ext/DynamicPPLInferenceObjectsExt/utils.jl 0.00% <0.00%> (ø)
ext/DynamicPPLInferenceObjectsExt/varinfo.jl 0.00% <0.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@@ -0,0 +1,72 @@
function StatsBase.predict(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predict is not part of the DynamicPPL API, isn't it? At least I don't remember it, I think we just use rand for obtaining samples from a model (and you can of course condition the model on some data before sampling). I think extensions should not (must not?) add new API to a package, so if it's doing the same as rand on an conditioned model maybe just implement rand instead? And open an issue about adding predict to the API (maybe it could be defined just as rand on a conditioned model)?

@@ -0,0 +1,28 @@
function DynamicPPL.generated_quantities(
Copy link
Member Author

@sethaxen sethaxen Feb 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since DynamicPPL places no restrictions on what types these can be, and users might have intermediate types that don't fit the InferenceData format, it would be nice to support users specifying an output type. Either that, or we should document the constraints upon the returned objects in a model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing in particular that should be possible for the user to specify somehow is the variables to include in the chain. Sometimes you might want to return something that you really don't want to end up in the chain, e.g. the full solution of a ODE solve (this can be useful for checking convergence as a post-processing step, but you usually don't want the full solution in your chain).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

InferenceObjects integration
5 participants