Skip to content

Commit

Permalink
Merge branch 'dev-4.0.0' into dev-4.0.0-get-extra-or-default
Browse files Browse the repository at this point in the history
  • Loading branch information
wouterwln authored Mar 19, 2024
2 parents db172d4 + b99d6a7 commit b225c30
Show file tree
Hide file tree
Showing 12 changed files with 582 additions and 57 deletions.
2 changes: 1 addition & 1 deletion src/backends/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ GraphPPL.interface_aliases(::DefaultBackend, _) = GraphPPL.StaticInterfaceAliase
# And throws an error for `Composite` nodes since those has to be called with named arguments anyway
default_parametrization(::DefaultBackend, ::Atomic, fform::F, rhs::Tuple) where {F} = (in = rhs,)
default_parametrization(::DefaultBackend, ::Composite, fform::F, rhs) where {F} =
error("Composite nodes always have to be initialized with named arguments")
error("Composite nodes always have to be initialized with named arguments")
11 changes: 10 additions & 1 deletion src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,17 @@ using Static
using NamedTupleTools
using Dictionaries

import Base: put!, haskey, gensym, getindex, getproperty, setproperty!, setindex!, vec, iterate
import Base: put!, haskey, gensym, getindex, getproperty, setproperty!, setindex!, vec, iterate, showerror, Exception
import MetaGraphsNext.Graphs: neighbors, degree

export as_node, as_variable, as_context, savegraph, loadgraph

struct NotImplementedError <: Exception
message::String
end

showerror(io::IO, e::NotImplementedError) = print(io, "NotImplementedError: " * e.message)

struct Broadcasted
name::Symbol
end
Expand Down Expand Up @@ -188,6 +194,7 @@ unroll(something) = something
__proxy_unroll(something) = something
__proxy_unroll(proxy::ProxyLabel) = __proxy_unroll(proxy, proxy.index, proxy.proxied)
__proxy_unroll(proxy::ProxyLabel, index, proxied) = __safegetindex(__proxy_unroll(proxied), index)
__proxy_unroll(proxy::ProxyLabel, index::NTuple{N, UnitRange}, proxied) where {N} = __safegetindex(__proxy_unroll(proxied), index)

__safegetindex(something, index::FunctionalIndex) = Base.getindex(something, index)
__safegetindex(something, index::Tuple) = Base.getindex(something, index...)
Expand Down Expand Up @@ -884,11 +891,13 @@ check_variate_compatability(node::NodeLabel, index) =
error("Cannot call single random variable on the left-hand-side by an indexed statement")

check_variate_compatability(label::GraphPPL.ProxyLabel, index) = check_variate_compatability(unroll(label), index)
check_variate_compatability(label::GraphPPL.ProxyLabel, index...) = check_variate_compatability(unroll(label), index)

check_variate_compatability(node::ResizableArray{NodeLabel, V, N}, index::Vararg{Int, N}) where {V, N} = isassigned(node, index...)
check_variate_compatability(node::ResizableArray{NodeLabel, V, N}, index::Vararg{Int, M}) where {V, N, M} =
error("Index of length $(length(index)) not possible for $N-dimensional vector of random variables")


check_variate_compatability(node::ResizableArray{NodeLabel, V, N}, index::Nothing) where {V, N} =
error("Cannot call vector of random variables on the left-hand-side by an unindexed statement")

Expand Down
151 changes: 126 additions & 25 deletions src/plugins/variational_constraints/variational_constraints_engine.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import Base: showerror, Exception

struct UnresolvableFactorizationConstraintError <: Exception
message::String
end

Base.showerror(io::IO, e::UnresolvableFactorizationConstraintError) = println(io, "Unresolvable factorization constraint: " * e.message)



const VariationalConstraintsFactorizationIndicesKey = NodeDataExtraKey{:factorization_constraint_indices, Tuple}()
const VariationalConstraintsFactorizationBitSetKey = NodeDataExtraKey{:factorization_constraint_bitset, BoundedBitSetTuple}()
const VariationalConstraintsMarginalFormConstraintKey = NodeDataExtraKey{:marginal_form_constraint, Any}()
const VariationalConstraintsMessagesFormConstraintKey = NodeDataExtraKey{:messages_form_constraint, Any}()


"""
CombinedRange{L, R}
Expand All @@ -19,7 +29,6 @@ end
Base.firstindex(range::CombinedRange) = range.from
Base.lastindex(range::CombinedRange) = range.to
Base.in(item, range::CombinedRange) = firstindex(range) <= item <= lastindex(range)
Base.in(item::NTuple{N, Int} where {N}, range::CombinedRange) = CartesianIndex(item...) firstindex(range):lastindex(range)
Base.length(range::CombinedRange) = lastindex(range) - firstindex(range) + 1

Base.show(io::IO, range::CombinedRange) = print(io, repr(range.from), ":", repr(range.to))
Expand Down Expand Up @@ -83,6 +92,7 @@ __factorization_split_merge_range(a::Int, b::Int) = SplittedRange(a, b)
__factorization_split_merge_range(a::FunctionalIndex, b::Int) = SplittedRange(a, b)
__factorization_split_merge_range(a::Int, b::FunctionalIndex) = SplittedRange(a, b)
__factorization_split_merge_range(a::FunctionalIndex, b::FunctionalIndex) = SplittedRange(a, b)
__factorization_split_merge_range(a::NTuple{N, Int}, b::NTuple{N, Int}) where {N} = throw(NotImplementedError("q(var[firstindex])..q(var[lastindex]) for index dimension $N (constraint specified with $a and $b as endpoints)"))
__factorization_split_merge_range(a::Any, b::Any) = error("Cannot merge $(a) and $(b) indexes in `factorization_split`")

"""
Expand Down Expand Up @@ -390,16 +400,57 @@ Base.in(nodedata::NodeData, var::ResolvedIndexedVariable) = in(nodedata, getprop

Base.in(
nodedata::NodeData, properties::VariableNodeProperties, var::ResolvedIndexedVariable{T} where {T <: Union{Int, NTuple{N, Int} where N}}
) = Base.in(nodedata, properties, var, index(properties))

Base.in(
nodedata::NodeData,
properties::VariableNodeProperties,
var::ResolvedIndexedVariable{T} where {T <: Union{Int, NTuple{N, Int} where N}},
i::Union{Int, Nothing}
) = (getname(var) == getname(properties)) && (index(var) == index(properties)) && (getcontext(var) == getcontext(nodedata))

Base.in(
nodedata::NodeData,
properties::VariableNodeProperties,
var::ResolvedIndexedVariable{T} where {T <: Union{Int, NTuple{N, Int} where N}},
i::NTuple{M, Int} where {M}
) =
(getname(properties) == getname(var)) &&
(flattened_index(getcontext(var)[getname(var)], i) index(var)) &&
(getcontext(var) == getcontext(nodedata))

Base.in(nodedata::NodeData, properties::VariableNodeProperties, var::ResolvedIndexedVariable{T} where {T <: Nothing}) =
(getname(var) == getname(properties)) && (getcontext(var) == getcontext(nodedata))

Base.in(
nodedata::NodeData,
properties::VariableNodeProperties,
var::ResolvedIndexedVariable{T} where {T <: Union{SplittedRange, CombinedRange, UnitRange}}
) = (getname(properties) == getname(var)) && (index(properties) index(var)) && (getcontext(var) == getcontext(nodedata))
) = Base.in(nodedata, properties, var, index(properties))

Base.in(
nodedata::NodeData,
properties::VariableNodeProperties,
var::ResolvedIndexedVariable{T} where {T <: Union{SplittedRange, CombinedRange, UnitRange}},
i::NTuple{N, Int} where {N}
) =
(getname(properties) == getname(var)) &&
(flattened_index(getcontext(var)[getname(var)], i) index(var)) &&
(getcontext(var) == getcontext(nodedata))

Base.in(
nodedata::NodeData,
properties::VariableNodeProperties,
var::ResolvedIndexedVariable{T} where {T <: Union{SplittedRange, CombinedRange, UnitRange}},
i::Int
) = (getname(properties) == getname(var)) && (i index(var)) && (getcontext(var) == getcontext(nodedata))

Base.in(
nodedata::NodeData,
properties::VariableNodeProperties,
var::ResolvedIndexedVariable{T} where {T <: Union{SplittedRange, CombinedRange, UnitRange}},
i::Nothing
) = false

struct ResolvedConstraintLHS{V}
variables::V
Expand Down Expand Up @@ -582,33 +633,77 @@ end

get_constraint_names(constraint::NTuple{N, Tuple} where {N}) = map(entry -> GraphPPL.getname.(entry), constraint)

function __resolve(data::NodeData)
return __resolve(data, getproperties(data))
__resolve_index_consistency(model, labels, findex::Int, lindex::Int) = (findex, lindex)
function __resolve_index_consistency(model, labels, findex::NTuple{N, Int}, lindex::NTuple{N, Int}) where {N}
differing_indices = findall(map(indices -> indices[1] != indices[2], zip(findex, lindex)))
if length(differing_indices) == 1 && first(differing_indices) == N
full_array = model[first(labels[1])].context[first(labels[1]).name] # This black magic line gets the full array of the sliced variable that we need to acces. It accesses it through the context which is saved in the nodedata.
return flattened_index(full_array, findex), flattened_index(full_array, lindex)
else
throw(
NotImplementedError(
"Congratulations, you tried to define a factorization constraint for a >2 dimensional random variable where there is either more than one differing index between the endpoints of the constraint, or you've sliced the random variable in more than 1 dimension. We've thought about this
edge case but don't know how we can resolve this, let alone efficiently. Please open an issue on GitHub if you need this feature, or consider changing your model definition. Furthermore, PR's are always welcome!"
)
)
end
end

function __resolve(data::NodeData, properties::VariableNodeProperties)
return ResolvedIndexedVariable(getname(properties), index(properties), getcontext(data))
function __resolve(model::Model, label::NodeLabel)
data = model[label]
return __resolve(model, data, getproperties(data), index(getproperties(data)))
end

function __resolve(data::AbstractArray{T} where {T <: NodeData})
firstdata = first(data)
lastdata = last(data)
if getname(getproperties(firstdata)) != getname(getproperties(lastdata))
error("Cannot resolve factorization constraint for $(getname(getproperties(firstdata))) and $(getname(getproperties(lastdata))).")
function __resolve(::Model, data::NodeData, properties::VariableNodeProperties, i::Union{Nothing, Int})
# The variable is either a single variable or in a vector, then we don't really care.
return ResolvedIndexedVariable(getname(properties), i, getcontext(data))
end

function __resolve(model::Model, data::NodeData, properties::VariableNodeProperties, i::NTuple{N, Int} where {N})
# The variable is either a single variable or in a vector, then we don't really care.
full_array = getcontext(data)[getname(properties)]
return ResolvedIndexedVariable(getname(properties), flattened_index(full_array, i), getcontext(data))
end

function __resolve(model::Model, labels::AbstractArray{T, 1}) where {T <: NodeLabel}
fdata = model[first(labels)]
ldata = model[last(labels)]
if getname(getproperties(fdata)) != getname(getproperties(ldata))
throw(UnresolvableFactorizationConstraintError("Cannot resolve factorization constraint for $(getname(getproperties(fdata))) and $(getname(getproperties(ldata)))."))
end
# If we make a slice of a matrix in the constraints, we end up here (for example, q(x[1], x[2]) = q(x[1])q(x[2]) for matrix valued x).
# Then `index(getproperties(fdata))` and `index(getproperties(ldata))` will be `Tuple`, and we need to resolve this to a single `Int` in the dimension in which they differ
findex = index(getproperties(fdata))
lindex = index(getproperties(ldata))
findex, lindex = __resolve_index_consistency(model, labels, findex, lindex)

return ResolvedIndexedVariable(getname(getproperties(fdata)), CombinedRange(findex, lindex), getcontext(fdata))
end

function __resolve(model::Model, labels::AbstractArray{T, N} where {T <: NodeLabel}) where {N}
findex, flabel = firstwithindex(labels)
lindex, llabel = lastwithindex(labels)

fdata = model[flabel]
ldata = model[llabel]

# We have to test whether or not the `ResizableArray` of labels passed is a slice. If it is, we throw because the constraint is unresolvable
if CartesianIndex(index(getproperties(fdata))) != findex || CartesianIndex(index(getproperties(ldata))) != lindex
throw(UnresolvableFactorizationConstraintError(lazy"Did you pass a slice of the variable to a submodel ($(getname(getproperties(fdata)))), and then tried to factorize it? These partial factorization constraints cannot be resolved and are not supported."))
end

return ResolvedIndexedVariable(
getname(getproperties(firstdata)),
CombinedRange(index(getproperties(firstdata)), index(getproperties(lastdata))),
getcontext(firstdata)
getname(getproperties(fdata)),
CombinedRange(flattened_index(labels, findex.I), flattened_index(labels, lindex.I)),
getcontext(fdata)
)
end

function resolve(model::Model, context::Context, variable::IndexedVariable{<:SplittedRange})
global_label = unroll(context[getname(variable)])
resolved_indices = __factorization_specification_resolve_index(index(variable), global_label)
global_node_data = model[global_label[firstindex(resolved_indices):lastindex(resolved_indices)]]
firstdata = first(global_node_data)
lastdata = last(global_node_data)
firstdata = model[global_label[firstindex(resolved_indices)]]
lastdata = model[global_label[lastindex(resolved_indices)]]
if getname(getproperties(firstdata)) != getname(getproperties(lastdata))
error("Cannot resolve factorization constraint for $(getname(getproperties(firstdata))) and $(getname(getproperties(lastdata))).")
end
Expand All @@ -621,14 +716,20 @@ end

function resolve(model::Model, context::Context, variable::IndexedVariable{Nothing})
global_label = unroll(context[getname(variable)])
global_node_data = model[global_label]
return __resolve(global_node_data)
return __resolve(model, global_label)
end

function resolve(model::Model, context::Context, variable::IndexedVariable)
global_label = unroll(context[getname(variable)])[index(variable)]
global_node_data = model[global_label]
return __resolve(global_node_data)
global_label = unroll(context[getname(variable)])[index(variable)...]
return __resolve(model, global_label)
end

resolve(model::Model, context::Context, variable::IndexedVariable{CombinedRange{NTuple{N, Int}, NTuple{N, Int}}}) where {N} =
throw(UnresolvableFactorizationConstraintError("Cannot resolve factorization constraint for a combined range of dimension > 2."))

function resolve(model::Model, context::Context, variable::IndexedVariable{<:CombinedRange})
global_label = unroll(context[getname(variable)])[firstindex(index(variable)):lastindex(index(variable))]
return __resolve(model, global_label)
end

function resolve(model::Model, context::Context, constraint::FactorizationConstraint)
Expand Down Expand Up @@ -712,10 +813,10 @@ function is_decoupled_one_linked(links, unlinked::NodeData, constraint::Resolved
else
# Perhaps, this is possible to resolve automatically, but that would required
# quite some difficult graph traversal logic, so for now we just throw an error
error(lazy"""
Cannot resolve factorization constraint $(constraint) for an anonymous variable connected to variables $(join(links, ',')).
throw(UnresolvableFactorizationConstraintError(lazy"""
Cannot resolve factorization constraint $(constraint) for an anonymous variable connected to variables $(join(links, ',')).
As a workaround specify the name and the factorization constraint for the anonymous variable explicitly.
""")
"""))
end
end

Expand Down
77 changes: 73 additions & 4 deletions src/resizable_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function ResizableArray(array::AbstractVector{T}) where {T}
return ResizableArray{V, Vector{T}, get_recursive_depth(array)}(array)
end

ResizableArray(A::AbstractArray) = ResizableArray([A[:, i] for i in 1:size(A, 2)])
# ResizableArray(A::AbstractArray{T, N}) where {T, N} = ResizableArray{T, typeof(A), N}(A)

function make_recursive_vector(::Type{T}, ::Val{1}) where {T}
return T[]
Expand Down Expand Up @@ -109,6 +109,11 @@ function getindex(array::ResizableArray{T, V, N}, index::UnitRange) where {T, V,
return ResizableArray(array.data[index])
end

function getindex(array::ResizableArray{T, V, N}, index::Vararg{UnitRange}) where {T, V, N}
return ResizableArray(recursive_getindex(Val(length(index)), array.data, index...))
end


function getindex(array::ResizableArray{T, V, N}, index::Vararg{Int}) where {T, V, N}
@assert N >= length(index) "Invalid index $(index) for $(array) of shape $(size(array)))"
return recursive_getindex(Val(length(index)), array.data, index...)
Expand All @@ -122,10 +127,20 @@ function recursive_getindex(::Val{1}, array::Vector, index)
return array[index]
end

function recursive_getindex(::Val{1}, array::Vector, index::UnitRange)
return array[index]
end


function recursive_getindex(::Val{N}, array::Vector{V}, findex, index...) where {N, V}
return recursive_getindex(Val(N - 1), array[findex], index...)
end


function recursive_getindex(::Val{N}, array::Vector{V}, findex::UnitRange, index...) where {N, V}
return [recursive_getindex(Val(N - 1), array[i], index...) for i in findex]
end

function Base.show(io::IO, array::ResizableArray{T, V, N}) where {T, V, N}
print(io, "ResizableArray{$T,$N}(")
show(io, array.data)
Expand All @@ -144,10 +159,64 @@ end

Base.iterate(array::ResizableArray{T, V, N}, state = 1) where {T, V, N} = iterate(array.data, state)

function Base.map(f, array::ResizableArray{T, V, N}) where {T, V, N}
result = map(f, array.data)
return ResizableArray(result)
end

__length(array::ResizableArray{T, V, N}) where {T, V, N} = __recursive_length(Val(N), array.data)

function __recursive_length(::Val{N}, array) where {N}
if length(array) == 0
return 0
end
return sum((arr) -> __recursive_length(Val(N - 1), arr), array)
end

__recursive_length(::Val{1}, array) = length(array) == 0 ? 0 : sum((x) -> isassigned(array, x), 1:length(array))

function flattened_index(array::ResizableArray{T, V, N}, index::NTuple{N, Int}) where {T, V, N}
return __flattened_index(Val(N), array.data, index...)
end

flattened_index(array::ResizableArray{T, V, 1}, index::Int) where {T, V} = __flattened_index(Val(1), array.data, index)

function __flattened_index(::Val{1}, array::Vector{T}, index) where {T}
if isassigned(array, index)
return index
else
return sum((x) -> isassigned(array, x), 1:index)
end
end

function __flattened_index(::Val{N}, array::Vector{V}, findex, index...) where {N, V}
if findex == 1
return __flattened_index(Val(N - 1), array[findex], index...)
else
return sum(i -> __recursive_length(Val(N - 1), array[i]), 1:(findex - 1)) + __flattened_index(Val(N - 1), array[findex], index...)
end
end

function Base.first(array::ResizableArray{T, V, N}) where {T, V, N}
for index in Tuple.(CartesianIndices(size(array)))
if isassigned(array, index...)::Bool
return array[index...]
for index in CartesianIndices(size(array)) #TODO improve performance of this function since it uses splatting
if isassigned(array, index.I...)::Bool
return array[index.I...]
end
end
end

function firstwithindex(array::ResizableArray{T, V, N}) where {T, V, N}
for index in CartesianIndices(size(array)) #TODO improve performance of this function since it uses splatting
if isassigned(array, index.I...)::Bool
return (index, array[index.I...])
end
end
end

function lastwithindex(array::ResizableArray{T, V, N}) where {T, V, N}
for index in reverse(CartesianIndices(reverse(size(array)))) #TODO improve performance of this function since it uses splatting
if isassigned(array, reverse(index.I)...)::Bool
return (index, array[reverse(index.I)...])
end
end
end
4 changes: 2 additions & 2 deletions test/graph_construction_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -745,10 +745,10 @@ end
ydata = rand(10)
prior = Beta(1, 1)

model = create_model(coin_model_priors(prior = prior)) do model, context
model = create_model(coin_model_priors(prior = prior)) do model, context
return (; y = getorcreate!(model, context, NodeCreationOptions(kind = :data), :y, LazyIndex(ydata)))
end

@test length(collect(filter(as_node(Bernoulli), model))) === 10
@test length(collect(filter(as_node(prior), model))) === 1
end
end
Loading

0 comments on commit b225c30

Please sign in to comment.