-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add InferenceObjects integration #465
base: master
Are you sure you want to change the base?
Conversation
|
||
ndraws = size(data, :draw) | ||
nchains = size(data, :chain) | ||
# TODO: optionally post-process idata to convert index variables like Symbol("y[1]") to Symbol("y") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty important for the results to be useful with ArviZ but is seemingly non-trivial so will wait for a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do actually have some functionality do perform this now:)
There is https://turinglang.org/DynamicPPL.jl/dev/api/#DynamicPPL.value_iterator_from_chain which makes use of this under the hood; in particular, to get the "innermost" VarName
, you can use varname_leaves
Lines 820 to 844 in 1ebe8bc
""" | |
varname_leaves(vn::VarName, val) | |
Return an iterator over all varnames that are represented by `vn` on `val`. | |
# Examples | |
```jldoctest | |
julia> using DynamicPPL: varname_leaves | |
julia> foreach(println, varname_leaves(@varname(x), rand(2))) | |
x[1] | |
x[2] | |
julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) | |
x[1:2][1] | |
x[1:2][2] | |
julia> x = (y = 1, z = [[2.0], [3.0]]); | |
julia> foreach(println, varname_leaves(@varname(x), x)) | |
x.y | |
x.z[1][1] | |
x.z[2][1] | |
``` | |
""" |
Using this you can take the varname
from the varinfo
+ a value, and then determine the varname-leaves.
@@ -0,0 +1,10 @@ | |||
function AbstractPPL.condition( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole file is type-piracy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it rather be an extension to AbstractPPL? Then it would not be type piracy (or rather, only the one that extensions were designed for).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would then we also make InferenceObjects a full dependency of AbstractPPL for v1.8 and earlier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AbstractPPL is supposed to be extremely lightweight (https://github.com/TuringLang/AbstractPPL.jl/blob/main/Project.toml), so I don't think that's an attractive option. Maybe an optional dependency with Requires or a full-blown subpackage would be better (one can avoid loading it in newer Julia versions).
Pull Request Test Coverage Report for Build 4221259736
💛 - Coveralls |
Codecov ReportBase: 74.42% // Head: 71.86% // Decreases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #465 +/- ##
==========================================
- Coverage 74.42% 71.86% -2.56%
==========================================
Files 20 26 +6
Lines 2444 2531 +87
==========================================
Hits 1819 1819
- Misses 625 712 +87
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
@@ -0,0 +1,72 @@ | |||
function StatsBase.predict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
predict
is not part of the DynamicPPL API, isn't it? At least I don't remember it, I think we just use rand
for obtaining samples from a model (and you can of course condition the model on some data before sampling). I think extensions should not (must not?) add new API to a package, so if it's doing the same as rand
on an conditioned model maybe just implement rand
instead? And open an issue about adding predict
to the API (maybe it could be defined just as rand
on a conditioned model)?
@@ -0,0 +1,28 @@ | |||
function DynamicPPL.generated_quantities( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since DynamicPPL places no restrictions on what types these can be, and users might have intermediate types that don't fit the InferenceData format, it would be nice to support users specifying an output type. Either that, or we should document the constraints upon the returned objects in a model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing in particular that should be possible for the user to specify somehow is the variables to include in the chain. Sometimes you might want to return
something that you really don't want to end up in the chain, e.g. the full solution of a ODE solve
(this can be useful for checking convergence as a post-processing step, but you usually don't want the full solution in your chain).
This PR adapts most of the code from https://github.com/sethaxen/DynamicPPLInferenceObjects.jl to be an extension for Julia v1.9 and later and a submodule for earlier Julia versions. Fixes #464