diff --git a/Project.toml b/Project.toml index 8e2cb92e2..98459eb0b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.28.5" +version = "0.28.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index a6b907701..f4c84a5c3 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -429,22 +429,17 @@ function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) return Accessors.@set varinfo.values = _subset(varinfo.values, vns) end -function _subset(x::AbstractDict, vns) +function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} vns_present = collect(keys(x)) - vns_found = mapreduce(vcat, vns) do vn + vns_found = mapreduce(vcat, vns; init=VN[]) do vn return filter(Base.Fix1(subsumes, vn), vns_present) end - - # NOTE: This `vns` to be subsume varnames explicitly present in `x`. + C = ConstructionBase.constructorof(typeof(x)) if isempty(vns_found) - throw( - ArgumentError( - "Cannot subset `AbstractDict` with `VarName` which does not subsume any keys.", - ), - ) + return C() + else + return C(vn => x[vn] for vn in vns_found) end - C = ConstructionBase.constructorof(typeof(x)) - return C(vn => x[vn] for vn in vns_found) end function _subset(x::NamedTuple, vns) diff --git a/src/varinfo.jl b/src/varinfo.jl index 903789325..6278d260f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -264,20 +264,24 @@ function subset(varinfo::TypedVarInfo, vns::AbstractVector{<:VarName}) return VarInfo(NamedTuple{syms}(metadatas), varinfo.logp, varinfo.num_produce) end -function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) +function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName} # TODO: Should we error if `vns` contains a variable that is not in `metadata`? # For each `vn` in `vns`, get the variables subsumed by `vn`. - vns = mapreduce(vcat, vns_given) do vn + vns = mapreduce(vcat, vns_given; init=VN[]) do vn filter(Base.Fix1(subsumes, vn), metadata.vns) end indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) - indices = Dict(vn => i for (i, vn) in enumerate(vns)) + indices = if isempty(vns) + Dict{VarName,Int}() + else + Dict(vn => i for (i, vn) in enumerate(vns)) + end # Construct new `vals` and `ranges`. vals_original = metadata.vals ranges_original = metadata.ranges # Allocate the new `vals`. and `ranges`. - vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns])) - ranges = similar(ranges_original) + vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0)) + ranges = similar(ranges_original, length(vns)) # The new range `r` for `vns[i]` is offset by `offset` and # has the same length as the original range `r_original`. # The new `indices` (from above) ensures ordering according to `vns`. @@ -311,7 +315,7 @@ function subset(metadata::Metadata, vns_given::AbstractVector{<:VarName}) ranges, vals, metadata.dists[indices_for_vns], - metadata.gids, + metadata.gids[indices_for_vns], metadata.orders[indices_for_vns], flags, ) @@ -382,7 +386,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -412,6 +416,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) dist_right = getdist(metadata_right, vn) # Give precedence to `metadata_right`. push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders`: giving precedence to `metadata_right` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -431,6 +437,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_left = getdist(metadata_left, vn) push!(dists, dist_left) + gid = metadata_left.gids[getidx(metadata_left, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_left, vn)) # `flags` @@ -449,6 +457,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_right = getdist(metadata_right, vn) push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -1594,25 +1604,40 @@ function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" end val = vectorize(dist, r) - - meta = getmetadata(vi, vn) - meta.idcs[vn] = length(meta.idcs) + 1 - push!(meta.vns, vn) - l = length(meta.vals) - n = length(val) - push!(meta.ranges, (l + 1):(l + n)) - append!(meta.vals, val) - push!(meta.dists, dist) - push!(meta.gids, gidset) - push!(meta.orders, get_num_produce(vi)) - push!(meta.flags["del"], false) - push!(meta.flags["trans"], false) + sym = getsym(vn) + if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + # The NamedTuple doesn't have an entry for this variable, let's add one. + md = Metadata( + Dict(vn => 1), + [vn], + [1:length(val)], + val, + [dist], + [gidset], + [get_num_produce(vi)], + Dict{String,BitVector}("trans" => [false], "del" => [false]), + ) + vi = Accessors.@set vi.metadata[sym] = md + else + meta = getmetadata(vi, vn) + meta.idcs[vn] = length(meta.idcs) + 1 + push!(meta.vns, vn) + l = length(meta.vals) + n = length(val) + push!(meta.ranges, (l + 1):(l + n)) + append!(meta.vals, val) + push!(meta.dists, dist) + push!(meta.gids, gidset) + push!(meta.orders, get_num_produce(vi)) + push!(meta.flags["del"], false) + push!(meta.flags["trans"], false) + end return vi end diff --git a/test/varinfo.jl b/test/varinfo.jl index 12387f6a7..ff0e7235f 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -145,6 +145,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end + + @testset "push!! to TypedVarInfo" begin + vn_x = @varname x + vn_y = @varname y + untyped_vi = VarInfo() + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + typed_vi = TypedVarInfo(untyped_vi) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + @test typed_vi[vn_x] == 1.0 + @test typed_vi[vn_y] == 2.0 + end + @testset "setgid!" begin vi = VarInfo() meta = vi.metadata @@ -511,6 +523,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) else vns_supported_standard end + + @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in + vns_supported + varinfo_subset = subset(varinfo, VarName[]) + @test isempty(varinfo_subset) + end + @testset "$(convert(Vector{VarName}, vns_subset))" for vns_subset in vns_supported varinfo_subset = subset(varinfo, vns_subset) @@ -638,6 +657,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.getdist(varinfo_merged, @varname(x)) isa Normal end + + # The below used to error, testing to avoid regression. + @testset "merge gids" begin + gidset_left = Set([Selector(1)]) + vi_left = VarInfo() + vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) + gidset_right = Set([Selector(2)]) + vi_right = VarInfo() + vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) + varinfo_merged = merge(vi_left, vi_right) + @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left + @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right + end end @testset "VarInfo with selectors" begin