Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Materialize anonymous variables later only if needed #164

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ext/GraphPPLDistributionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,8 @@ end
end
end

# Special cases
GraphPPLDistributionsExt.distributions_ext_input_interfaces(::Type{<:Distributions.InverseWishart}) = GraphPPL.StaticInterfaces((:df, :scale))
GraphPPLDistributionsExt.distributions_ext_interfaces(::Type{<:Distributions.InverseWishart}) = GraphPPL.StaticInterfaces((:out, :df, :scale))

end
85 changes: 49 additions & 36 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@
function Base.show(io::IO, properties::VariableNodeProperties)
print(io, "name = ", properties.name, ", index = ", properties.index)
if !isnothing(properties.link)
print(io, "(linked to ", node.link, ")")
print(io, ", linked to ", properties.link)

Check warning on line 474 in src/graph_engine.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_engine.jl#L474

Added line #L474 was not covered by tests
end
end

Expand All @@ -480,23 +480,24 @@

Data associated with a factor node in a probabilistic graphical model.
"""
struct FactorNodeProperties
struct FactorNodeProperties{D}
fform::Any
neighbors::Vector{Tuple{NodeLabel, EdgeLabel, Any}}
neighbors::Vector{Tuple{NodeLabel, EdgeLabel, D}}
end

FactorNodeProperties(; fform, neighbors = Tuple{NodeLabel, EdgeLabel}[]) = FactorNodeProperties(fform, neighbors)
FactorNodeProperties(; fform, neighbors = Tuple{NodeLabel, EdgeLabel, NodeData}[]) = FactorNodeProperties(fform, neighbors)

is_factor(::FactorNodeProperties) = true
is_variable(::FactorNodeProperties) = false

function Base.convert(::Type{FactorNodeProperties}, fform, options::NodeCreationOptions)
return FactorNodeProperties(fform = fform, neighbors = get(options, :neighbors, Tuple{NodeLabel, EdgeLabel}[]))
return FactorNodeProperties(fform = fform, neighbors = get(options, :neighbors, Tuple{NodeLabel, EdgeLabel, NodeData}[]))
end

fform(properties::FactorNodeProperties) = properties.fform
neighbors(properties::FactorNodeProperties) = properties.neighbors
addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data::F) where {F} = push!(properties.neighbors, (variable, edge, data))
addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data) =
push!(properties.neighbors, (variable, edge, data))
neighbor_data(properties::FactorNodeProperties) = Iterators.map(neighbor -> neighbor[3], neighbors(properties))

function Base.show(io::IO, properties::FactorNodeProperties)
Expand All @@ -513,7 +514,7 @@
"""
struct NodeData
context :: Context
properties :: Union{VariableNodeProperties, FactorNodeProperties}
properties :: Union{VariableNodeProperties, FactorNodeProperties{NodeData}}
extra :: UnorderedDictionary{Symbol, Any}
end

Expand Down Expand Up @@ -543,13 +544,13 @@
"""
struct NodeDataExtraKey{K, T} end

function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where K
function hasextra(node::NodeData, key::NodeDataExtraKey{K}) where {K}
return haskey(node.extra, K)
end
function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T}
function getextra(node::NodeData, key::NodeDataExtraKey{K, T})::T where {K, T}
return getindex(node.extra, K)::T
end
function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T}
function setextra!(node::NodeData, key::NodeDataExtraKey{K}, value::T) where {K, T}
return insert!(node.extra, K, value)
end

Expand Down Expand Up @@ -765,7 +766,7 @@
struct Atomic <: NodeType end

NodeType(::Type) = Atomic()
NodeType(::F) where {F<:Function} = Atomic()
NodeType(::F) where {F <: Function} = Atomic()

abstract type NodeBehaviour end

Expand Down Expand Up @@ -1293,7 +1294,7 @@
# Returns
- `missing_interfaces`: A `Vector` of the missing interfaces.
"""
function missing_interfaces(fform, val, known_interfaces::NamedTuple)
function missing_interfaces(fform::F, val, known_interfaces::NamedTuple) where {F}
return missing_interfaces(interfaces(fform, val), StaticInterfaces(keys(known_interfaces)))
end

Expand All @@ -1303,12 +1304,12 @@
return StaticInterfaces(filter(interface -> interface ∉ present_interfaces, all_interfaces))
end

function prepare_interfaces(fform, lhs_interface, rhs_interfaces::NamedTuple)
function prepare_interfaces(fform::F, lhs_interface, rhs_interfaces::NamedTuple) where {F}
missing_interface = missing_interfaces(fform, static(length(rhs_interfaces)) + static(1), rhs_interfaces)
return prepare_interfaces(missing_interface, fform, lhs_interface, rhs_interfaces)
end

function prepare_interfaces(::StaticInterfaces{I}, fform, lhs_interface, rhs_interfaces::NamedTuple) where {I}
function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_interfaces::NamedTuple) where {I, F}
@assert length(I) == 1 lazy"Expected only one missing interface, got $I of length $(length(I)) (node $fform with interfaces $(keys(rhs_interfaces)))))"
missing_interface = first(I)
return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...))
Expand All @@ -1320,8 +1321,8 @@
return map(materialize_interface, interfaces)
end

default_parametrization(::Atomic, fform, rhs::Tuple) = (in = rhs,)
default_parametrization(::Composite, fform, rhs) = error("Composite nodes always have to be initialized with named arguments")
default_parametrization(::Atomic, fform::F, rhs::Tuple) where {F} = (in = rhs,)
default_parametrization(::Composite, fform::F, rhs) where {F} = error("Composite nodes always have to be initialized with named arguments")

# maybe change name

Expand All @@ -1348,25 +1349,20 @@
return make_node!(model, ctx, EmptyNodeCreationOptions, fform, lhs_interfaces, rhs_interfaces)
end

# Special case which should materialize anonymous variable
function make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::AnonymousVariable, rhs_interfaces) where {F}
lhs_materialized = materialize_anonymous_variable!(lhs_interface, fform, rhs_interfaces)::NodeLabel
return make_node!(model, ctx, options, fform, lhs_materialized, rhs_interfaces)
end

make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} =
make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} =
make_node!(NodeType(fform), model, ctx, options, fform, lhs_interface, rhs_interfaces)

#if it is composite, we assume it should be materialized and it is stochastic
make_node!(nodetype::Composite, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} =
make_node!(True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces)
make_node!(
nodetype::Composite, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces
) where {F} = make_node!(True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces)

# If a node is an object and not a function, we materialize it as a stochastic atomic node
make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces::Nothing) where {F} =
make_node!(model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces::Nothing) where {F} =
make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, NamedTuple{}())

# If node is Atomic, check stochasticity
make_node!(::Atomic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} =
make_node!(::Atomic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} =
make_node!(Atomic(), NodeBehaviour(fform), model, ctx, options, fform, lhs_interface, rhs_interfaces)

#If a node is deterministic, we check if there are any NodeLabel objects in the rhs_interfaces (direct check if node should be materialized)
Expand All @@ -1379,7 +1375,8 @@
fform::F,
lhs_interface,
rhs_interfaces
) where {F} = make_node!(contains_nodelabel(rhs_interfaces), atomic, deterministic, model, ctx, options, fform, lhs_interface, rhs_interfaces)
) where {F} =
make_node!(contains_nodelabel(rhs_interfaces), atomic, deterministic, model, ctx, options, fform, lhs_interface, rhs_interfaces)

# If the node should not be materialized (if it's Atomic, Deterministic and contains no NodeLabel objects), we return the function evaluated at the interfaces
make_node!(
Expand All @@ -1404,7 +1401,7 @@
fform::F,
lhs_interface,
rhs_interfaces::NamedTuple
) where {F} = (nothing, fform(; rhs_interfaces...))
) where {F} = (nothing, fform(; rhs_interfaces...))

make_node!(
::False,
Expand All @@ -1419,8 +1416,9 @@
) where {F} = (nothing, fform(rhs_interfaces.args...; rhs_interfaces.kwargs...))

# If a node is Stochastic, we always materialize.
make_node!(::Atomic, ::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces) where {F} =
make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces)
make_node!(
::Atomic, ::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface, rhs_interfaces
) where {F} = make_node!(True(), Atomic(), Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces)

# If we have to materialize but lhs_interface is nothing, we create a variable for it
function make_node!(
Expand All @@ -1440,6 +1438,21 @@
return make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_node, rhs_interfaces)
end

function make_node!(
materialize::True,
node_type::NodeType,
behaviour::NodeBehaviour,
model::Model,
ctx::Context,
options::NodeCreationOptions,
fform::F,
lhs_interface::AnonymousVariable,
rhs_interfaces
) where {F}
lhs_materialized = materialize_anonymous_variable!(lhs_interface, fform, rhs_interfaces)::NodeLabel
return make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_materialized, rhs_interfaces)
end

# If we have to materialize but the rhs_interfaces argument is not a NamedTuple, we convert it
make_node!(
materialize::True,
Expand Down Expand Up @@ -1473,7 +1486,7 @@
fform::F,
lhs_interface::Union{NodeLabel, ProxyLabel},
rhs_interfaces::MixedArguments
) where {F} = error("MixedArguments not supported for rhs_interfaces when node has to be materialized")
) where {F} = error("MixedArguments not supported for rhs_interfaces when node has to be materialized")

make_node!(
materialize::True,
Expand All @@ -1485,7 +1498,7 @@
fform::F,
lhs_interface::Union{NodeLabel, ProxyLabel},
rhs_interfaces::Tuple{}
) where {F} = make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_interface, NamedTuple{}())
) where {F} = make_node!(materialize, node_type, behaviour, model, ctx, options, fform, lhs_interface, NamedTuple{}())

make_node!(
materialize::True,
Expand All @@ -1497,7 +1510,7 @@
fform::F,
lhs_interface::Union{NodeLabel, ProxyLabel},
rhs_interfaces::Tuple
) where {F} = error(lazy"Composite node $fform cannot should be called with explicitly naming the interface names")
) where {F} = error(lazy"Composite node $fform cannot should be called with explicitly naming the interface names")

make_node!(
materialize::True,
Expand All @@ -1509,7 +1522,7 @@
fform::F,
lhs_interface::Union{NodeLabel, ProxyLabel},
rhs_interfaces::NamedTuple
) where {F} = make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1))
) where {F} = make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1))

"""
make_node!
Expand All @@ -1535,7 +1548,7 @@
fform::F,
lhs_interface::Union{NodeLabel, ProxyLabel},
rhs_interfaces::NamedTuple
) where {F}
) where {F}
fform = factor_alias(fform, Val(keys(rhs_interfaces)))
interfaces = materialze_interfaces(prepare_interfaces(fform, lhs_interface, rhs_interfaces))
nodeid, _, _ = materialize_factor_node!(model, context, options, fform, interfaces)
Expand Down
45 changes: 45 additions & 0 deletions test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,3 +593,48 @@ end
@test length(collect(filter(as_variable(:x), model))) === 10
@test length(collect(filter(as_variable(:y), model))) === 10
end

@testitem "Anonymous variables should not be created from arithmetical operations on pure constants" begin
using Distributions, LinearAlgebra

import GraphPPL: create_model, getorcreate!, NodeCreationOptions, LazyIndex, variable_nodes, getproperties, is_random, getname

@model function mv_iid_inverse_wishart_known_mean(y, d)
m ~ MvNormal(zeros(d + 1 - 1 + 1 - 1), Matrix(Diagonal(ones(d + 1 - 1 + 1 - 1))))
C ~ InverseWishart(d + 1, Matrix(Diagonal(ones(d))))

for i in eachindex(y)
y[i] ~ MvNormal(m, C)
end
end

ydata = rand(10)

for d in 1:3
model = create_model(mv_iid_inverse_wishart_known_mean(d = d)) do model, ctx
y = getorcreate!(model, ctx, NodeCreationOptions(kind = :data), :y, LazyIndex(ydata))
return (y = y,)
end

variable_nodes(model) do label, nodedata
properties = getproperties(nodedata)
if is_random(properties)
# Shouldn't be any anonymous variables here
@test getname(properties) ∈ (:C, :m)
end
end

@test length(collect(filter(as_node(MvNormal), model))) === 11
@test length(collect(filter(as_node(InverseWishart), model))) === 1
@test length(collect(filter(as_node(Matrix), model))) === 0
@test length(collect(filter(as_node(Diagonal), model))) === 0
@test length(collect(filter(as_node(ones), model))) === 0
@test length(collect(filter(as_node(+), model))) === 0
@test length(collect(filter(as_node(-), model))) === 0
@test length(collect(filter(as_node(sum), model))) === 0
@test length(collect(filter(as_variable(:C), model))) === 1
@test length(collect(filter(as_variable(:m), model))) === 1


end
end
Loading