Skip to content

iddict to dict #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "GraphTensorNetworks"
uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2"
authors = ["GiggleLiu <cacate0129@gmail.com> and contributors"]
version = "0.2.8"
version = "0.2.9"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand All @@ -21,6 +21,7 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334"
Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9"
Expand Down
12 changes: 6 additions & 6 deletions docs/src/performancetips.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ Key word argument `optimizer` decides the contraction order optimizer of the ten
Here, we choose the `TreeSA` optimizer to optimize the tensor network contraciton tree, it is a local search based algorithm.
It is one of the state of the art tensor network contraction order optimizers, one may check [arXiv: 2108.05665](https://arxiv.org/abs/2108.05665) to learn more about the algorithm.
Other optimizers include
* [`GreedyMethod`](@ref) (default, fastest in searching speed but worse in contraction order)
* [`TreeSA`](@ref)
* [`GreedyMethod`](@ref) (default, fastest in searching speed but worst in contraction complexity)
* [`TreeSA`](@ref) (often best in contraction complexity, supports slicing)
* [`KaHyParBipartite`](@ref)
* [`SABipartite`](@ref)

Expand All @@ -32,8 +32,8 @@ julia> timespacereadwrite_complexity(problem)
```

The return values are `log2` of the the number of iterations, the number elements in the largest tensor during contraction and the number of read-write operations to tensor elements.
In this example, the number of `+` and `*` operations are both `\sim 2^{21.9}`
and the number of read-write operations are `\sim 2^{20}`.
In this example, the number of `+` and `*` operations are both ``\sim 2^{21.9}``
and the number of read-write operations are ``\sim 2^{20}``.
The largest tensor size is ``2^17``, one can check the element size by typing
```julia
julia> sizeof(TropicalF64)
Expand Down Expand Up @@ -136,7 +136,7 @@ julia> lineplot(hamming_distribution(samples, samples))
```

## Multiprocessing
Submodule `GraphTensorNetworks.SimpleMutiprocessing` provides a function [`multiprocess_run`](@ref) function for simple multi-processing jobs.
Submodule `GraphTensorNetworks.SimpleMutiprocessing` provides a function [`GraphTensorNetworks.SimpleMultiprocessing.multiprocess_run`](@ref) function for simple multi-processing jobs.
Suppose we want to find the independence polynomial for multiple graphs with 4 processes.
We can create a file, e.g. named `run.jl` with the following content

Expand Down Expand Up @@ -190,4 +190,4 @@ CUDA backended properties are
* [`CountingAll`](@ref)
* [`CountingMax`](@ref)
* [`GraphPolynomial`](@ref)
* [`SingleConfigMax`](@ref)
* [`SingleConfigMax`](@ref)
2 changes: 2 additions & 0 deletions docs/src/ref.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ StaticBitVector
StaticElementVector
save_configs
load_configs
save_sumproduct
load_sumproduct
@bv_str
onehotv

Expand Down
4 changes: 3 additions & 1 deletion examples/IndependentSet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ all_independent_sets_tree = solve(problem, ConfigsAll(; tree_storage=true))[]
# The results encode the configurations in the sum-product-tree format. One can count and enumerate them explicitly by typing
length(all_independent_sets_tree)

#
# Then one can use `Base.collect` function to create a [`ConfigEnumerator`](@ref) or use [`generate_samples`](@ref) to generate samples from it.

collect(all_independent_sets_tree)

generate_samples(all_independent_sets_tree, 10)

# One can use [`save_configs`](@ref) and [`load_configs`](@ref) to save and read a set of configuration to disk.
filename = tempname()

Expand Down
4 changes: 3 additions & 1 deletion src/GraphTensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using TropicalNumbers
using OMEinsum
using OMEinsum: timespace_complexity, getixsv
using Graphs, Random
using DelimitedFiles, Serialization

# OMEinsum
export timespace_complexity, timespacereadwrite_complexity, @ein_str, getixsv, getiyv
Expand Down Expand Up @@ -48,7 +49,7 @@ export is_matching
export solve, SizeMax, SizeMin, CountingAll, CountingMax, CountingMin, GraphPolynomial, SingleConfigMax, SingleConfigMin, ConfigsAll, ConfigsMax, ConfigsMin

# Utilities
export save_configs, load_configs, hamming_distribution
export save_configs, load_configs, hamming_distribution, save_sumproduct, load_sumproduct

# Visualization
export show_graph, spring_layout
Expand All @@ -64,6 +65,7 @@ include("configurations.jl")
include("graphs.jl")
include("bounding.jl")
include("visualize.jl")
include("fileio.jl")
include("interfaces.jl")
include("deprecate.jl")
include("multiprocessing.jl")
Expand Down
95 changes: 63 additions & 32 deletions src/arithematics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ function collect_geq!(res, A, B, mB, low)
K = length(A)
k = 1 # TODO: we should tighten mA, mB later!
Ak = A[K-k+1]
Bq = B[K-mB+1]
l = 0
for q = K-mB+1:-1:1
Bq = B[K-q+1]
Expand Down Expand Up @@ -587,37 +586,56 @@ function printnode(io::IO, t::SumProductTree{ET}) where {ET}
end
end

# it must be mutable, otherwise the `IdDict` trick for computing the length does not work.
Base.length(x::SumProductTree) = _length!(x, IdDict{typeof(x), Int}())
# it must be mutable, otherwise, objectid might be slow serialization might fail.
# IdDict is much slower than Dict, it is useless.
Base.length(x::SumProductTree) = _length!(x, Dict{UInt, Float64}())

function _length!(x, d)
haskey(d, x) && return d[x]
id = objectid(x)
haskey(d, id) && return d[id]
if x.tag === SUM
l = _length!(x.left, d) + _length!(x.right, d)
d[x] = l
d[id] = l
return l
elseif x.tag === PROD
l = _length!(x.left, d) * _length!(x.right, d)
d[x] = l
d[id] = l
return l
elseif x.tag === ZERO
return 0
return 0.0
else
return 1
return 1.0
end
end

num_nodes(x::SumProductTree) = _num_nodes(x, IdDict{typeof(x), Int}())
function _find_branch(x, d)
if x.tag === ZERO
return true, 0.0
elseif x.tag === ONE || x.tag === LEAF
return true, 1.0
else
idl = objectid(x.left)
if haskey(d, idl)
return true, d[idl]
else
return false, 0.0
end
end
end


num_nodes(x::SumProductTree) = _num_nodes(x, Dict{UInt, Int}())
function _num_nodes(x, d)
haskey(d, x) && return 0
id = objectid(x)
haskey(d, id) && return 0
if x.tag == ZERO || x.tag == ONE
res = 1
elseif x.tag == LEAF
res = 1
else
res = _num_nodes(x.left, d) + _num_nodes(x.right, d) + 1
end
d[x] = res
d[id] = res
return res
end

Expand Down Expand Up @@ -708,33 +726,46 @@ true
function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET}
# get length dict
res = fill(_data_one(ET), nsamples)
d = IdDict{typeof(t), Int}()
d = Dict{UInt, Float64}()
sample_descend!(res, t, d)
return res
end

function sample_descend!(res::AbstractVector, t::SumProductTree, d::IdDict)
length(res) == 0 && return res
if t.tag == LEAF
res .|= Ref(t.data)
elseif t.tag == SUM
ratio = _length!(t.left, d)/_length!(t, d)
nleft = 0
for _ = 1:length(res)
if rand() < ratio
nleft += 1
function sample_descend!(res::AbstractVector, t::SumProductTree, d::Dict)
res_stack = Any[res]
t_stack = [t]
while !isempty(t_stack) && !isempty(res_stack)
t = pop!(t_stack)
res = pop!(res_stack)
if t.tag == LEAF
res .|= Ref(t.data)
elseif t.tag == SUM
ratio = _length!(t.left, d)/_length!(t, d)
nleft = 0
for _ = 1:length(res)
if rand() < ratio
nleft += 1
end
end
shuffle!(res) # shuffle the `res` to avoid biased sampling, very important.
if nleft >= 1
push!(res_stack, view(res,1:nleft))
push!(t_stack, t.left)
end
if length(res) > nleft
push!(res_stack, view(res,nleft+1:length(res)))
push!(t_stack, t.right)
end
elseif t.tag == PROD
push!(res_stack, res)
push!(res_stack, res)
push!(t_stack, t.left)
push!(t_stack, t.right)
elseif t.tag == ZERO
error("Meet zero when descending.")
else
# pass for 1
end
shuffle!(res) # shuffle the `res` to avoid biased sampling, very important.
sample_descend!(view(res,1:nleft), t.left, d)
sample_descend!(view(res,nleft+1:length(res)), t.right, d)
elseif t.tag == PROD
sample_descend!(res, t.right, d)
sample_descend!(res, t.left, d)
elseif t.tag == ZERO
error("Meet zero when descending.")
else
# pass for 1
end
return res
end
Expand Down
125 changes: 125 additions & 0 deletions src/fileio.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""
save_configs(filename, data::ConfigEnumerator; format=:binary)

Save configurations `data` to file `filename`. The format is `:binary` or `:text`.
"""
function save_configs(filename, data::ConfigEnumerator{N,S,C}; format::Symbol=:binary) where {N,S,C}
if format == :binary
write(filename, raw_matrix(data))
elseif format == :text
writedlm(filename, plain_matrix(data))
else
error("format must be `:binary` or `:text`, got `:$format`")
end
end

"""
load_configs(filename; format=:binary, bitlength=nothing, nflavors=2)

Load configurations from file `filename`. The format is `:binary` or `:text`.
If the format is `:binary`, the bitstring length `bitlength` must be specified,
`nflavors` specifies the degree of freedom.
"""
function load_configs(filename; bitlength=nothing, format::Symbol=:binary, nflavors=2)
if format == :binary
bitlength === nothing && error("you need to specify `bitlength` for reading configurations from binary files.")
S = ceil(Int, log2(nflavors))
C = _nints(bitlength, S)
return _from_raw_matrix(StaticElementVector{bitlength,S,C}, reshape(reinterpret(UInt64, read(filename)),C,:))
elseif format == :text
return from_plain_matrix(readdlm(filename); nflavors=nflavors)
else
error("format must be `:binary` or `:text`, got `:$format`")
end
end

function raw_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C}
m = zeros(UInt64, C, length(x))
@inbounds for i=1:length(x), j=1:C
m[j,i] = x.data[i].data[j]
end
return m
end
function plain_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C}
m = zeros(UInt8, N, length(x))
@inbounds for i=1:length(x), j=1:N
m[j,i] = x.data[i][j]
end
return m
end

function from_raw_matrix(m; bitlength, nflavors=2)
S = ceil(Int,log2(nflavors))
C = size(m, 1)
T = StaticElementVector{bitlength,S,C}
@assert bitlength*S <= C*64
_from_raw_matrix(T, m)
end
function _from_raw_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C}
data = zeros(StaticElementVector{N,S,C}, size(m, 2))
@inbounds for i=1:size(m, 2)
data[i] = StaticElementVector{N,S,C}(NTuple{C,UInt64}(view(m,:,i)))
end
return ConfigEnumerator(data)
end
function from_plain_matrix(m::Matrix; nflavors=2)
S = ceil(Int,log2(nflavors))
N = size(m, 1)
C = _nints(N, S)
T = StaticElementVector{N,S,C}
_from_plain_matrix(T, m)
end
function _from_plain_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C}
data = zeros(StaticElementVector{N,S,C}, size(m, 2))
@inbounds for i=1:size(m, 2)
data[i] = convert(StaticElementVector{N,S,C}, view(m, :, i))
end
return ConfigEnumerator(data)
end

# convert to Matrix
Base.Matrix(ce::ConfigEnumerator) = plain_matrix(ce)
Base.Vector(ce::StaticElementVector) = collect(ce)

########## saving tree ####################
"""
save_sumproduct(filename, t::SumProductTree)

Serialize a sum-product tree into a file.
"""
save_sumproduct(filename::String, t::SumProductTree) = serialize(filename, dict_serialize_tree!(t, Dict{UInt,Any}()))

"""
load_sumproduct(filename)

Deserialize a sum-product tree from a file.
"""
load_sumproduct(filename::String) = dict_deserialize_tree(deserialize(filename)...)

function dict_serialize_tree!(t::SumProductTree, d::Dict)
id = objectid(t)
if !haskey(d, id)
if t.tag === GraphTensorNetworks.LEAF || t.tag === GraphTensorNetworks.ZERO || t.tag == GraphTensorNetworks.ONE
d[id] = t
else
d[id] = (t.tag, objectid(t.left), objectid(t.right))
dict_serialize_tree!(t.left, d)
dict_serialize_tree!(t.right, d)
end
end
return id, d
end

function dict_deserialize_tree(id::UInt, d::Dict)
@assert haskey(d, id)
content = d[id]
if content isa SumProductTree
return content
else
(tag, left, right) = content
t = SumProductTree(tag, dict_deserialize_tree(left, d), dict_deserialize_tree(right, d))
d[id] = t
return t
end
end

Loading