Skip to content

Give user an option to use SimpleVarInfo with sample function #606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
65 changes: 48 additions & 17 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,17 @@ function AbstractMCMC.step(
model::Model,
sampler::Union{SampleFromUniform,SampleFromPrior},
state=nothing;
trace_type=VarInfo,
kwargs...,
)
vi = VarInfo()
model(rng, vi, sampler)
if trace_type === VarInfo
vi = VarInfo()
model(rng, vi, sampler)
elseif trace_type === SimpleVarInfo
vi = last(evaluate!!(model, rng, SimpleVarInfo{Float64}(OrderedDict()), sampler))
else
throw(ArgumentError("Unknown trace type: $trace_type"))
end
return vi, nothing
end

Expand Down Expand Up @@ -97,23 +104,47 @@ end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
initial_params=nothing,
trace_type=VarInfo,
kwargs...,
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(evaluate!!(model, vi, DefaultContext()))
end
if trace_type === VarInfo
# Sample initial values.
vi = default_varinfo(rng, model, spl)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't quite make sense because default_varinfo can also return SimpleVarInfo


# Update the parameters if provided.
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
# and https://github.com/TuringLang/Turing.jl/issues/1563
# to avoid that existing variables are resampled
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
elseif trace_type === SimpleVarInfo
vi = last(
DynamicPPL.evaluate!!(
model,
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
),
)

if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; initial_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
else
throw(ArgumentError("Unknown trace type: $trace_type"))
end
end

"""
Expand Down
2 changes: 2 additions & 0 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,8 @@ function dot_assume(
return value, lp, vi
end

updategid!(vi::SimpleOrThreadSafeSimple, vn::VarName, spl::Sampler) = nothing

# NOTE: We don't implement `settrans!!(vi, trans, vn)`.
function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
Expand Down
151 changes: 140 additions & 11 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,44 +11,63 @@
N = 1_000

chains = sample(model, SampleFromPrior(), N; progress=false)
chains_svi = sample(
model, SampleFromPrior(), N; progress=false, trace_type=SimpleVarInfo
)
@test chains isa Vector{<:VarInfo}
@test length(chains) == N
@test chains_svi isa Vector{<:SimpleVarInfo}
@test length(chains_svi) == N

# Expected value of ``X`` where ``X ~ N(2, ...)`` is 2.
@test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.15
@test mean(vi[@varname(m)] for vi in chains_svi) ≈ 2 atol = 0.15

# Expected value of ``X`` where ``X ~ IG(2, 3)`` is 3.
@test mean(vi[@varname(s)] for vi in chains) ≈ 3 atol = 0.2
@test mean(vi[@varname(s)] for vi in chains_svi) ≈ 3 atol = 0.2

chains = sample(model, SampleFromUniform(), N; progress=false)
chains_svi = sample(
model, SampleFromUniform(), N; progress=false, trace_type=SimpleVarInfo
)
@test chains isa Vector{<:VarInfo}
@test length(chains) == N
@test chains_svi isa Vector{<:SimpleVarInfo}
@test length(chains_svi) == N

# `m` is Gaussian, i.e. no transformation is used, so it
# should have a mean equal to its prior, i.e. 2.
@test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1
@test mean(vi[@varname(m)] for vi in chains_svi) ≈ 2 atol = 0.1

# Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8.
@test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1
@test mean(vi[@varname(s)] for vi in chains_svi) ≈ 1.8 atol = 0.1
end

@testset "init" begin
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
N = 1000
chain_init = sample(model, SampleFromUniform(), N; progress=false)
chain_init_svi = sample(
model, SampleFromUniform(), N; progress=false, trace_type=SimpleVarInfo
)

for vn in keys(first(chain_init))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11
else
error("Unknown variable name: $vn")
for chain in (chain_init, chain_init_svi)
for vn in keys(first(chain))
if AbstractPPL.subsumes(@varname(s), vn)
# `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2].
dist = InverseGamma(2, 3)
b = DynamicPPL.link_transform(dist)
@test mean(mean(b(vi[vn])) for vi in chain) ≈ 0 atol = 0.11
elseif AbstractPPL.subsumes(@varname(m), vn)
# `m ~ Normal(0, sqrt(s))` and its constrained value is the same.
@test mean(mean(vi[vn]) for vi in chain) ≈ 0 atol = 0.11
else
error("Unknown variable name: $vn")
end
end
end
end
Expand Down Expand Up @@ -85,8 +104,18 @@
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
chain_svi = sample(
model,
sampler,
1;
initial_params=0.2,
progress=false,
trace_type=SimpleVarInfo,
)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue
@test chain_svi[1][@varname(p)] == 0.2
@test getlogp(chain_svi[1]) == lptrue

# parallel sampling
chains = sample(
Expand All @@ -103,6 +132,21 @@
@test getlogp(c[1]) == lptrue
end

chains_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=fill(0.2, 10),
progress=false,
trace_type=SimpleVarInfo,
)
for c in chains_svi
@test c[1][@varname(p)] == 0.2
@test getlogp(c[1]) == lptrue
end

# model with two variables: initialization s = 4, m = -1
@model function twovars()
s ~ InverseGamma(2, 3)
Expand All @@ -114,6 +158,17 @@
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
chain_svi = sample(
model,
sampler,
1;
initial_params=[4, -1],
progress=false,
trace_type=SimpleVarInfo,
)
@test chain_svi[1][@varname(s)] == 4
@test chain_svi[1][@varname(m)] == -1
@test getlogp(chain_svi[1]) == lptrue

# parallel sampling
chains = sample(
Expand All @@ -131,10 +186,36 @@
@test getlogp(c[1]) == lptrue
end

chains_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=fill([4, -1], 10),
progress=false,
trace_type=SimpleVarInfo,
)
for c in chains_svi
@test c[1][@varname(s)] == 4
@test c[1][@varname(m)] == -1
@test getlogp(c[1]) == lptrue
end

# set only m = -1
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]
chain_svi = sample(
model,
sampler,
1;
initial_params=[missing, -1],
progress=false,
trace_type=SimpleVarInfo,
)
@test !ismissing(chain_svi[1][@varname(s)])
@test chain_svi[1][@varname(m)] == -1

# parallel sampling
chains = sample(
Expand All @@ -150,26 +231,74 @@
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end
chains_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=fill([missing, -1], 10),
progress=false,
trace_type=SimpleVarInfo,
)
for c in chains_svi
@test !ismissing(c[1][@varname(s)])
@test c[1][@varname(m)] == -1
end

# specify `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(model, sampler, 1; progress=false)
chain1_svi = sample(model, sampler, 1; progress=false, trace_type=SimpleVarInfo)
Random.seed!(1234)
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
chain2_svi = sample(
model,
sampler,
1;
initial_params=nothing,
progress=false,
trace_type=SimpleVarInfo,
)
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals
@test chain1_svi[1][@varname(m)] == chain2_svi[1][@varname(m)]
@test chain1_svi[1][@varname(s)] == chain2_svi[1][@varname(s)]

# parallel sampling
Random.seed!(1234)
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
chains1_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
progress=false,
trace_type=SimpleVarInfo,
)
Random.seed!(1234)
chains2 = sample(
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
)
chains2_svi = sample(
model,
sampler,
MCMCThreads(),
1,
10;
initial_params=nothing,
progress=false,
trace_type=SimpleVarInfo,
)
for (c1, c2) in zip(chains1, chains2)
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
@test c1[1].metadata.s.vals == c2[1].metadata.s.vals
end
for (c1, c2) in zip(chains1_svi, chains2_svi)
@test c1[1][@varname(m)] == c2[1][@varname(m)]
@test c1[1][@varname(s)] == c2[1][@varname(s)]
end
end
end
end
Loading