diff --git a/src/sampler.jl b/src/sampler.jl index b2fc6f4ec..5648b2400 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -60,10 +60,17 @@ function AbstractMCMC.step( model::Model, sampler::Union{SampleFromUniform,SampleFromPrior}, state=nothing; + trace_type=VarInfo, kwargs..., ) - vi = VarInfo() - model(rng, vi, sampler) + if trace_type === VarInfo + vi = VarInfo() + model(rng, vi, sampler) + elseif trace_type === SimpleVarInfo + vi = last(evaluate!!(model, rng, SimpleVarInfo{Float64}(OrderedDict()), sampler)) + else + throw(ArgumentError("Unknown trace type: $trace_type")) + end return vi, nothing end @@ -97,23 +104,47 @@ end # initial step: general interface for resuming and function AbstractMCMC.step( - rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... + rng::Random.AbstractRNG, + model::Model, + spl::Sampler; + initial_params=nothing, + trace_type=VarInfo, + kwargs..., ) - # Sample initial values. - vi = default_varinfo(rng, model, spl) - - # Update the parameters if provided. - if initial_params !== nothing - vi = initialize_parameters!!(vi, initial_params, spl, model) - - # Update joint log probability. - # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 - # and https://github.com/TuringLang/Turing.jl/issues/1563 - # to avoid that existing variables are resampled - vi = last(evaluate!!(model, vi, DefaultContext())) - end + if trace_type === VarInfo + # Sample initial values. + vi = default_varinfo(rng, model, spl) + + # Update the parameters if provided. + if initial_params !== nothing + vi = initialize_parameters!!(vi, initial_params, spl, model) + + # Update joint log probability. + # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 + # and https://github.com/TuringLang/Turing.jl/issues/1563 + # to avoid that existing variables are resampled + vi = last(evaluate!!(model, vi, DefaultContext())) + end + + return initialstep(rng, model, spl, vi; initial_params, kwargs...) + elseif trace_type === SimpleVarInfo + vi = last( + DynamicPPL.evaluate!!( + model, + SimpleVarInfo{Float64}(OrderedDict()), + SamplingContext(rng, SampleFromPrior(), DefaultContext()), + ), + ) + + if initial_params !== nothing + vi = initialize_parameters!!(vi, initial_params, spl, model) + vi = last(evaluate!!(model, vi, DefaultContext())) + end - return initialstep(rng, model, spl, vi; initial_params, kwargs...) + return initialstep(rng, model, spl, vi; initial_params, kwargs...) + else + throw(ArgumentError("Unknown trace type: $trace_type")) + end end """ diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a6b907701..5c3ce8651 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -540,6 +540,8 @@ function dot_assume( return value, lp, vi end +updategid!(vi::SimpleOrThreadSafeSimple, vn::VarName, spl::Sampler) = nothing + # NOTE: We don't implement `settrans!!(vi, trans, vn)`. function settrans!!(vi::SimpleVarInfo, trans) return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) diff --git a/test/sampler.jl b/test/sampler.jl index b52a9c921..336f89011 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -11,25 +11,39 @@ N = 1_000 chains = sample(model, SampleFromPrior(), N; progress=false) + chains_svi = sample( + model, SampleFromPrior(), N; progress=false, trace_type=SimpleVarInfo + ) @test chains isa Vector{<:VarInfo} @test length(chains) == N + @test chains_svi isa Vector{<:SimpleVarInfo} + @test length(chains_svi) == N # Expected value of ``X`` where ``X ~ N(2, ...)`` is 2. @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15 + @test mean(vi[@varname(m)] for vi in chains_svi) ≈ 2 atol = 0.15 # Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3. @test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2 + @test mean(vi[@varname(s)] for vi in chains_svi) ≈ 3 atol = 0.2 chains = sample(model, SampleFromUniform(), N; progress=false) + chains_svi = sample( + model, SampleFromUniform(), N; progress=false, trace_type=SimpleVarInfo + ) @test chains isa Vector{<:VarInfo} @test length(chains) == N + @test chains_svi isa Vector{<:SimpleVarInfo} + @test length(chains_svi) == N # `m` is Gaussian, i.e. no transformation is used, so it # should have a mean equal to its prior, i.e. 2. @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 + @test mean(vi[@varname(m)] for vi in chains_svi) ≈ 2 atol = 0.1 # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 + @test mean(vi[@varname(s)] for vi in chains_svi) ≈ 1.8 atol = 0.1 end @testset "init" begin @@ -37,18 +51,23 @@ @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS N = 1000 chain_init = sample(model, SampleFromUniform(), N; progress=false) + chain_init_svi = sample( + model, SampleFromUniform(), N; progress=false, trace_type=SimpleVarInfo + ) - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") + for chain in (chain_init, chain_init_svi) + for vn in keys(first(chain)) + if AbstractPPL.subsumes(@varname(s), vn) + # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. + dist = InverseGamma(2, 3) + b = DynamicPPL.link_transform(dist) + @test mean(mean(b(vi[vn])) for vi in chain) ≈ 0 atol = 0.11 + elseif AbstractPPL.subsumes(@varname(m), vn) + # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. + @test mean(mean(vi[vn]) for vi in chain) ≈ 0 atol = 0.11 + else + error("Unknown variable name: $vn") + end end end end @@ -85,8 +104,18 @@ sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) chain = sample(model, sampler, 1; initial_params=0.2, progress=false) + chain_svi = sample( + model, + sampler, + 1; + initial_params=0.2, + progress=false, + trace_type=SimpleVarInfo, + ) @test chain[1].metadata.p.vals == [0.2] @test getlogp(chain[1]) == lptrue + @test chain_svi[1][@varname(p)] == 0.2 + @test getlogp(chain_svi[1]) == lptrue # parallel sampling chains = sample( @@ -103,6 +132,21 @@ @test getlogp(c[1]) == lptrue end + chains_svi = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(0.2, 10), + progress=false, + trace_type=SimpleVarInfo, + ) + for c in chains_svi + @test c[1][@varname(p)] == 0.2 + @test getlogp(c[1]) == lptrue + end + # model with two variables: initialization s = 4, m = -1 @model function twovars() s ~ InverseGamma(2, 3) @@ -114,6 +158,17 @@ @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogp(chain[1]) == lptrue + chain_svi = sample( + model, + sampler, + 1; + initial_params=[4, -1], + progress=false, + trace_type=SimpleVarInfo, + ) + @test chain_svi[1][@varname(s)] == 4 + @test chain_svi[1][@varname(m)] == -1 + @test getlogp(chain_svi[1]) == lptrue # parallel sampling chains = sample( @@ -131,10 +186,36 @@ @test getlogp(c[1]) == lptrue end + chains_svi = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill([4, -1], 10), + progress=false, + trace_type=SimpleVarInfo, + ) + for c in chains_svi + @test c[1][@varname(s)] == 4 + @test c[1][@varname(m)] == -1 + @test getlogp(c[1]) == lptrue + end + # set only m = -1 chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] + chain_svi = sample( + model, + sampler, + 1; + initial_params=[missing, -1], + progress=false, + trace_type=SimpleVarInfo, + ) + @test !ismissing(chain_svi[1][@varname(s)]) + @test chain_svi[1][@varname(m)] == -1 # parallel sampling chains = sample( @@ -150,26 +231,74 @@ @test !ismissing(c[1].metadata.s.vals[1]) @test c[1].metadata.m.vals == [-1] end + chains_svi = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill([missing, -1], 10), + progress=false, + trace_type=SimpleVarInfo, + ) + for c in chains_svi + @test !ismissing(c[1][@varname(s)]) + @test c[1][@varname(m)] == -1 + end # specify `initial_params=nothing` Random.seed!(1234) chain1 = sample(model, sampler, 1; progress=false) + chain1_svi = sample(model, sampler, 1; progress=false, trace_type=SimpleVarInfo) Random.seed!(1234) chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) + chain2_svi = sample( + model, + sampler, + 1; + initial_params=nothing, + progress=false, + trace_type=SimpleVarInfo, + ) @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals + @test chain1_svi[1][@varname(m)] == chain2_svi[1][@varname(m)] + @test chain1_svi[1][@varname(s)] == chain2_svi[1][@varname(s)] # parallel sampling Random.seed!(1234) chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) + chains1_svi = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + progress=false, + trace_type=SimpleVarInfo, + ) Random.seed!(1234) chains2 = sample( model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false ) + chains2_svi = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=nothing, + progress=false, + trace_type=SimpleVarInfo, + ) for (c1, c2) in zip(chains1, chains2) @test c1[1].metadata.m.vals == c2[1].metadata.m.vals @test c1[1].metadata.s.vals == c2[1].metadata.s.vals end + for (c1, c2) in zip(chains1_svi, chains2_svi) + @test c1[1][@varname(m)] == c2[1][@varname(m)] + @test c1[1][@varname(s)] == c2[1][@varname(s)] + end end end end