diff --git a/src/sampler.jl b/src/sampler.jl index b2fc6f4ec..cfc58942e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -142,38 +142,54 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). """ initialsampler(spl::Sampler) = SampleFromPrior() -function initialize_parameters!!( - vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model +function set_values!!( + varinfo::AbstractVarInfo, + initial_params::AbstractVector{<:Union{Real,Missing}}, + spl::AbstractSampler, ) - @debug "Using passed-in initial variable values" initial_params - - # Flatten parameters. - init_theta = mapreduce(vcat, initial_params) do x - vec([x;]) - end - - # Get all values. - linked = islinked(vi, spl) - if linked - vi = invlink!!(vi, spl, model) - end - theta = vi[spl] - length(theta) == length(init_theta) || throw( + flattened_param_vals = varinfo[spl] + length(flattened_param_vals) == length(initial_params) || throw( DimensionMismatch( - "Provided initial value size ($(length(init_theta))) doesn't match the model size ($(length(theta)))", + "Provided initial value size ($(length(initial_params))) doesn't match the model size ($(length(theta)))", ), ) # Update values that are provided. - for i in eachindex(init_theta) - x = init_theta[i] + for i in eachindex(initial_params) + x = initial_params[i] if x !== missing - theta[i] = x + flattened_param_vals[i] = x end end - # Update in `vi`. - vi = setindex!!(vi, theta, spl) + # Update in `varinfo`. + return setindex!!(varinfo, flattened_param_vals, spl) +end + +function set_values!!( + varinfo::AbstractVarInfo, initial_params::NamedTuple, spl::AbstractSampler +) + initial_params = NamedTuple(k => v for (k, v) in pairs(initial_params) if v !== missing) + return update_values!!( + varinfo, initial_params, map(k -> VarName{k}(), keys(initial_params)) + ) +end + +function initialize_parameters!!( + vi::AbstractVarInfo, initial_params, spl::AbstractSampler, model::Model +) + @debug "Using passed-in initial variable values" initial_params + + # `link` the varinfo if needed. + linked = islinked(vi, spl) + if linked + vi = invlink!!(vi, spl, model) + end + + # Set the values in `vi`. + vi = set_values!!(vi, initial_params, spl) + + # `invlink` if needed. if linked vi = link!!(vi, spl, model) end diff --git a/src/test_utils.jl b/src/test_utils.jl index 72ccf6e4f..bf7be0a9a 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,19 +11,7 @@ using Bijectors: Bijectors using Accessors: Accessors # For backwards compat. -using DynamicPPL: varname_leaves - -""" - update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) - -Return instance similar to `vi` but with `vns` set to values from `vals`. -""" -function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) - for vn in vns - vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) - end - return vi -end +using DynamicPPL: varname_leaves, update_values!! """ test_values(vi::AbstractVarInfo, vals::NamedTuple, vns) diff --git a/src/utils.jl b/src/utils.jl index 9493e1bc9..4bf652363 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -796,6 +796,18 @@ function nested_getindex(values::AbstractDict, vn::VarName) return child(value) end +""" + update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) + +Return instance similar to `vi` but with `vns` set to values from `vals`. +""" +function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) + for vn in vns + vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn) + end + return vi +end + """ float_type_with_fallback(x) diff --git a/src/varinfo.jl b/src/varinfo.jl index 68d36141e..903789325 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -892,6 +892,12 @@ Base.keys(vi::TypedVarInfo{<:NamedTuple{()}}) = VarName[] return expr end +# FIXME(torfjelde): Don't use `_getvns`. +Base.keys(vi::UntypedVarInfo, spl::AbstractSampler) = _getvns(vi, spl) +function Base.keys(vi::TypedVarInfo, spl::AbstractSampler) + return mapreduce(values, vcat, _getvns(vi, spl)) +end + """ setgid!(vi::VarInfo, gid::Selector, vn::VarName) diff --git a/test/sampler.jl b/test/sampler.jl index b52a9c921..b29d3caf1 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,23 +84,25 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - chain = sample(model, sampler, 1; initial_params=0.2, progress=false) - @test chain[1].metadata.p.vals == [0.2] - @test getlogp(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill(0.2, 10), - progress=false, - ) - for c in chains - @test c[1].metadata.p.vals == [0.2] - @test getlogp(c[1]) == lptrue + let inits = (; p=0.2) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) + @test chain[1].metadata.p.vals == [0.2] + @test getlogp(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.p.vals == [0.2] + @test getlogp(c[1]) == lptrue + end end # model with two variables: initialization s = 4, m = -1 @@ -110,45 +112,49 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false) - @test chain[1].metadata.s.vals == [4] - @test chain[1].metadata.m.vals == [-1] - @test getlogp(chain[1]) == lptrue - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill([4, -1], 10), - progress=false, - ) - for c in chains - @test c[1].metadata.s.vals == [4] - @test c[1].metadata.m.vals == [-1] - @test getlogp(c[1]) == lptrue + for inits in ([4, -1], (; s=4, m=-1)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) + @test chain[1].metadata.s.vals == [4] + @test chain[1].metadata.m.vals == [-1] + @test getlogp(chain[1]) == lptrue + + # parallel sampling + chains = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test c[1].metadata.s.vals == [4] + @test c[1].metadata.m.vals == [-1] + @test getlogp(c[1]) == lptrue + end end # set only m = -1 - chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false) - @test !ismissing(chain[1].metadata.s.vals[1]) - @test chain[1].metadata.m.vals == [-1] - - # parallel sampling - chains = sample( - model, - sampler, - MCMCThreads(), - 1, - 10; - initial_params=fill([missing, -1], 10), - progress=false, - ) - for c in chains - @test !ismissing(c[1].metadata.s.vals[1]) - @test c[1].metadata.m.vals == [-1] + for inits in ([missing, -1], (; s=missing, m=-1), (; m=-1)) + chain = sample(model, sampler, 1; initial_params=inits, progress=false) + @test !ismissing(chain[1].metadata.s.vals[1]) + @test chain[1].metadata.m.vals == [-1] + + # parallel sampling + chains = sample( + model, + sampler, + MCMCThreads(), + 1, + 10; + initial_params=fill(inits, 10), + progress=false, + ) + for c in chains + @test !ismissing(c[1].metadata.s.vals[1]) + @test c[1].metadata.m.vals == [-1] + end end # specify `initial_params=nothing`