From bd4baf154e29168d26ec5b04b1b29a30db255f09 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 14 Oct 2024 13:40:41 +0100 Subject: [PATCH 1/3] Fix treatment of gid in merge(::Metadata) --- src/varinfo.jl | 8 +++++++- test/varinfo.jl | 13 +++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 8727796bc..4b229d828 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -490,7 +490,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - gids = metadata_right.gids # NOTE: giving precedence to `metadata_right` + gids = Set{Selector}[] orders = Int[] flags = Dict{String,BitVector}() # Initialize the `flags`. @@ -520,6 +520,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) dist_right = getdist(metadata_right, vn) # Give precedence to `metadata_right`. push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders`: giving precedence to `metadata_right` push!(orders, getorder(metadata_right, vn)) # `flags` @@ -539,6 +541,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_left = getdist(metadata_left, vn) push!(dists, dist_left) + gid = metadata_left.gids[getidx(metadata_left, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_left, vn)) # `flags` @@ -557,6 +561,8 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) # `dists` dist_right = getdist(metadata_right, vn) push!(dists, dist_right) + gid = metadata_right.gids[getidx(metadata_right, vn)] + push!(gids, gid) # `orders` push!(orders, getorder(metadata_right, vn)) # `flags` diff --git a/test/varinfo.jl b/test/varinfo.jl index 65f849dda..88439425a 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -694,6 +694,19 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] @test DynamicPPL.istrans(varinfo_merged, @varname(x)) end + + # The below used to error, testing to avoid regression. + @testset "merge gids" begin + gidset_left = Set([Selector(1)]) + vi_left = VarInfo() + vi_left = push!!(vi_left, @varname(x), 1.0, Normal(), gidset_left) + gidset_right = Set([Selector(2)]) + vi_right = VarInfo() + vi_right = push!!(vi_right, @varname(y), 2.0, Normal(), gidset_right) + varinfo_merged = merge(vi_left, vi_right) + @test DynamicPPL.getgid(varinfo_merged, @varname(x)) == gidset_left + @test DynamicPPL.getgid(varinfo_merged, @varname(y)) == gidset_right + end end @testset "VarInfo with selectors" begin From d804ef159c9efb030ab2aab7dd9aa7f33a38bc27 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 14 Oct 2024 17:13:50 +0100 Subject: [PATCH 2/3] Allowing pushing new symbols to TypedVarInfo --- src/varinfo.jl | 28 ++++++++++++++++++++++------ test/varinfo.jl | 12 ++++++++++++ 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4b229d828..13674555f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1832,14 +1832,31 @@ function BangBang.push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) if vi isa UntypedVarInfo - @assert ~(vn in keys(vi)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" + @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to TypedVarInfo of syms $(syms(vi)) with dist=$dist, gid=$gidset" + end + + sym = getsym(vn) + if vi isa TypedVarInfo && ~haskey(vi.metadata, sym) + # The NamedTuple doesn't have an entry for this variable, let's add one. + val = tovec(r) + md = Metadata( + Dict(vn => 1), + [vn], + [1:length(val)], + val, + [dist], + [gidset], + [get_num_produce(vi)], + Dict{String,BitVector}("trans" => [false], "del" => [false]), + ) + vi = Accessors.@set vi.metadata[sym] = md + else + meta = getmetadata(vi, vn) + push!(meta, vn, r, dist, gidset, get_num_produce(vi)) end - meta = getmetadata(vi, vn) - push!(meta, vn, r, dist, gidset, get_num_produce(vi)) - return vi end @@ -1870,7 +1887,6 @@ function Base.push!(meta::Metadata, vn, r, dist, gidset, num_produce) push!(meta.orders, num_produce) push!(meta.flags["del"], false) push!(meta.flags["trans"], false) - return meta end diff --git a/test/varinfo.jl b/test/varinfo.jl index 88439425a..308f6d5b7 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -154,6 +154,18 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_varinfo!(vi) test_varinfo!(empty!!(TypedVarInfo(vi))) end + + @testset "push!! to TypedVarInfo" begin + vn_x = @varname x + vn_y = @varname y + untyped_vi = VarInfo() + untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1), Selector()) + typed_vi = TypedVarInfo(untyped_vi) + typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1), Selector()) + @test typed_vi[vn_x] == 1.0 + @test typed_vi[vn_y] == 2.0 + end + @testset "setgid!" begin vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata From 78f12c5c30f3db88293be952c240f56cc6fb1da7 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 15 Oct 2024 14:56:29 +0100 Subject: [PATCH 3/3] Bump patch version to 0.30.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f995d7359..eab8c362c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30" +version = "0.30.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"