Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add InferenceObjects as a chain_type #1913

Closed
wants to merge 3 commits into from
Closed

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented Nov 24, 2022

This PR is a prototype of adding InferenceObjects.InferenceData as a chain_type for sample. Here's a working example:

julia> using Turing, InferenceObjects

julia> @model function foo()
           x ~ Normal()
           y ~ filldist(Normal(), 2)
           z = similar(y, 2, 3)
           for i in axes(z, 1), j in axes(z, 2)
               z[i, j] ~ Normal()
           end
       end;

julia> idata_prior = sample(foo(), Prior(), 1_000; chain_type=InferenceData)
Sampling 100%|█████████████████████████████████████████████████████████████████| Time: 0:00:00
InferenceData with groups:
  > prior
  > sample_stats_prior

julia> idata_post = sample(foo(), NUTS(), MCMCThreads(), 1_000, 4; chain_type=InferenceData)
┌ Info: Found initial step size
└   ϵ = 0.9
┌ Info: Found initial step size
└   ϵ = 3.2
┌ Info: Found initial step size
└   ϵ = 1.6
┌ Info: Found initial step size
└   ϵ = 1.6
Sampling (4 threads) 100%|█████████████████████████████████████████████████████| Time: 0:00:02
InferenceData with groups:
  > posterior
  > sample_stats

julia> idata = merge(idata_post, idata_prior)
InferenceData with groups:
  > posterior
  > sample_stats
  > prior
  > sample_stats_prior

julia> idata.posterior
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points,
  Dim{:y_dim_1} Sampled{Int64} Base.OneTo(2) ForwardOrdered Regular Points,
  Dim{:z_dim_1} Sampled{Int64} Base.OneTo(6) ForwardOrdered Regular Points
and 3 layers:
  :x Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :y Float64 dims: Dim{:y_dim_1}, Dim{:draw}, Dim{:chain} (2×1000×4)
  :z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (6×1000×4)

with metadata Dict{String, Any} with 3 entries:
  "start_time" => 1.66929e9
  "created_at" => "2022-11-24T11:29:24.947"
  "stop_time"  => 1.66929e9

julia> idata.sample_stats
Dataset with dimensions: 
  Dim{:draw} Sampled{Int64} Base.OneTo(1000) ForwardOrdered Regular Points,
  Dim{:chain} Sampled{Int64} Base.OneTo(4) ForwardOrdered Regular Points
and 12 layers:
  :lp               Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :n_steps          Int64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :is_accept        Bool dims: Dim{:draw}, Dim{:chain} (1000×4)
  :acceptance_rate  Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :log_density      Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :energy           Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :energy_error     Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :max_energy_error Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :tree_depth       Int64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :diverging        Bool dims: Dim{:draw}, Dim{:chain} (1000×4)
  :step_size        Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)
  :step_size_nom    Float64 dims: Dim{:draw}, Dim{:chain} (1000×4)

with metadata Dict{String, Any} with 3 entries:
  "start_time" => 1.66929e9
  "created_at" => "2022-11-24T11:29:24.947"
  "stop_time"  => 1.66929e9

Note that all of this is type piracy, and it would probably be much cleaner to move this code into a DynamicPPLInferenceObjects.jl glue package that Turing instead depends on. However, we currently need the metadata function in order to retrieve sampling statistics and logevidence. Is there any way that functionality could be moved to DynamicPPL?

Relates TuringLang/MCMCChains.jl#381

@sethaxen
Copy link
Member Author

sethaxen commented Nov 24, 2022

Also, the getparams functionality would need to be in DynamicPPL or AbstractMCMC.

@sethaxen
Copy link
Member Author

sethaxen commented Nov 24, 2022

  :z Float64 dims: Dim{:z_dim_1}, Dim{:draw}, Dim{:chain} (6×1000×4)

It seems getparams still flattens z. Is there a way to access an unflattened form?

Edit: Ah, this is because DynamicPPL.tonamedtuple does the flattening. Same question but for that function.

Edit2: Ah, but when sampling e.g. with NUTS, ts contains Transitions, which are already flattened. So we're back to how to get an unflattened form.

@coveralls
Copy link

coveralls commented Nov 24, 2022

Pull Request Test Coverage Report for Build 5708532201

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 0 of 31 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage remained the same at 0.0%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/inference/Inference.jl 0 31 0.0%
Totals Coverage Status
Change from base Build 5708133281: 0.0%
Covered Lines: 0
Relevant Lines: 1483

💛 - Coveralls

@codecov
Copy link

codecov bot commented Nov 24, 2022

Codecov Report

Patch and project coverage have no change.

Comparison is base (161d9bc) 0.00% compared to head (8024f24) 0.00%.
Report is 1 commits behind head on master.

Additional details and impacted files
@@          Coverage Diff           @@
##           master   #1913   +/-   ##
======================================
  Coverage    0.00%   0.00%           
======================================
  Files          22      22           
  Lines        1452    1483   +31     
======================================
- Misses       1452    1483   +31     
Files Changed Coverage Δ
src/inference/Inference.jl 0.00% <0.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sethaxen
Copy link
Member Author

Note that all of this is type piracy, and it would probably be much cleaner to move this code into a DynamicPPLInferenceObjects.jl glue package that Turing instead depends on. However, we currently need the metadata function in order to retrieve sampling statistics and logevidence.

For demonstration purposes, I've coded up a more complete example of what I mean at https://github.com/sethaxen/DynamicPPLInferenceObjects.jl

@yebai
Copy link
Member

yebai commented Jul 30, 2023

It might be reasonable to convert this into an extension that weakly depends on InferenceObjects.

It's a useful functionality, in my view. I am happy to merge this PR quickly.

@sethaxen
Copy link
Member Author

Ah, this has evolved substantially since this PR. TuringLang/DynamicPPL.jl#465 supersedes this PR by adding an InferenceObjects extension to DynamicPPL. That was held up by trying to get predict as part of DPPL or APPL's API (TuringLang/DynamicPPL.jl#466 and TuringLang/AbstractPPL.jl#81). Personally I agree with TuringLang/AbstractPPL.jl#81 (review) that it makes more sense to add it to DPPL.

I do think that work can be done fairly quickly, but I need to wrap up a few things first before resuming.

@yebai
Copy link
Member

yebai commented Jul 31, 2023

Closed in favour of TuringLang/DynamicPPL.jl#465

@yebai yebai closed this Jul 31, 2023
@yebai yebai deleted the inferenceobjects branch July 31, 2023 11:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants