Skip to content

Tree enumerator #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "GraphTensorNetworks"
uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
version = "0.2.1"
version = "0.2.2"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
Expand Down Expand Up @@ -35,9 +36,10 @@ Polynomials = "2.0"
Primes = "0.5"
Requires = "1"
SIMDTypes = "0.1"
StatsBase = "0.33"
TropicalNumbers = "0.4, 0.5"
Viznet = "0.3"
StatsBase = "0.33"
AbstractTrees = "0.3"
julia = "1"

[extras]
Expand Down
1 change: 1 addition & 0 deletions docs/src/ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Polynomials.Polynomial
TruncatedPoly
Max2Poly
ConfigEnumerator
TreeConfigEnumerator
ConfigSampler
```

Expand Down
11 changes: 11 additions & 0 deletions examples/IndependentSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ Compose.compose(context(),
# One can use [`ConfigsAll`](@ref) to enumerate all sets satisfying the problem constraint.
all_independent_sets = solve(problem, ConfigsAll())[]

# It is often difficult to store all configurations in a vector.
# A more clever way to store the data is using the [`TreeConfigEnumerator`](@ref) format.
all_independent_sets_tree = solve(problem, ConfigsAll(; tree_storage=true))[]

# The results encode the configurations in the sum-product-tree format. One can count and enumerate them explicitly by typing
length(all_independent_sets_tree)

#

collect(all_independent_sets_tree)

# To save/read a set of configuration to disk, one can type the following
filename = tempname()

Expand Down
4 changes: 2 additions & 2 deletions examples/MaximalIS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ counting_min_maximal_independent_set = solve(problem, CountingMin())[]
# ##### finding all maximal independent set
maximal_configs = solve(problem, ConfigsAll())[]

all(c->is_maximal_independent_set(g, i), maximal_configs)
all(c->is_maximal_independent_set(graph, c), maximal_configs)

#

Expand All @@ -90,7 +90,7 @@ cliques = maximal_cliques(complement(graph))

# ##### finding minimum maximal independent set
# It is the [`ConfigsMin`](@ref) property in the program.
minimum_maximal_configs = solve(problem, ConfigsMin())[]
minimum_maximal_configs = solve(problem, ConfigsMin())[].c

imgs2 = ntuple(k->show_graph(graph;
locs=locations, scale=0.25,
Expand Down
2 changes: 1 addition & 1 deletion src/GraphTensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export GreedyMethod, TreeSA, SABipartite, KaHyParBipartite, MergeVectors, MergeG
# Algebras
export StaticBitVector, StaticElementVector, @bv_str
export is_commutative_semiring
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler
export Max2Poly, TruncatedPoly, Polynomial, Tropical, CountingTropical, StaticElementVector, Mod, ConfigEnumerator, onehotv, ConfigSampler, TreeConfigEnumerator
export CountingTropicalF64, CountingTropicalF32, TropicalF64, TropicalF32

# Lower level APIs
Expand Down
196 changes: 185 additions & 11 deletions src/arithematics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ using Polynomials: Polynomial
using TropicalNumbers: Tropical, CountingTropical
using Mods, Primes
using Base.Cartesian
import AbstractTrees: children, printnode, print_tree

@enum TreeTag LEAF SUM PROD ZERO

# pirate
Base.abs(x::Mod) = x
Expand Down Expand Up @@ -275,20 +278,179 @@ Base.one(::Type{ConfigSampler{N,S,C}}) where {N,S,C} = ConfigSampler{N,S,C}(zero
Base.zero(::ConfigSampler{N,S,C}) where {N,S,C} = zero(ConfigSampler{N,S,C})
Base.one(::ConfigSampler{N,S,C}) where {N,S,C} = one(ConfigSampler{N,S,C})

# A patch to make `Polynomial{ConfigEnumerator}` work
function Base.:*(a::Int, y::ConfigEnumerator)
a == 0 && return zero(y)
a == 1 && return y
error("multiplication between int and config enumerator is not defined.")
# tree config enumerator
"""
TreeConfigEnumerator{N,S,C}

Configuration enumerator encoded in a tree, it is the most natural representation given by a sum-product network
and is often more memory efficient than putting the configurations in a vector.
`N`, `S` and `C` are type parameters from the [`StaticElementVector`](@ref){N,S,C}.

Fields
-----------------------
* `tag` is one of `ZERO`, `LEAF`, `SUM`, `PROD`.
* `data` is the element stored in a `LEAF` node.
* `left` and `right` are two operands of a `SUM` or `PROD` node.

Example
------------------------
```jldoctest; setup=:(using GraphTensorNetworks)
julia> s = TreeConfigEnumerator(bv"00111")
00111


julia> q = TreeConfigEnumerator(bv"10000")
10000


julia> x = s + q
+
├─ 00111
└─ 10000


julia> y = x * x
*
├─ +
│ ├─ 00111
│ └─ 10000
└─ +
├─ 00111
└─ 10000


julia> collect(y)
4-element Vector{StaticBitVector{5, 1}}:
00111
10111
10111
10000

julia> zero(s)



julia> one(s)
00000


```
"""
struct TreeConfigEnumerator{N,S,C}
tag::TreeTag
data::StaticElementVector{N,S,C}
left::TreeConfigEnumerator{N,S,C}
right::TreeConfigEnumerator{N,S,C}
TreeConfigEnumerator(tag::TreeTag, left::TreeConfigEnumerator{N,S,C}, right::TreeConfigEnumerator{N,S,C}) where {N,S,C} = new{N,S,C}(tag, zero(StaticElementVector{N,S,C}), left, right)
function TreeConfigEnumerator(data::StaticElementVector{N,S,C}) where {N,S,C}
new{N,S,C}(LEAF, data)
end
function TreeConfigEnumerator{N,S,C}(tag::TreeTag) where {N,S,C}
@assert tag === ZERO
return new{N,S,C}(tag)
end
end
function Base.:*(a::Int, y::ConfigSampler)
a == 0 && return zero(y)
a == 1 && return y
error("multiplication between int and config sampler is not defined.")

# AbstractTree APIs
function children(t::TreeConfigEnumerator)
if isdefined(t, :left)
if isdefined(t, :right)
return [t.left, t.right]
else
return [t.left]
end
else
if isdefined(t, :right)
return [t.right]
else
return typeof(t)[]
end
end
end
function printnode(io::IO, t::TreeConfigEnumerator)
if t.tag === LEAF
print(io, t.data)
elseif t.tag === ZERO
print(io, "")
elseif t.tag === SUM
print(io, "+")
else # PROD
print(io, "*")
end
end

function Base.length(x::TreeConfigEnumerator)
if x.tag === SUM
return length(x.left) + length(x.right)
elseif x.tag === PROD
return length(x.left) * length(x.right)
elseif x.tag === ZERO
return 0
else
return 1
end
end

function num_nodes(x::TreeConfigEnumerator)
x.tag == ZERO && return 1
x.tag == LEAF && return 1
return num_nodes(x.left) + num_nodes(x.right) + 1
end

function Base.:(==)(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
return Set(collect(x)) == Set(collect(y))
end

Base.show(io::IO, t::TreeConfigEnumerator) = print_tree(io, t)

function Base.collect(x::TreeConfigEnumerator{N,S,C}) where {N,S,C}
if x.tag == ZERO
return StaticElementVector{N,S,C}[]
elseif x.tag == LEAF
return StaticElementVector{N,S,C}[x.data]
elseif x.tag == SUM
return vcat(collect(x.left), collect(x.right))
else # PROD
return vec([reduce((x,y)->x|y, si) for si in Iterators.product(collect(x.left), collect(x.right))])
end
end

function Base.:+(x::TreeConfigEnumerator{N,S,C}, y::TreeConfigEnumerator{N,S,C}) where {N,S,C}
TreeConfigEnumerator(SUM, x, y)
end

function Base.:*(x::TreeConfigEnumerator{L,S,C}, y::TreeConfigEnumerator{L,S,C}) where {L,S,C}
TreeConfigEnumerator(PROD, x, y)
end

Base.zero(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator{N,S,C}(ZERO)
Base.one(::Type{TreeConfigEnumerator{N,S,C}}) where {N,S,C} = TreeConfigEnumerator(zero(StaticElementVector{N,S,C}))
Base.zero(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = zero(TreeConfigEnumerator{N,S,C})
Base.one(::TreeConfigEnumerator{N,S,C}) where {N,S,C} = one(TreeConfigEnumerator{N,S,C})
# todo, check siblings too?
function Base.iszero(t::TreeConfigEnumerator)
if t.TAG == SUM
iszero(t.left) && iszero(t.right)
elseif t.TAG == ZERO
true
elseif t.TAG == LEAF
false
else
iszero(t.left) || iszero(t.right)
end
end

# A patch to make `Polynomial{ConfigEnumerator}` work
for T in [:ConfigEnumerator, :ConfigSampler, :TreeConfigEnumerator]
@eval function Base.:*(a::Int, y::$T)
a == 0 && return zero(y)
a == 1 && return y
error("multiplication between int and `$(typeof(y))` is not defined.")
end
end

# convert from counting type to bitstring type
for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler)]
for (F,TP) in [(:set_type, :ConfigEnumerator), (:sampler_type, :ConfigSampler), (:treeset_type, :TreeConfigEnumerator)]
@eval begin
function $F(::Type{T}, n::Int, nflavor::Int) where {OT, K, T<:TruncatedPoly{K,C,OT} where C}
TruncatedPoly{K, $F(n,nflavor),OT}
Expand All @@ -312,12 +474,24 @@ end

# utilities for creating onehot vectors
onehotv(::Type{ConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigEnumerator([onehotv(StaticElementVector{N,S,C}, i, v)])
onehotv(::Type{TreeConfigEnumerator{N,S,C}}, i::Integer, v) where {N,S,C} = TreeConfigEnumerator(onehotv(StaticElementVector{N,S,C}, i, v))
onehotv(::Type{ConfigSampler{N,S,C}}, i::Integer, v) where {N,S,C} = ConfigSampler(onehotv(StaticElementVector{N,S,C}, i, v))
# just to make matrix transpose work
Base.transpose(c::ConfigEnumerator) = c
Base.copy(c::ConfigEnumerator) = ConfigEnumerator(copy(c.data))
Base.transpose(c::TreeConfigEnumerator) = c
function Base.copy(c::TreeConfigEnumerator)
if c.tag == LEAF
TreeConfigEnumerator(c.data)
elseif c.tag == ZERO
TreeConfigEnumerator(c.tag)
else
TreeConfigEnumerator(c.tag, c.left, c.right)
end
end

# Handle boolean, this is a patch for CUDA matmul
for TYPE in [:ConfigEnumerator, :ConfigSampler, :TruncatedPoly]
for TYPE in [:ConfigEnumerator, :ConfigSampler, :TruncatedPoly, :TreeConfigEnumerator]
@eval Base.:*(a::Bool, y::$TYPE) = a ? y : zero(y)
@eval Base.:*(y::$TYPE, a::Bool) = a ? y : zero(y)
end
1 change: 0 additions & 1 deletion src/bounding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ function backward_tropical(mode, ixs, @nospecialize(xs::Tuple), iy, @nospecializ
mask .= inv.(einsum(EinCode(nixs, niy), nxs, size_dict)) .<= xs[i] .* Tropical(largest_k(mode)-1+1e-12)
push!(masks, mask)
elseif mode isa SingleConfig
A = zeros(eltype(xs[i]), size(xs[i]))
A = einsum(EinCode(nixs, niy), nxs, size_dict)
push!(masks, onehotmask(A, xs[i]))
else
Expand Down
Loading