Skip to content

Give user an option to use SimpleVarInfo with sample function #606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from

Conversation

sunxd3
Copy link
Member

@sunxd3 sunxd3 commented May 15, 2024

Partially address TuringLang/Turing.jl#2213

An example

julia> using AbstractMCMC, DynamicPPL
[...]

julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#289" = DynamicPPL.TypeWrap{Vector{Float64}}()), NamedTuple(), DefaultContext())

julia> chn = sample(model, SampleFromUniform(), 10; trace_type = SimpleVarInfo)
10-element Vector{SimpleVarInfo{OrderedDict{Any, Any}, Float64, DynamicPPL.NoTransformation}}:
 SimpleVarInfo(OrderedDict{Any, Any}(s[1] => 0.18821524284155777, s[2] => 0.33957731437677985, m[1] => 0.027047762098432387, m[2] => -0.3396883816169604), -27.049779424102084)
 ...

Work with Turing

This should work with Turing's Inference pipeline with almost no modification, the only change is https://github.com/TuringLang/Turing.jl/blob/56f64ec5909cec4a5ded4e28555c2b289020bbe1/src/mcmc/Inference.jl#L319 to

function getparams(model::DynamicPPL.Model, vi::Union{DynamicPPL.VarInfo, DynamicPPL.SimpleVarInfo})

This allows bundle_samples to use this function.

Then

julia> AbstractMCMC.step(Random.default_rng(), model, DynamicPPL.Sampler(HMC(0.2, 20), DynamicPPL.Selector()); trace_type = SimpleVarInfo)
(Turing.Inference.Transition{Vector{Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}}, Float64, @NamedTuple{n_steps::Int64, is_accept::Bool, acceptance_rate::Float64, log_density::Float64, hamiltonian_energy::Float64, hamiltonian_energy_error::Float64, numerical_error::Bool, step_size::Float64, nom_step_size::Float64}}(Tuple{AbstractPPL.VarName{sym, Accessors.IndexLens{Tuple{Int64}}} where sym, Float64}[(s[1], 1.6822421472438154), (s[2], 0.8921514354736135), (m[1], -0.1272569385613846), (m[2], 0.8103126419880976)], -7.598386060870171, (n_steps = 20, is_accept = true, acceptance_rate = 1.0, log_density = -7.598386060870171, hamiltonian_energy = 9.707595087582115, hamiltonian_energy_error = -0.0094109681431096, numerical_error = true, step_size = 0.2, nom_step_size = 0.2)), Turing.Inference.HMCState{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}, AdvancedHMC.Hamiltonian{AdvancedHMC.UnitEuclideanMetric{Float64, Tuple{Int64}}, AdvancedHMC.GaussianKinetic, Base.Fix1{typeof(LogDensityProblems.logdensity), LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}, Turing.Inference.var"#∂logπ∂θ#32"{LogDensityProblemsADForwardDiffExt.ForwardDiffLogDensity{LogDensityFunction{DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Any}, Float64, DynamicPPL.DynamicTransformation}, DynamicPPL.Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#289")), (), (), Tuple{Vector{Float64}, DynamicPPL.TypeWrap{Vector{Float64}}}, Tuple{}, DynamicPPL.DefaultContext}, DynamicPPL.SamplingContext{DynamicPPL.Sampler{HMC{AutoForwardDiff{nothing, Nothing}, (), AdvancedHMC.UnitEuclideanMetric}}, DynamicPPL.DefaultContext, TaskLocalRNG}}, ForwardDiff.Chunk{4}, ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, ForwardDiff.GradientConfig{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag, Float64}, Float64, 4}}}}}}, AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}, AdvancedHMC.Adaptation.NoAdaptation}(DynamicPPL.SimpleVarInfo{OrderedCollections.OrderedDict{Any, Float64}, Float64, DynamicPPL.DynamicTransformation}(OrderedCollections.OrderedDict{Any, Float64}(s[1] => 0.5201275150675577, s[2] => -0.11411939010121486, m[1] => -0.1272569385613846, m[2] => 0.8103126419880976), -7.598386060870171, DynamicPPL.DynamicTransformation()), 1, AdvancedHMC.HMCKernel{AdvancedHMC.FullMomentumRefreshment, AdvancedHMC.Trajectory{AdvancedHMC.EndPointTS, AdvancedHMC.Leapfrog{Float64}, AdvancedHMC.FixedNSteps}}(AdvancedHMC.FullMomentumRefreshment(), Trajectory{AdvancedHMC.EndPointTS}(integrator=Leapfrog(ϵ=0.2), tc=AdvancedHMC.FixedNSteps(20))), Hamiltonian(metric=UnitEuclideanMetric([1.0, 1.0, 1.0, 1.0]), kinetic=AdvancedHMC.GaussianKinetic()), AdvancedHMC.PhasePoint{Vector{Float64}, AdvancedHMC.DualValue{Float64, Vector{Float64}}}([0.5201275150675577, -0.11411939010121486, -0.1272569385613846, 0.8103126419880976], [0.2195756027868936, 2.0379584520488434, 0.11023899654296329, 0.06911815570849916], AdvancedHMC.DualValue{Float64, Vector{Float64}}(-7.598386060870171, [0.4248179767985423, -1.5238746846234323, -1.042961549856163, -0.4252357850238818]), AdvancedHMC.DualValue{Float64, Vector{Float64}}(-2.109209026711945, [-0.2195756027868936, -2.0379584520488434, -0.11023899654296329, -0.06911815570849916])), AdvancedHMC.Adaptation.NoAdaptation()))

julia> chn = sample(model, HMC(0.2, 20), 10; trace_type = SimpleVarInfo)
Chains MCMC chain (10×14×1 Array{Float64, 3}):

Iterations        = 1:1:10
Number of chains  = 1
Samples per chain = 10
Wall duration     = 2.04 seconds
Compute duration  = 2.04 seconds
parameters        = s[1], s[2], m[1], m[2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail      rhat   ess_per_sec 
      Symbol   Float64   Float64   Float64    Float64    Float64   Float64       Float64 

        s[1]    1.6551    0.9727    0.3219     8.1555    10.0000    1.6048        3.9939
        s[2]    0.9875    0.1954    0.0942     4.5670    10.0000    1.5820        2.2365
        m[1]    0.7053    1.0723    0.4889     5.0832    10.0000    1.4156        2.4893
        m[2]    1.1302    0.5511    0.1743    10.0000    10.0000    0.9406        4.8972

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        s[1]    0.6404    0.9317    1.5316    2.0851    3.5047
        s[2]    0.7580    0.8095    1.0099    1.1126    1.3119
        m[1]   -0.9806   -0.1058    1.0414    1.3320    2.2166
        m[2]    0.3819    0.7973    1.0316    1.6741    1.8228

@sunxd3 sunxd3 marked this pull request as draft May 15, 2024 08:06
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@sunxd3
Copy link
Member Author

sunxd3 commented May 15, 2024

An alternative approach I can envision is fully adopting LogDensityFunction in Turing through the ExternalSampler interface, but this might requires much more serious work encapsulating the InferenceAlgorithms. (Do we have plan to do this during the coming months?)

Also for SimpleVarInfo, I opted in OrderedDict. To use NamedTuple, it requires to predetermine the variable names (ref

julia> # (×) If we don't provide the container...
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo(), ctx); vi
ERROR: type NamedTuple has no field x
[...]
), in principle can be done, but need a bit of work.
This also means, when SimpleVarInfo through the changes proposed in this PR may be less performant.

@sunxd3
Copy link
Member Author

sunxd3 commented May 15, 2024

@torfjelde @yebai @devmotion does this PR make sense? If this is desirable, then I'll extend tests in https://github.com/TuringLang/DynamicPPL.jl/blob/master/test/sampler.jl.

@coveralls
Copy link

Pull Request Test Coverage Report for Build 9092154133

Details

  • 9 of 17 (52.94%) changed or added relevant lines in 2 files are covered.
  • 2 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-1.0%) to 77.471%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/simple_varinfo.jl 0 2 0.0%
src/sampler.jl 9 15 60.0%
Files with Coverage Reduction New Missed Lines %
src/sampler.jl 2 84.75%
Totals Coverage Status
Change from base Build 9062140401: -1.0%
Covered Lines: 2775
Relevant Lines: 3582

💛 - Coveralls

@coveralls
Copy link

coveralls commented May 15, 2024

Pull Request Test Coverage Report for Build 9116442883

Details

  • 17 of 20 (85.0%) changed or added relevant lines in 2 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage increased (+0.03%) to 77.596%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/simple_varinfo.jl 0 1 0.0%
src/sampler.jl 17 19 89.47%
Totals Coverage Status
Change from base Build 9099752668: 0.03%
Covered Lines: 2660
Relevant Lines: 3428

💛 - Coveralls

@yebai
Copy link
Member

yebai commented May 15, 2024

cc @willtebbutt

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, let's push this through @sunxd3!

sunxd3 and others added 7 commits May 16, 2024 11:46
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@yebai
Copy link
Member

yebai commented May 16, 2024

@willtebbutt can you help review this PR?

@yebai yebai requested a review from willtebbutt May 16, 2024 11:54
@yebai yebai marked this pull request as ready for review May 16, 2024 11:54
@torfjelde
Copy link
Member

Having a look @sunxd3 👍

end
if tracetype === VarInfo
# Sample initial values.
vi = default_varinfo(rng, model, spl)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't quite make sense because default_varinfo can also return SimpleVarInfo

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is cool -- thanks for doing this.

More feedback:

  • the tests introduce a lot of copy + paste. Could you please reformat to avoid the copy + paste by using the looping option available for @testset? i.e. apply to same set of tests to VarInfo and SimpleVarInfo.
  • please bump the patch version so that we can tag a release :)

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What it is the purpose of using SimpleVarInfo{<:AbstractDict} vs. VarInfo? I can see a clear purpose if we support SimpleVarInfo{<:NamedTuple}, but this is not the case here.

IMO, for it to make sense to expose SimpleVarInfo, we need to make it possible to use the NamedTuple version, which can provide performance improvements over VarInfo, but the AbstractDict version will be much, much worse than the current VarInfo is performance-wise.

This should work with Turing's Inference pipeline with almost no modification

If you do a search for VarInfo throughout the Turing codebase, you'll find several explicit mentions of it, e.g. in Gibbs or MH, so it will require some effort.

Note that the default_varinfo used in the initial step was meant to be a way of exposing this experimentally (you can overload it for a specific model + sampler combo), but I do agree that this is annoying vs. simply passing it as an argument.

sunxd3 and others added 3 commits May 16, 2024 18:11
Co-authored-by: Will Tebbutt <wt0881@my.bristol.ac.uk>
@sunxd3
Copy link
Member Author

sunxd3 commented May 16, 2024

@torfjelde is my understanding of SimpleVarInfo with NamedTuple correct at #606 (comment)?

@sunxd3
Copy link
Member Author

sunxd3 commented May 16, 2024

we need to make it possible to use the NamedTuple version

I agree.

Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong), but I am unsure if AbstractDict version of SimpleVarInfo works better than VarInfo.

@torfjelde
Copy link
Member

s my understanding of SimpleVarInfo with NamedTuple correct at

Yep! If you don't "seed" SimpleVarInfo{<:NamedTuple} with the correct values, then it will only be sensible for models containing only varnames of the form VarName{sym,typeof(identity)}. Now that we have "debugging" capabilities, we could "check" this, but that would ofc not be 100% reliable.

Other than performance, I thought SimpleVarInfo is also less error-prone for AD (correct me if wrong)

Not really. Or rather, I don't think VarInfo is particuarly error-prone for AD either. SimpleVarInfo{<:NamedTuple} might be for some of the more recent AD backends, e.g. Tapir.jl and Enzyme.jl, but only because it improves type-stability (which is not hte case of SimpleVarInfo{<:Dict})

@sunxd3
Copy link
Member Author

sunxd3 commented May 17, 2024

Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs

@torfjelde
Copy link
Member

Let me look into it and see if I can make NamedTuple variant of SimpleVarInfo work, or at least a clear TODOs

Lovely:) And I don't mean to be negative about this btw. I'm just not a huge fan of adding additional kwargs, etc. unless there's a clear reason to use them (because otherwise nobody will ever use them until they suddenly are riddled with uncaught bugs because nobody uses it). So if we're going to add an additional kwarg to use SimpleVarInfo, we should simultaneously prove that it has utility, i.e.:

  1. Make it work with SimpleVarInfo{<:NamedTuple}, because that definitively has utility.
  2. Test it properly in Turing.jl and make sure that indeed SimpleVarInfo is used, and nowhere do we implicitly convert to VarInfo before we merge this into DynamicPPL.jl.

@sunxd3
Copy link
Member Author

sunxd3 commented May 17, 2024

Sounds good. I started this as a quick and dirty prototype, it definitely needs more work, until we can justify complicating the interface.

@yebai yebai closed this Aug 6, 2024
@yebai yebai deleted the sunxd/add_simiplevarinfo_option branch August 6, 2024 18:28
@yebai
Copy link
Member

yebai commented Aug 6, 2024

This is no longer necessary since Tapir now works with all tracing types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants