Skip to content

Commit

Permalink
Merge pull request #157 from ReactiveBayes/vardict
Browse files Browse the repository at this point in the history
Implement `VarDict`
  • Loading branch information
bvdmitri authored Mar 1, 2024
2 parents 74fea7d + c537793 commit 9b3ac74
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 46 deletions.
90 changes: 54 additions & 36 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ struct Context
fform::Function
prefix::String
parent::Union{Context, Nothing}
submodel_counts::Dict{Any, Int}
children::Dict{FactorID, Context}
factor_nodes::Dict{FactorID, NodeLabel}
submodel_counts::UnorderedDictionary{Any, Int}
children::UnorderedDictionary{FactorID, Context}
factor_nodes::UnorderedDictionary{FactorID, NodeLabel}
individual_variables::UnorderedDictionary{Symbol, NodeLabel}
vector_variables::UnorderedDictionary{Symbol, ResizableArray{NodeLabel, Vector{NodeLabel}, 1}}
tensor_variables::UnorderedDictionary{Symbol, ResizableArray{NodeLabel}}
Expand All @@ -231,9 +231,9 @@ function Context(depth::Int, fform::Function, prefix::String, parent)
fform,
prefix,
parent,
Dict{Any, Int}(),
Dict{FactorID, Context}(),
Dict{FactorID, NodeLabel}(),
UnorderedDictionary{Any, Int}(),
UnorderedDictionary{FactorID, Context}(),
UnorderedDictionary{FactorID, NodeLabel}(),
UnorderedDictionary{Symbol, NodeLabel}(),
UnorderedDictionary{Symbol, ResizableArray{NodeLabel, Vector{NodeLabel}, 1}}(),
UnorderedDictionary{Symbol, ResizableArray{NodeLabel}}(),
Expand Down Expand Up @@ -262,7 +262,7 @@ path_to_root(context::Context) = [context, path_to_root(parent(context))...]

function generate_factor_nodelabel(context::Context, fform::Any)
if count(context, fform) == 0
context.submodel_counts[fform] = 1
set!(context.submodel_counts, fform, 1)
else
context.submodel_counts[fform] += 1
end
Expand Down Expand Up @@ -311,17 +311,13 @@ haskey(context::Context, key::Symbol) =

haskey(context::Context, key::FactorID) = haskey(context.factor_nodes, key) || haskey(context.children, key)

function Base.getindex(c::Context, key::Any)
function Base.getindex(c::Context, key::Symbol)
if haskey(c.individual_variables, key)
return c.individual_variables[key]
elseif haskey(c.vector_variables, key)
return c.vector_variables[key]
elseif haskey(c.tensor_variables, key)
return c.tensor_variables[key]
elseif haskey(c.factor_nodes, key)
return c.factor_nodes[key]
elseif haskey(c.children, key)
return c.children[key]
elseif haskey(c.proxies, key)
return c.proxies[key]
end
Expand All @@ -337,41 +333,63 @@ function Base.getindex(c::Context, key::FactorID)
throw(KeyError(key))
end

function Base.getindex(c::Context, fform, index::Int)
return c[FactorID(fform, index)]
end
Base.getindex(c::Context, fform, index::Int) = c[FactorID(fform, index)]

function Base.setindex!(c::Context, val::NodeLabel, key::Symbol)
return setindex!(c, val, key, nothing)
end
Base.setindex!(c::Context, val::NodeLabel, key::Symbol) = set!(c.individual_variables, key, val)
Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Nothing) = set!(c.individual_variables, key, val)
Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Int) = c.vector_variables[key][index] = val
Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::NTuple{N, Int64} where {N}) = c.tensor_variables[key][index...] = val
Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, 1} where {T}, key::Symbol) = set!(c.vector_variables, key, val)
Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, N} where {T, N}, key::Symbol) = set!(c.tensor_variables, key, val)
Base.setindex!(c::Context, val::ProxyLabel, key::Symbol) = set!(c.proxies, key, val)
Base.setindex!(c::Context, val::ProxyLabel, key::Symbol, index::Nothing) = set!(c.proxies, key, val)
Base.setindex!(c::Context, val::Context, key::FactorID) = set!(c.children, key, val)
Base.setindex!(c::Context, val::NodeLabel, key::FactorID) = set!(c.factor_nodes, key, val)

function Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Nothing)
return set!(c.individual_variables, key, val)
end
"""
VarDict
function Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::Int)
return c.vector_variables[key][index] = val
A recursive dictionary structure that contains all variables in a probabilistic graphical model.
Iterates over all variables in the model and their children in a linear fashion, but preserves the recursive nature of the actual model.
"""
struct VarDict{T}
variables::UnorderedDictionary{Symbol, T}
children::UnorderedDictionary{FactorID, VarDict}
end

function Base.setindex!(c::Context, val::NodeLabel, key::Symbol, index::NTuple{N, Int64}) where {N}
return c.tensor_variables[key][index...] = val
function VarDict(context::Context)
dictvariables = merge(individual_variables(context), vector_variables(context), tensor_variables(context))
dictchildren = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> VarDict(child), children(context)))
return VarDict(dictvariables, dictchildren)
end

Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, 1} where {T}, key::Symbol) = set!(c.vector_variables, key, val)
Base.setindex!(c::Context, val::ResizableArray{NodeLabel, T, N} where {T, N}, key::Symbol) = set!(c.tensor_variables, key, val)
variables(vardict::VarDict) = vardict.variables
children(vardict::VarDict) = vardict.children

function Base.setindex!(c::Context, val::ProxyLabel, key::Symbol)
return setindex!(c, val, key, nothing)
end
haskey(vardict::VarDict, key::Symbol) = haskey(vardict.variables, key)
haskey(vardict::VarDict, key::Tuple{T, Int} where {T}) = haskey(vardict.children, FactorID(first(key), last(key)))
haskey(vardict::VarDict, key::FactorID) = haskey(vardict.children, key)

Base.getindex(vardict::VarDict, key::Symbol) = vardict.variables[key]
Base.getindex(vardict::VarDict, f, index::Int) = vardict.children[FactorID(f, index)]
Base.getindex(vardict::VarDict, key::Tuple{T, Int} where {T}) = vardict.children[FactorID(first(key), last(key))]
Base.getindex(vardict::VarDict, key::FactorID) = vardict.children[key]

function Base.setindex!(c::Context, val::ProxyLabel, key::Symbol, index::Nothing)
return set!(c.proxies, key, val)
function Base.map(f, vardict::VarDict)
mapped_variables = map(f, variables(vardict))
mapped_children = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> map(f, child), children(vardict)))
return VarDict(mapped_variables, mapped_children)
end

function Base.setindex!(c::Context, val::Context, key::FactorID)
return setindex!(c.children, val, key)
function Base.filter(f, vardict::VarDict)
filtered_variables = filter(f, variables(vardict))
filtered_children = convert(UnorderedDictionary{FactorID, VarDict}, map(child -> filter(f, child), children(vardict)))
return VarDict(filtered_variables, filtered_children)
end


Base.:(==)(left::VarDict, right::VarDict) = left.variables == right.variables && left.children == right.children

"""
NodeCreationOptions(namedtuple)
Expand Down Expand Up @@ -1151,7 +1169,7 @@ function add_atomic_factor_node!(model::Model, context::Context, options::NodeCr
)

model[potential_label] = nodedata
context.factor_nodes[factornode_id] = label
context[factornode_id] = label

return label, nodedata, getproperties(nodedata)
end
Expand All @@ -1178,7 +1196,7 @@ Returns:
"""
function add_composite_factor_node!(model::Model, parent_context::Context, context::Context, node_name)
node_id = generate_factor_nodelabel(parent_context, node_name)
parent_context.children[node_id] = context
parent_context[node_id] = context
return node_id
end

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/meta/meta_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function apply_meta!(model::Model, context::Context, meta::MetaSpecification)
for meta_obj in getmetaobjects(meta)
apply_meta!(model, context, meta_obj)
end
for (factor_id, child) in children(context)
for (factor_id, child) in pairs(children(context))
if (submodel = getspecificsubmodelmeta(meta, factor_id)) !== nothing
apply_meta!(model, child, getmetaobjects(submodel))
elseif (submodel = getgeneralsubmodelmeta(meta, fform(factor_id))) !== nothing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ function apply_constraints!(
for rfc in constraints(resolved_factorization_constraints)
apply_constraints!(model, context, rfc)
end
for (factor_id, child) in children(context)
for (factor_id, child) in pairs(children(context))
if factor_id keys(specific_submodel_constraints(constraint_set))
apply_constraints!(
model, child, getconstraint(specific_submodel_constraints(constraint_set)[factor_id]), resolved_factorization_constraints
Expand Down
3 changes: 1 addition & 2 deletions test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,5 +592,4 @@ end
@test length(collect(filter(as_node(Normal), model))) === 20
@test length(collect(filter(as_variable(:x), model))) === 10
@test length(collect(filter(as_variable(:y), model))) === 10

end
end
117 changes: 112 additions & 5 deletions test/graph_engine_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ end
@test getname(last(p)) === :y
end

let p = ProxyLabel(:x, (1, ), y)
let p = ProxyLabel(:x, (1,), y)
@test_throws "Indexing a single node label `y` with an index `[1]` is not allowed" unroll(p)
end

Expand Down Expand Up @@ -786,7 +786,6 @@ end

ctx6 = Context(ctx3, secondlayer)
@test typeof(ctx6) == Context && ctx6.prefix == "test_layer_secondlayer" && length(ctx6.individual_variables) == 0 && ctx6.depth == 2

end

@testitem "haskey(::Context)" begin
Expand Down Expand Up @@ -847,11 +846,11 @@ end

ctx = Context()
@test_throws KeyError ctx[FactorID(sum, 1)]
ctx.children[FactorID(sum, 1)] = Context()
ctx[FactorID(sum, 1)] = Context()
@test ctx[FactorID(sum, 1)] == ctx.children[FactorID(sum, 1)]

@test_throws KeyError ctx[FactorID(sum, 2)]
ctx.factor_nodes[FactorID(sum, 2)] = NodeLabel(:sum, 1)
ctx[FactorID(sum, 2)] = NodeLabel(:sum, 1)
@test ctx[FactorID(sum, 2)] == ctx.factor_nodes[FactorID(sum, 2)]
end

Expand All @@ -878,6 +877,105 @@ end
@test path_to_root(inner_inner_context) == [inner_inner_context, inner_context, ctx]
end

@testitem "VarDict" begin
using GraphPPL
import GraphPPL: Context, VarDict

ctx = Context()
vardict = VarDict(ctx)
@test isa(vardict, VarDict)

using Distributions

import GraphPPL: create_model, getorcreate!, LazyIndex, NodeCreationOptions, getcontext, is_random, is_data, getproperties

@model function submodel(y, x_prev, x_next)
γ ~ Gamma(1, 1)
x_next ~ Normal(x_prev, γ)
y ~ Normal(x_next, 1)
end

@model function state_space_model_with_new(y)
x[1] ~ Normal(0, 1)
y[1] ~ Normal(x[1], 1)
for i in 2:length(y)
# `x[i]` is not defined here, so this should fail
y[i] ~ submodel(x_next = new(x[i]), x_prev = x[i - 1])
end
end

ydata = ones(10)
model = create_model(state_space_model_with_new()) do model, ctx
y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :y, LazyIndex(ydata))
return (y = y,)
end

context = getcontext(model)
vardict = VarDict(context)

@test haskey(vardict, :y)
@test haskey(vardict, :x)
for i in 1:(length(ydata) - 1)
@test haskey(vardict, (submodel, i))
@test haskey(vardict[submodel, i], )
end

@test vardict[:y] === context[:y]
@test vardict[:x] === context[:x]
@test vardict[submodel, 1] == VarDict(context[submodel, 1])

result = map(identity, vardict)
@test haskey(result, :y)
@test haskey(result, :x)
for i in 1:(length(ydata) - 1)
@test haskey(result, (submodel, i))
@test haskey(result[submodel, i], )
end

result = map(vardict) do variable
return length(variable)
end
@test haskey(result, :y)
@test haskey(result, :x)
@test result[:y] === length(ydata)
@test result[:x] === length(ydata)
for i in 1:(length(ydata) - 1)
@test result[(submodel, i)][] === 1
@test result[GraphPPL.FactorID(submodel, i)][] === 1
@test result[submodel, i][] === 1
end

# Filter only random variables
result = filter(vardict) do label
if label isa GraphPPL.ResizableArray
all(is_random.(getproperties.(model[label])))
else
return is_random(getproperties(model[label]))
end
end
@test !haskey(result, :y)
@test haskey(result, :x)
for i in 1:(length(ydata) - 1)
@test haskey(result, (submodel, i))
@test haskey(result[submodel, i], )
end

# Filter only data variables
result = filter(vardict) do label
if label isa GraphPPL.ResizableArray
all(is_data.(getproperties.(model[label])))
else
return is_data(getproperties(model[label]))
end
end
@test haskey(result, :y)
@test !haskey(result, :x)
for i in 1:(length(ydata) - 1)
@test haskey(result, (submodel, i))
@test !haskey(result[submodel, i], )
end
end

@testitem "NodeType" begin
include("model_zoo.jl")
import GraphPPL: NodeType, Composite, Atomic
Expand Down Expand Up @@ -991,7 +1089,16 @@ end

@testitem "getorcreate!" begin
using Graphs
import GraphPPL: create_model, getcontext, getorcreate!, check_variate_compatability, NodeLabel, ResizableArray, NodeCreationOptions, getproperties, is_kind
import GraphPPL:
create_model,
getcontext,
getorcreate!,
check_variate_compatability,
NodeLabel,
ResizableArray,
NodeCreationOptions,
getproperties,
is_kind

let # let block to suppress the scoping warnings
# Test 1: Creation of regular one-dimensional variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -862,4 +862,4 @@ end
return getextra(model[node], :factorization_constraint) === ((interfaces...,),)
end
end
end
end

0 comments on commit 9b3ac74

Please sign in to comment.