From 04b04724ebc66613ef702dfd3b85330dfacb1569 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Mon, 21 Mar 2022 23:05:05 -0400 Subject: [PATCH 1/6] iddict to dict --- Project.toml | 3 +- docs/src/performancetips.md | 10 +-- docs/src/ref.md | 2 + src/GraphTensorNetworks.jl | 4 +- src/arithematics.jl | 28 ++++---- src/fileio.jl | 125 ++++++++++++++++++++++++++++++++++++ src/interfaces.jl | 85 ------------------------ test/fileio.jl | 53 +++++++++++++++ test/interfaces.jl | 39 ----------- test/runtests.jl | 4 ++ 10 files changed, 209 insertions(+), 144 deletions(-) create mode 100644 src/fileio.jl create mode 100644 test/fileio.jl diff --git a/Project.toml b/Project.toml index b1200ca0..40c348b8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "GraphTensorNetworks" uuid = "0978c8c2-34f6-49c7-9826-ea2cc20dabd2" authors = ["GiggleLiu and contributors"] -version = "0.2.8" +version = "0.2.9" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -21,6 +21,7 @@ Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SIMDTypes = "94e857df-77ce-4151-89e5-788b33177be4" +Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TropicalNumbers = "b3a74e9c-7526-4576-a4eb-79c0d4c32334" Viznet = "52a3aca4-6234-47fd-b74a-806bdf78ede9" diff --git a/docs/src/performancetips.md b/docs/src/performancetips.md index 983956c8..c409579e 100644 --- a/docs/src/performancetips.md +++ b/docs/src/performancetips.md @@ -15,8 +15,8 @@ Key word argument `optimizer` decides the contraction order optimizer of the ten Here, we choose the `TreeSA` optimizer to optimize the tensor network contraciton tree, it is a local search based algorithm. It is one of the state of the art tensor network contraction order optimizers, one may check [arXiv: 2108.05665](https://arxiv.org/abs/2108.05665) to learn more about the algorithm. Other optimizers include -* [`GreedyMethod`](@ref) (default, fastest in searching speed but worse in contraction order) -* [`TreeSA`](@ref) +* [`GreedyMethod`](@ref) (default, fastest in searching speed but worst in contraction complexity) +* [`TreeSA`](@ref) (often best in contraction complexity, supports slicing) * [`KaHyParBipartite`](@ref) * [`SABipartite`](@ref) @@ -32,8 +32,8 @@ julia> timespacereadwrite_complexity(problem) ``` The return values are `log2` of the the number of iterations, the number elements in the largest tensor during contraction and the number of read-write operations to tensor elements. -In this example, the number of `+` and `*` operations are both `\sim 2^{21.9}` -and the number of read-write operations are `\sim 2^{20}`. +In this example, the number of `+` and `*` operations are both ``\sim 2^{21.9}`` +and the number of read-write operations are ``\sim 2^{20}``. The largest tensor size is ``2^17``, one can check the element size by typing ```julia julia> sizeof(TropicalF64) @@ -190,4 +190,4 @@ CUDA backended properties are * [`CountingAll`](@ref) * [`CountingMax`](@ref) * [`GraphPolynomial`](@ref) -* [`SingleConfigMax`](@ref) \ No newline at end of file +* [`SingleConfigMax`](@ref) diff --git a/docs/src/ref.md b/docs/src/ref.md index 16005d9f..b2f5c036 100644 --- a/docs/src/ref.md +++ b/docs/src/ref.md @@ -98,6 +98,8 @@ StaticBitVector StaticElementVector save_configs load_configs +save_sumproduct +load_sumproduct @bv_str onehotv diff --git a/src/GraphTensorNetworks.jl b/src/GraphTensorNetworks.jl index 1a7dd091..25e48b8e 100644 --- a/src/GraphTensorNetworks.jl +++ b/src/GraphTensorNetworks.jl @@ -6,6 +6,7 @@ using TropicalNumbers using OMEinsum using OMEinsum: timespace_complexity, getixsv using Graphs, Random +using DelimitedFiles, Serialization # OMEinsum export timespace_complexity, timespacereadwrite_complexity, @ein_str, getixsv, getiyv @@ -48,7 +49,7 @@ export is_matching export solve, SizeMax, SizeMin, CountingAll, CountingMax, CountingMin, GraphPolynomial, SingleConfigMax, SingleConfigMin, ConfigsAll, ConfigsMax, ConfigsMin # Utilities -export save_configs, load_configs, hamming_distribution +export save_configs, load_configs, hamming_distribution, save_sumproduct, load_sumproduct # Visualization export show_graph, spring_layout @@ -64,6 +65,7 @@ include("configurations.jl") include("graphs.jl") include("bounding.jl") include("visualize.jl") +include("fileio.jl") include("interfaces.jl") include("deprecate.jl") include("multiprocessing.jl") diff --git a/src/arithematics.jl b/src/arithematics.jl index 15924622..31e78986 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -297,7 +297,6 @@ function collect_geq!(res, A, B, mB, low) K = length(A) k = 1 # TODO: we should tighten mA, mB later! Ak = A[K-k+1] - Bq = B[K-mB+1] l = 0 for q = K-mB+1:-1:1 Bq = B[K-q+1] @@ -587,29 +586,32 @@ function printnode(io::IO, t::SumProductTree{ET}) where {ET} end end -# it must be mutable, otherwise the `IdDict` trick for computing the length does not work. -Base.length(x::SumProductTree) = _length!(x, IdDict{typeof(x), Int}()) +# it must be mutable, otherwise, objectid might be slow serialization might fail. +# IdDict is much slower than Dict, it is useless. +Base.length(x::SumProductTree) = _length!(x, Dict{UInt, Float64}()) function _length!(x, d) - haskey(d, x) && return d[x] + id = objectid(x) + haskey(d, id) && return d[id] if x.tag === SUM l = _length!(x.left, d) + _length!(x.right, d) - d[x] = l + d[id] = l return l elseif x.tag === PROD l = _length!(x.left, d) * _length!(x.right, d) - d[x] = l + d[id] = l return l elseif x.tag === ZERO - return 0 + return 0.0 else - return 1 + return 1.0 end end -num_nodes(x::SumProductTree) = _num_nodes(x, IdDict{typeof(x), Int}()) +num_nodes(x::SumProductTree) = _num_nodes(x, Dict{UInt, Int}()) function _num_nodes(x, d) - haskey(d, x) && return 0 + id = objectid(x) + haskey(d, id) && return 0 if x.tag == ZERO || x.tag == ONE res = 1 elseif x.tag == LEAF @@ -617,7 +619,7 @@ function _num_nodes(x, d) else res = _num_nodes(x.left, d) + _num_nodes(x.right, d) + 1 end - d[x] = res + d[id] = res return res end @@ -708,12 +710,12 @@ true function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET} # get length dict res = fill(_data_one(ET), nsamples) - d = IdDict{typeof(t), Int}() + d = Dict{UInt, Float64}() sample_descend!(res, t, d) return res end -function sample_descend!(res::AbstractVector, t::SumProductTree, d::IdDict) +function sample_descend!(res::AbstractVector, t::SumProductTree, d::Dict) length(res) == 0 && return res if t.tag == LEAF res .|= Ref(t.data) diff --git a/src/fileio.jl b/src/fileio.jl new file mode 100644 index 00000000..9bfa6232 --- /dev/null +++ b/src/fileio.jl @@ -0,0 +1,125 @@ +""" + save_configs(filename, data::ConfigEnumerator; format=:binary) + +Save configurations `data` to file `filename`. The format is `:binary` or `:text`. +""" +function save_configs(filename, data::ConfigEnumerator{N,S,C}; format::Symbol=:binary) where {N,S,C} + if format == :binary + write(filename, raw_matrix(data)) + elseif format == :text + writedlm(filename, plain_matrix(data)) + else + error("format must be `:binary` or `:text`, got `:$format`") + end +end + +""" + load_configs(filename; format=:binary, bitlength=nothing, nflavors=2) + +Load configurations from file `filename`. The format is `:binary` or `:text`. +If the format is `:binary`, the bitstring length `bitlength` must be specified, +`nflavors` specifies the degree of freedom. +""" +function load_configs(filename; bitlength=nothing, format::Symbol=:binary, nflavors=2) + if format == :binary + bitlength === nothing && error("you need to specify `bitlength` for reading configurations from binary files.") + S = ceil(Int, log2(nflavors)) + C = _nints(bitlength, S) + return _from_raw_matrix(StaticElementVector{bitlength,S,C}, reshape(reinterpret(UInt64, read(filename)),C,:)) + elseif format == :text + return from_plain_matrix(readdlm(filename); nflavors=nflavors) + else + error("format must be `:binary` or `:text`, got `:$format`") + end +end + +function raw_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C} + m = zeros(UInt64, C, length(x)) + @inbounds for i=1:length(x), j=1:C + m[j,i] = x.data[i].data[j] + end + return m +end +function plain_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C} + m = zeros(UInt8, N, length(x)) + @inbounds for i=1:length(x), j=1:N + m[j,i] = x.data[i][j] + end + return m +end + +function from_raw_matrix(m; bitlength, nflavors=2) + S = ceil(Int,log2(nflavors)) + C = size(m, 1) + T = StaticElementVector{bitlength,S,C} + @assert bitlength*S <= C*64 + _from_raw_matrix(T, m) +end +function _from_raw_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C} + data = zeros(StaticElementVector{N,S,C}, size(m, 2)) + @inbounds for i=1:size(m, 2) + data[i] = StaticElementVector{N,S,C}(NTuple{C,UInt64}(view(m,:,i))) + end + return ConfigEnumerator(data) +end +function from_plain_matrix(m::Matrix; nflavors=2) + S = ceil(Int,log2(nflavors)) + N = size(m, 1) + C = _nints(N, S) + T = StaticElementVector{N,S,C} + _from_plain_matrix(T, m) +end +function _from_plain_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C} + data = zeros(StaticElementVector{N,S,C}, size(m, 2)) + @inbounds for i=1:size(m, 2) + data[i] = convert(StaticElementVector{N,S,C}, view(m, :, i)) + end + return ConfigEnumerator(data) +end + +# convert to Matrix +Base.Matrix(ce::ConfigEnumerator) = plain_matrix(ce) +Base.Vector(ce::StaticElementVector) = collect(ce) + +########## saving tree #################### +""" + save_sumproduct(filename, t::SumProductTree) + +Serialize a sum-product tree into a file. +""" +save_sumproduct(filename::String, t::SumProductTree) = serialize(filename, dict_serialize_tree!(t, Dict{UInt,Any}())) + +""" + load_sumproduct(filename) + +Deserialize a sum-product tree from a file. +""" +load_sumproduct(filename::String) = dict_deserialize_tree(deserialize(filename)...) + +function dict_serialize_tree!(t::SumProductTree, d::Dict) + id = objectid(t) + if !haskey(d, id) + if t.tag === GraphTensorNetworks.LEAF || t.tag === GraphTensorNetworks.ZERO || t.tag == GraphTensorNetworks.ONE + d[id] = t + else + d[id] = (t.tag, objectid(t.left), objectid(t.right)) + dict_serialize_tree!(t.left, d) + dict_serialize_tree!(t.right, d) + end + end + return id, d +end + +function dict_deserialize_tree(id::UInt, d::Dict) + @assert haskey(d, id) + content = d[id] + if content isa SumProductTree + return content + else + (tag, left, right) = content + t = SumProductTree(tag, dict_deserialize_tree(left, d), dict_deserialize_tree(right, d)) + d[id] = t + return t + end +end + diff --git a/src/interfaces.jl b/src/interfaces.jl index e4f29bb0..dd4e5fe3 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -313,91 +313,6 @@ It is a shorthand of `solve(problem, CountingMax(); usecuda=false)`. """ max_size_count(m::GraphProblem; usecuda=false) = (r = sum(solve(m, CountingMax(); usecuda=usecuda)); (Int(r.n), Int(r.c))) -using DelimitedFiles - -""" - save_configs(filename, data::ConfigEnumerator; format=:binary) - -Save configurations `data` to file `filename`. The format is `:binary` or `:text`. -""" -function save_configs(filename, data::ConfigEnumerator{N,S,C}; format::Symbol=:binary) where {N,S,C} - if format == :binary - write(filename, raw_matrix(data)) - elseif format == :text - writedlm(filename, plain_matrix(data)) - else - error("format must be `:binary` or `:text`, got `:$format`") - end -end - -""" - load_configs(filename; format=:binary, bitlength=nothing, nflavors=2) - -Load configurations from file `filename`. The format is `:binary` or `:text`. -If the format is `:binary`, the bitstring length `bitlength` must be specified, -`nflavors` specifies the degree of freedom. -""" -function load_configs(filename; bitlength=nothing, format::Symbol=:binary, nflavors=2) - if format == :binary - bitlength === nothing && error("you need to specify `bitlength` for reading configurations from binary files.") - S = ceil(Int, log2(nflavors)) - C = _nints(bitlength, S) - return _from_raw_matrix(StaticElementVector{bitlength,S,C}, reshape(reinterpret(UInt64, read(filename)),C,:)) - elseif format == :text - return from_plain_matrix(readdlm(filename); nflavors=nflavors) - else - error("format must be `:binary` or `:text`, got `:$format`") - end -end - -function raw_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C} - m = zeros(UInt64, C, length(x)) - @inbounds for i=1:length(x), j=1:C - m[j,i] = x.data[i].data[j] - end - return m -end -function plain_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C} - m = zeros(UInt8, N, length(x)) - @inbounds for i=1:length(x), j=1:N - m[j,i] = x.data[i][j] - end - return m -end - -function from_raw_matrix(m; bitlength, nflavors=2) - S = ceil(Int,log2(nflavors)) - C = size(m, 1) - T = StaticElementVector{bitlength,S,C} - @assert bitlength*S <= C*64 - _from_raw_matrix(T, m) -end -function _from_raw_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C} - data = zeros(StaticElementVector{N,S,C}, size(m, 2)) - @inbounds for i=1:size(m, 2) - data[i] = StaticElementVector{N,S,C}(NTuple{C,UInt64}(view(m,:,i))) - end - return ConfigEnumerator(data) -end -function from_plain_matrix(m::Matrix; nflavors=2) - S = ceil(Int,log2(nflavors)) - N = size(m, 1) - C = _nints(N, S) - T = StaticElementVector{N,S,C} - _from_plain_matrix(T, m) -end -function _from_plain_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C} - data = zeros(StaticElementVector{N,S,C}, size(m, 2)) - @inbounds for i=1:size(m, 2) - data[i] = convert(StaticElementVector{N,S,C}, view(m, :, i)) - end - return ConfigEnumerator(data) -end - -# convert to Matrix -Base.Matrix(ce::ConfigEnumerator) = plain_matrix(ce) -Base.Vector(ce::StaticElementVector) = collect(ce) - ########## memory estimation ############### """ estimate_memory(problem, property; T=Float64) diff --git a/test/fileio.jl b/test/fileio.jl new file mode 100644 index 00000000..2a3feba1 --- /dev/null +++ b/test/fileio.jl @@ -0,0 +1,53 @@ +using GraphTensorNetworks, Graphs, Test + +@testset "save load" begin + M = 10 + fname = tempname() + m = ConfigEnumerator([StaticBitVector(rand(Bool, 300)) for i=1:M]) + bm = GraphTensorNetworks.plain_matrix(m) + rm = GraphTensorNetworks.raw_matrix(m) + m1 = GraphTensorNetworks.from_raw_matrix(rm; bitlength=300, nflavors=2) + m2 = GraphTensorNetworks.from_plain_matrix(bm; nflavors=2) + @test m1 == m + @test m2 == m + save_configs(fname, m; format=:binary) + @test_throws ErrorException load_configs("_test.bin"; format=:binary) + ma = load_configs(fname; format=:binary, bitlength=300, nflavors=2) + @test ma == m + + fname = tempname() + save_configs(fname, m; format=:text) + mb = load_configs(fname; format=:text, nflavors=2) + @test mb == m + + M = 10 + m = ConfigEnumerator([StaticElementVector(3, rand(0:2, 300)) for i=1:M]) + bm = GraphTensorNetworks.plain_matrix(m) + rm = GraphTensorNetworks.raw_matrix(m) + m1 = GraphTensorNetworks.from_raw_matrix(rm; bitlength=300, nflavors=3) + m2 = GraphTensorNetworks.from_plain_matrix(bm; nflavors=3) + @test m1 == m + @test m2 == m + @test Matrix(m) == bm + @test Vector(m.data[1]) == bm[:,1] + + fname = tempname() + save_configs(fname, m; format=:binary) + @test_throws ErrorException load_configs(fname; format=:binary) + ma = load_configs(fname; format=:binary, bitlength=300, nflavors=3) + @test ma == m + + fname = tempname() + save_configs(fname, m; format=:text) + mb = load_configs(fname; format=:text, nflavors=3) + @test mb == m +end + +@testset "save load tree" begin + fname = tempname() + tree = solve(IndependentSet(smallgraph(:petersen)), ConfigsAll(; tree_storage=true))[] + save_sumproduct(fname, tree) + ma = load_sumproduct(fname) + @test ma == tree +end + diff --git a/test/interfaces.jl b/test/interfaces.jl index 9d9aa844..eed54ba6 100644 --- a/test/interfaces.jl +++ b/test/interfaces.jl @@ -44,45 +44,6 @@ using Graphs, Test end end -@testset "save load" begin - M = 10 - m = ConfigEnumerator([StaticBitVector(rand(Bool, 300)) for i=1:M]) - bm = GraphTensorNetworks.plain_matrix(m) - rm = GraphTensorNetworks.raw_matrix(m) - m1 = GraphTensorNetworks.from_raw_matrix(rm; bitlength=300, nflavors=2) - m2 = GraphTensorNetworks.from_plain_matrix(bm; nflavors=2) - @test m1 == m - @test m2 == m - save_configs("_test.bin", m; format=:binary) - @test_throws ErrorException load_configs("_test.bin"; format=:binary) - ma = load_configs("_test.bin"; format=:binary, bitlength=300, nflavors=2) - @test ma == m - - save_configs("_test.txt", m; format=:text) - mb = load_configs("_test.txt"; format=:text, nflavors=2) - @test mb == m - - M = 10 - m = ConfigEnumerator([StaticElementVector(3, rand(0:2, 300)) for i=1:M]) - bm = GraphTensorNetworks.plain_matrix(m) - rm = GraphTensorNetworks.raw_matrix(m) - m1 = GraphTensorNetworks.from_raw_matrix(rm; bitlength=300, nflavors=3) - m2 = GraphTensorNetworks.from_plain_matrix(bm; nflavors=3) - @test m1 == m - @test m2 == m - @test Matrix(m) == bm - @test Vector(m.data[1]) == bm[:,1] - - save_configs("_test.bin", m; format=:binary) - @test_throws ErrorException load_configs("_test.bin"; format=:binary) - ma = load_configs("_test.bin"; format=:binary, bitlength=300, nflavors=3) - @test ma == m - - save_configs("_test.txt", m; format=:text) - mb = load_configs("_test.txt"; format=:text, nflavors=3) - @test mb == m -end - @testset "slicing" begin g = Graphs.smallgraph("petersen") gp = IndependentSet(g; optimizer=TreeSA(nslices=5, ntrials=1)) diff --git a/test/runtests.jl b/test/runtests.jl index 1a84d207..673fa210 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,10 @@ end include("visualize.jl") end +@testset "fileio" begin + include("fileio.jl") +end + @testset "multiprocessing" begin include("multiprocessing.jl") end From 4dfd4569e4018d71a49628c7686228f15c7ad381 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 22 Mar 2022 00:34:34 -0400 Subject: [PATCH 2/6] save remove stack --- src/arithematics.jl | 112 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 93 insertions(+), 19 deletions(-) diff --git a/src/arithematics.jl b/src/arithematics.jl index 31e78986..596b05d6 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -608,6 +608,71 @@ function _length!(x, d) end end +# # loop version +# function _length!(x, d) +# rootid = objectid(x) +# t_stack = [x] +# # update dict +# while !isempty(t_stack) +# x = t_stack[end] +# id = objectid(x) +# if haskey(d, id) +# pop!(t_stack) +# else +# if x.tag === SUM +# idl = objectid(x.left) +# if haskey(d, idl) +# idr = objectid(x.right) +# if haskey(d, idr) +# @inbounds d[id] = d[idl] + d[idr] +# pop!(t_stack) +# else +# push!(t_stack, x.right) +# end +# else +# push!(t_stack, x.left) +# end +# elseif x.tag === PROD +# idl = objectid(x.left) +# if haskey(d, idl) +# idr = objectid(x.right) +# if haskey(d, idr) +# @inbounds d[id] = d[idl] * d[idr] +# pop!(t_stack) +# else +# push!(t_stack, x.right) +# end +# else +# push!(t_stack, x.left) +# end +# elseif x.tag === ZERO +# d[id] = 0.0 +# pop!(t_stack) +# else +# d[id] = 1.0 +# pop!(t_stack) +# end +# end +# end +# return d[rootid] +# end + +function _find_branch(x, d) + if x.tag === ZERO + return true, 0.0 + elseif x.tag === ONE || x.tag === LEAF + return true, 1.0 + else + idl = objectid(x.left) + if haskey(d, idl) + return true, d[idl] + else + return false, 0.0 + end + end +end + + num_nodes(x::SumProductTree) = _num_nodes(x, Dict{UInt, Int}()) function _num_nodes(x, d) id = objectid(x) @@ -716,27 +781,36 @@ function generate_samples(t::SumProductTree{ET}, nsamples::Int) where {ET} end function sample_descend!(res::AbstractVector, t::SumProductTree, d::Dict) - length(res) == 0 && return res - if t.tag == LEAF - res .|= Ref(t.data) - elseif t.tag == SUM - ratio = _length!(t.left, d)/_length!(t, d) - nleft = 0 - for _ = 1:length(res) - if rand() < ratio - nleft += 1 + res_stack = Any[res] + t_stack = [t] + while !isempty(t_stack) && !isempty(res_stack) + t = pop!(t_stack) + res = pop!(res_stack) + if t.tag == LEAF + res .|= Ref(t.data) + elseif t.tag == SUM + ratio = _length!(t.left, d)/_length!(t, d) + nleft = 0 + for _ = 1:length(res) + if rand() < ratio + nleft += 1 + end end + shuffle!(res) # shuffle the `res` to avoid biased sampling, very important. + push!(res_stack, view(res,1:nleft)) + push!(res_stack, view(res,nleft+1:length(res))) + push!(t_stack, t.left) + push!(t_stack, t.right) + elseif t.tag == PROD + push!(res_stack, res) + push!(res_stack, res) + push!(t_stack, t.left) + push!(t_stack, t.right) + elseif t.tag == ZERO + error("Meet zero when descending.") + else + # pass for 1 end - shuffle!(res) # shuffle the `res` to avoid biased sampling, very important. - sample_descend!(view(res,1:nleft), t.left, d) - sample_descend!(view(res,nleft+1:length(res)), t.right, d) - elseif t.tag == PROD - sample_descend!(res, t.right, d) - sample_descend!(res, t.left, d) - elseif t.tag == ZERO - error("Meet zero when descending.") - else - # pass for 1 end return res end From 9b218b62bfc57bfa890cb3d1c13a8bf790eb70e0 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 22 Mar 2022 00:43:45 -0400 Subject: [PATCH 3/6] fix sampling --- src/arithematics.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/arithematics.jl b/src/arithematics.jl index 596b05d6..4063b626 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -797,10 +797,14 @@ function sample_descend!(res::AbstractVector, t::SumProductTree, d::Dict) end end shuffle!(res) # shuffle the `res` to avoid biased sampling, very important. - push!(res_stack, view(res,1:nleft)) - push!(res_stack, view(res,nleft+1:length(res))) - push!(t_stack, t.left) - push!(t_stack, t.right) + if nleft >= 1 + push!(res_stack, view(res,1:nleft)) + push!(t_stack, t.left) + end + if length(res) > nleft + push!(res_stack, view(res,nleft+1:length(res))) + push!(t_stack, t.right) + end elseif t.tag == PROD push!(res_stack, res) push!(res_stack, res) From c75aeea4c49e7a2d05c1d7a9554d8f462e22904e Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 22 Mar 2022 00:44:54 -0400 Subject: [PATCH 4/6] clean up non-recursive _length --- src/arithematics.jl | 49 --------------------------------------------- 1 file changed, 49 deletions(-) diff --git a/src/arithematics.jl b/src/arithematics.jl index 4063b626..e45c2241 100644 --- a/src/arithematics.jl +++ b/src/arithematics.jl @@ -608,55 +608,6 @@ function _length!(x, d) end end -# # loop version -# function _length!(x, d) -# rootid = objectid(x) -# t_stack = [x] -# # update dict -# while !isempty(t_stack) -# x = t_stack[end] -# id = objectid(x) -# if haskey(d, id) -# pop!(t_stack) -# else -# if x.tag === SUM -# idl = objectid(x.left) -# if haskey(d, idl) -# idr = objectid(x.right) -# if haskey(d, idr) -# @inbounds d[id] = d[idl] + d[idr] -# pop!(t_stack) -# else -# push!(t_stack, x.right) -# end -# else -# push!(t_stack, x.left) -# end -# elseif x.tag === PROD -# idl = objectid(x.left) -# if haskey(d, idl) -# idr = objectid(x.right) -# if haskey(d, idr) -# @inbounds d[id] = d[idl] * d[idr] -# pop!(t_stack) -# else -# push!(t_stack, x.right) -# end -# else -# push!(t_stack, x.left) -# end -# elseif x.tag === ZERO -# d[id] = 0.0 -# pop!(t_stack) -# else -# d[id] = 1.0 -# pop!(t_stack) -# end -# end -# end -# return d[rootid] -# end - function _find_branch(x, d) if x.tag === ZERO return true, 0.0 From de161989d11fc5f80d248e844ccbed4ec764c5ea Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 22 Mar 2022 00:53:36 -0400 Subject: [PATCH 5/6] fix sampling --- examples/IndependentSet.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/IndependentSet.jl b/examples/IndependentSet.jl index 3d488553..5ba27230 100644 --- a/examples/IndependentSet.jl +++ b/examples/IndependentSet.jl @@ -135,10 +135,12 @@ all_independent_sets_tree = solve(problem, ConfigsAll(; tree_storage=true))[] # The results encode the configurations in the sum-product-tree format. One can count and enumerate them explicitly by typing length(all_independent_sets_tree) -# +# Then one can use `Base.collect` function to create a [`ConfigEnumerator`](@ref) or use [`generate_samples`](@ref) to generate samples from it. collect(all_independent_sets_tree) +generate_samples(all_independent_sets_tree, 10) + # One can use [`save_configs`](@ref) and [`load_configs`](@ref) to save and read a set of configuration to disk. filename = tempname() From 03732b7bb14693048a3d74da75c6d15299a74770 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 22 Mar 2022 01:02:21 -0400 Subject: [PATCH 6/6] fix docs --- docs/src/performancetips.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/performancetips.md b/docs/src/performancetips.md index c409579e..1c2db004 100644 --- a/docs/src/performancetips.md +++ b/docs/src/performancetips.md @@ -136,7 +136,7 @@ julia> lineplot(hamming_distribution(samples, samples)) ``` ## Multiprocessing -Submodule `GraphTensorNetworks.SimpleMutiprocessing` provides a function [`multiprocess_run`](@ref) function for simple multi-processing jobs. +Submodule `GraphTensorNetworks.SimpleMutiprocessing` provides a function [`GraphTensorNetworks.SimpleMultiprocessing.multiprocess_run`](@ref) function for simple multi-processing jobs. Suppose we want to find the independence polynomial for multiple graphs with 4 processes. We can create a file, e.g. named `run.jl` with the following content