struct Slicing{LT}
legs::Vector{LT} # sliced leg and its original size
Base.:(==)(se::Slicing, se2::Slicing) = se.legs == se2.legs

Slicing(s::Slicer, inverse_map) = Slicing([inverse_map[l] for (l, s) in s.legs])
Base.length(s::Slicing) = length(s.legs)
Expand All @@ -37,6 +38,7 @@ struct SlicedEinsum{LT, Ein} <: AbstractEinsum
Base.:(==)(se::SlicedEinsum, se2::SlicedEinsum) = se.slicing == se2.slicing && se.eins == se2.eins

# Iterate over tensor network slices, its iterator interface returns `slicemap` as a Dict
# slice and fill tensors with
Expand Down
136 changes: 93 additions & 43 deletions src/treesa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,19 @@ Base.@kwdef struct TreeSA{RT,IT,GM} <: CodeOptimizer
greedy_config::GM = GreedyMethod(nrepeat=1)

# `ExprInfo` stores the node information.
# * `out_dims` is the output dimensions of this tree/subtree.
# * `tensorid` specifies the tensor index for leaf nodes. It is `-1` is for non-leaf node.
struct ExprInfo
ExprInfo(out_dims::Vector{Int}) = ExprInfo(out_dims, -1)

# `ExprTree` is the expression tree for tensor contraction (or contraction tree), it is a binary tree (including leaf nodes without siblings).
# `left` and `right` are left and right branches, they are either both specified (non-leaf) or both unspecified (leaf), see [`isleaf`](@ref) function.
# `ExprTree()` for constructing a leaf node,
# `ExprTree(left, right, info)` for constructing a non-leaf node.
mutable struct ExprTree
Expand All @@ -55,14 +62,18 @@ function print_expr(io::IO, expr::ExprTree, level=0)
print(io, " "^(2*level), ") := ", labels(expr))
# if `expr` is a leaf, it should have `left` and `right` fields both unspecified.
OMEinsum.isleaf(expr::ExprTree) = !isdefined(expr, :left), expr::ExprTree) = print_expr(io, expr, 0), ::MIME"text/plain", expr::ExprTree) = show(io, expr)
siblings(t::ExprTree) = isleaf(t) ? ExprTree[] : ExprTree[t.left, t.right]
Base.copy(t::ExprTree) = isleaf(t) ? ExprTree( : ExprTree(copy(t.left), copy(t.right), copy(
Base.copy(info::ExprInfo) = ExprInfo(copy(info.out_dims), info.tensorid)
# output tensor labels
labels(t::ExprTree) =
# find the maximum label recursively, this is a helper function for converting an expression tree back to einsum.
maxlabel(t::ExprTree) = isleaf(t) ? maximum(isempty(labels(t)) ? 0 : labels(t)) : max(isempty(labels(t)) ? 0 : maximum(labels(t)), maxlabel(t.left), maxlabel(t.right))
# comparison between `ExprTree`s, mainly for testing
Base.:(==)(t1::ExprTree, t2::ExprTree) = _equal(t1, t2)
Base.:(==)(t1::ExprInfo, t2::ExprInfo) = _equal(t1.out_dims, t2.out_dims) && t1.tensorid == t2.tensorid
function _equal(t1::ExprTree, t2::ExprTree)
Expand All @@ -71,34 +82,44 @@ function _equal(t1::ExprTree, t2::ExprTree)
_equal(t1::Vector, t2::Vector) = Set(t1) == Set(t2)

# this is the main function
optimize_tree(code, size_dict; sc_target=20, βs=0.1:0.1:10, ntrials=2, niters=100, sc_weight=1.0, rw_weight=0.2, initializer=:greedy, greedy_method=OMEinsum.MinSpaceOut(), greedy_nrepeat=1)
Optimize the einsum contraction pattern specified by `code`, and edge sizes specified by `size_dict`. Key word arguments are
Optimize the einsum contraction pattern specified by `code`, and edge sizes specified by `size_dict`.
Check the docstring of `TreeSA` for detailed explaination of other input arguments.
function optimize_tree(code, size_dict; nslices::Int=0, sc_target=20, βs=0.1:0.1:10, ntrials=20, niters=100, sc_weight=1.0, rw_weight=0.2, initializer=:greedy, greedy_method=OMEinsum.MinSpaceOut(), greedy_nrepeat=1)
flatten_code = OMEinsum.flatten(code)
ninputs = length(OMEinsum.getixs(flatten_code))
if ninputs <= 2
return NestedEinsum(ntuple(i->i, ninputs), flatten_code isa DynamicEinCode ? flatten_code : DynamicEinCode(flatten_code))
# get input labels (`getixsv`) and output labels (`getiyv`) in the einsum code.
ixs, iy = getixsv(code), getiyv(code)
ninputs = length(ixs) # number of input tensors
if ninputs <= 2 # number of input tensors ≤ 2, can not be optimized
return SlicedEinsum(Slicing(eltype(iy)[]), NestedEinsum(ntuple(i->i, ninputs), DynamicEinCode(ixs, iy)))
labels = _label_dict(flatten_code) # label to int
inverse_map = Dict([v=>k for (k,v) in labels])
log2_sizes = [log2.(size_dict[inverse_map[i]]) for i=1:length(labels)]
if ntrials <= 0
###### Stage 1: preprocessing ######
labels = _label_dict(ixs, iy) # map labels to integers
inverse_map = Dict([v=>k for (k,v) in labels]) # the inverse transformation, map integers to labels
log2_sizes = [log2.(size_dict[inverse_map[i]]) for i=1:length(labels)] # use `log2` sizes in computing time
if ntrials <= 0 # no optimization at all, then 1). initialize an expression tree and 2). convert back to nested einsum.
best_tree = _initializetree(code, size_dict, initializer; greedy_method=greedy_method, greedy_nrepeat=greedy_nrepeat)
return NestedEinsum(best_tree, inverse_map)
return SlicedEinsum(Slicing(eltype(iy)[]), NestedEinsum(best_tree, inverse_map))
###### Stage 2: computing ######
# create vectors to store optimized 1). expression tree, 2). time complexities, 3). space complexities, 4). read-write complexities and 5). slicing information.
trees, tcs, scs, rws, slicers = Vector{ExprTree}(undef, ntrials), zeros(ntrials), zeros(ntrials), zeros(ntrials), Vector{Slicer}(undef, ntrials)
@threads for t = 1:ntrials
@threads for t = 1:ntrials # multi-threading on different trials, use `JULIA_NUM_THREADS=5 julia xxx.jl` for setting number of threads.
# 1). random/greedy initialize a contraction tree.
tree = _initializetree(code, size_dict, initializer; greedy_method=greedy_method, greedy_nrepeat=greedy_nrepeat)
# 2). optimize the `tree`` and `slicer` in a inplace manner.
slicer = Slicer(log2_sizes, nslices)
optimize_tree_sa!(tree, log2_sizes, slicer; sc_target=sc_target, βs=βs, niters=niters, sc_weight=sc_weight, rw_weight=rw_weight)
# 3). evaluate time-space-readwrite complexities.
tc, sc, rw = tree_timespace_complexity(tree, slicer.log2_sizes)
@debug "trial $t, time complexity = $tc, space complexity = $sc, read-write complexity = $rw."
trees[t], tcs[t], scs[t], rws[t], slicers[t] = tree, tc, sc, rw, slicer
###### Stage 3: postprocessing ######
# compare and choose the best solution
best_tree, best_tc, best_sc, best_rw, best_slicer = first(trees), first(tcs), first(scs), first(rws), first(slicers)
for i=2:ntrials
if scs[i] < best_sc || (scs[i] == best_sc && exp2(tcs[i]) + rw_weight * exp2(rws[i]) < exp2(best_tc) + rw_weight * exp2(rws[i]))
Expand All @@ -109,24 +130,26 @@ function optimize_tree(code, size_dict; nslices::Int=0, sc_target=20, βs=0.1:0.
if best_sc > sc_target
@warn "target space complexity not found, got: $best_sc, with time complexity $best_tc, read-write complexity $best_rw."
# returns a sliced einsum we need to map the sliced dimensions back from integers to labels.
return SlicedEinsum(Slicing(best_slicer, inverse_map), NestedEinsum(best_tree, inverse_map))

# initialize a contraction tree
function _initializetree(code, size_dict, method; greedy_method, greedy_nrepeat)
flatcode = OMEinsum.flatten(code)
if method == :greedy
labels = _label_dict(flatcode) # label to int
return _exprtree(optimize_greedy(flatcode, size_dict; method=greedy_method, nrepeat=greedy_nrepeat), labels)
labels = _label_dict(code) # label to int
return _exprtree(optimize_greedy(code, size_dict; method=greedy_method, nrepeat=greedy_nrepeat), labels)
elseif method == :random
return random_exprtree(flatcode)
return random_exprtree(code)
elseif method == :specified
labels = _label_dict(flatcode) # label to int
labels = _label_dict(code) # label to int
return _exprtree(code, labels)
throw(ArgumentError("intializier `$method` is not defined!"))

# use simulated annealing to optimize a contraction tree
function optimize_tree_sa!(tree::ExprTree, log2_sizes, slicer::Slicer; βs, niters, sc_target, sc_weight, rw_weight)
@assert rw_weight >= 0
@assert sc_weight >= 0
Expand All @@ -136,19 +159,22 @@ function optimize_tree_sa!(tree::ExprTree, log2_sizes, slicer::Slicer; βs, nite
tc, sc, rw = tree_timespace_complexity(tree, log2_sizes)
"β = , tc = $tc, sc = $sc, rw = $rw"
if slicer.max_size > 0
# find legs that reduce the dimension the most
###### Stage 1: add one slice at each temperature ######
if slicer.max_size > 0 # `max_size` specifies the maximum number of sliced dimensions.
# 1). find legs that reduce the dimension the most
scs, lbs = Float64[], Vector{Int}[]
# space complexities and labels of all intermediate tensors
tensor_sizes!(tree, slicer.log2_sizes, scs, lbs)
# the set of (intermediate) tensor labels that producing maximum space complexity
best_labels = _best_labels(scs, lbs)

# 2). slice the best not sliced label (it must appear in largest tensors)
best_not_sliced_labels = filter(x->!haskey(slicer.legs, x), best_labels)
if !isempty(best_not_sliced_labels)
best_not_sliced_label = rand(best_not_sliced_labels)
if length(slicer) < slicer.max_size
best_not_sliced_label = rand(best_not_sliced_labels) # TODO: can we have a selection rule than random selection?
if length(slicer) < slicer.max_size # if has not reached maximum number of slices, add one slice
push!(slicer, best_not_sliced_label)
#worst_sliced_labels = filter(x->haskey(slicer.legs, x), setdiff(log2_sizes, best_labels))
else # otherwise replace one slice
legs = collect(keys(slicer.legs))
score = [count(==(l), best_labels) for l in legs]
replace!(slicer, legs[argmin(score)]=>best_not_sliced_label)
Expand All @@ -159,18 +185,21 @@ function optimize_tree_sa!(tree::ExprTree, log2_sizes, slicer::Slicer; βs, nite
"after slicing: β = , tc = $tc, sc = $sc, rw = $rw"
###### Stage 2: sweep and optimize the contraction tree for `niters` times ######
for _ = 1:niters
optimize_subtree!(tree, β, slicer.log2_sizes, sc_target, sc_weight, log2rw_weight) # single sweep
return tree, slicer

# here "best" means giving maximum space complexity
function _best_labels(scs, lbs)
max_sc = maximum(scs)
return vcat(lbs[scs .> max_sc-0.99]...)

# find tensor sizes and their corresponding labels of all intermediate tensors
function tensor_sizes!(tree::ExprTree, log2_sizes, scs, lbs)
sc = isempty(labels(tree)) ? 0.0 : sum(i->log2_sizes[i], labels(tree))
push!(scs, sc)
Expand All @@ -180,17 +209,23 @@ function tensor_sizes!(tree::ExprTree, log2_sizes, scs, lbs)
tensor_sizes!(tree.right, log2_sizes, scs, lbs)

# the time-space-readwrite complexity of a contraction tree
function tree_timespace_complexity(tree::ExprTree, log2_sizes)
isleaf(tree) && return (-Inf, isempty(labels(tree)) ? 0.0 : sum(i->log2_sizes[i], labels(tree)), -Inf)
tcl, scl, rwl = tree_timespace_complexity(tree.left, log2_sizes)
tcr, scr, rwr = tree_timespace_complexity(tree.right, log2_sizes)
tc, sc, rw = tcscrw(labels(tree.left), labels(tree.right), labels(tree), log2_sizes, true)
return (fast_log2sumexp2(tc, tcl, tcr), max(sc, scl, scr), fast_log2sumexp2(rw, rwl, rwr))
@inline function tcscrw(ix1, ix2, iy, log2_sizes::Vector{T}, optimize_rw) where T

# returns time complexity, space complexity and read-write complexity (0 if `compute_rw` is false)
# `ix1` and `ix2` are vectors of labels for the first and second input tensors.
# `iy` is a vector of labels for the output tensors.
# `log2_sizes` is the log2 size of labels (note labels are integers, we do not need dict to index label sizes).\
@inline function tcscrw(ix1, ix2, iy, log2_sizes::Vector{T}, compute_rw) where T
l1, l2, l3 = ix1, ix2, iy
sc1 = (!optimize_rw || isempty(l1)) ? zero(T) : sum(i->(@inbounds log2_sizes[i]), l1)
sc2 = (!optimize_rw || isempty(l2)) ? zero(T) : sum(i->(@inbounds log2_sizes[i]), l2)
sc1 = (!compute_rw || isempty(l1)) ? zero(T) : sum(i->(@inbounds log2_sizes[i]), l1)
sc2 = (!compute_rw || isempty(l2)) ? zero(T) : sum(i->(@inbounds log2_sizes[i]), l2)
sc = isempty(l3) ? zero(T) : sum(i->(@inbounds log2_sizes[i]), l3)
tc = sc
# Note: assuming labels in `l1` being unique
Expand All @@ -199,13 +234,15 @@ end
tc += log2_sizes[l]
rw = optimize_rw ? fast_log2sumexp2(sc, sc1, sc2) : 0.0
rw = compute_rw ? fast_log2sumexp2(sc, sc1, sc2) : 0.0
return tc, sc, rw

# random contraction tree
function random_exprtree(code::EinCode)
labels = _label_dict(code)
return random_exprtree([Int[labels[l] for l in ix] for ix in getixsv(code)], Int[labels[l] for l in getiyv(code)], length(labels))
ixs, iy = getixsv(code), getiyv(code)
labels = _label_dict(ixs, iy)
return random_exprtree([Int[labels[l] for l in ix] for ix in ixs], Int[labels[l] for l in iy], length(labels))

function random_exprtree(ixs::Vector{Vector{Int}}, iy::Vector{Int}, nedge::Int)
Expand Down Expand Up @@ -243,25 +280,34 @@ function _random_exprtree(ixs::Vector{Vector{Int}}, xindices, outercount::Vector
return ExprTree(_random_exprtree(ixs[mask], xindices[mask], outercount1, allcount), _random_exprtree(ixs[(!).(mask)], xindices[(!).(mask)], outercount2, allcount), info)

# optimize a contraction tree recursively
function optimize_subtree!(tree, β, log2_sizes, sc_target, sc_weight, log2rw_weight)
# find appliable local rules, at most 4 rules can be applied.
# Sometimes, not all rules are applicable because either left or right sibling do not have siblings.
rst = ruleset(tree)
if !isempty(rst)
# propose a random update rule, TODO: can we have a better selector?
rule = rand(rst)
optimize_rw = log2rw_weight != -Inf
# difference in time, space and read-write complexity if the selected rule is applied
tc0, tc1, dsc, rw0, rw1, subout = tcsc_diff(tree, rule, log2_sizes, optimize_rw)
dtc = optimize_rw ? fast_log2sumexp2(tc1, log2rw_weight + rw1) - fast_log2sumexp2(tc0, log2rw_weight + rw0) : tc1 - tc0
sc = _sc(tree, rule, log2_sizes)
sc = _sc(tree, rule, log2_sizes) # current space complexity

# update the loss function
dE = (max(sc, sc+dsc) > sc_target ? sc_weight : 0) * dsc + dtc
if rand() < exp(-β*dE)
if rand() < exp(-β*dE) # ACCEPT
update_tree!(tree, rule, subout)
for subtree in siblings(tree)
for subtree in siblings(tree) # RECURSE
optimize_subtree!(subtree, β, log2_sizes, sc_target, sc_weight, log2rw_weight)
# if rule ∈ [1, 2], left sibling will be updated, otherwise, right sibling will be updated.
# we need to compute space complexity for current node and the updated sibling, and return the larger one.
_sc(tree, rule, log2_sizes) = max(__sc(tree, log2_sizes), __sc((rule == 1 || rule == 2) ? tree.left : tree.right, log2_sizes))
__sc(tree, log2_sizes) = length(labels(tree))==0 ? 0.0 : sum(l->log2_sizes[l], labels(tree))
__sc(tree, log2_sizes) = length(labels(tree))==0 ? 0.0 : sum(l->log2_sizes[l], labels(tree)) # space complexity of current node

@inline function ruleset(tree::ExprTree)
if isleaf(tree) || (isleaf(tree.left) && isleaf(tree.right))
Expand All @@ -287,9 +333,10 @@ function tcsc_diff(tree::ExprTree, rule, log2_sizes, optimize_rw)

# compute the time complexity, space complexity and read-write complexity information for the contraction update rule "((a,b),c) -> ((a,c),b)"
function abcacb(a, b, ab, c, d, log2_sizes, optimize_rw)
tc0, sc0, rw0, ab0 = _tcsc_merge(a, b, ab, c, d, log2_sizes, optimize_rw)
ac = Int[]
tc0, sc0, rw0 = _tcsc_merge(a, b, ab, c, d, log2_sizes, optimize_rw)
ac = Int[] # labels for contraction result of (a, c)
for l in a
if l b || l d # suppose no repeated indices
push!(ac, l)
Expand All @@ -300,16 +347,18 @@ function abcacb(a, b, ab, c, d, log2_sizes, optimize_rw)
push!(ac, l)
tc1, sc1, rw1, ab1 = _tcsc_merge(a, c, ac, b, d, log2_sizes, optimize_rw)
return tc0, tc1, sc1-sc0, rw0, rw1, ab1 # Note: this tc diff does not make much sense
tc1, sc1, rw1 = _tcsc_merge(a, c, ac, b, d, log2_sizes, optimize_rw)
return tc0, tc1, sc1-sc0, rw0, rw1, ac

# compute complexity for a two-step contraction: (a, b) -> ab, (ab, c) -> d
function _tcsc_merge(a, b, ab, c, d, log2_sizes, optimize_rw)
tcl, scl, rwl = tcscrw(a, b, ab, log2_sizes, optimize_rw) # this is correct
tcr, scr, rwr = tcscrw(ab, c, d, log2_sizes, optimize_rw)
fast_log2sumexp2(tcl, tcr), max(scl, scr), (optimize_rw ? fast_log2sumexp2(rwl, rwr) : 0.0), ab
fast_log2sumexp2(tcl, tcr), max(scl, scr), (optimize_rw ? fast_log2sumexp2(rwl, rwr) : 0.0)

# apply the update rule
function update_tree!(tree::ExprTree, rule::Int, subout)
if rule == 1 # (a,b), c -> (a,c),b
b, c = tree.left.right, tree.right
Expand All @@ -335,29 +384,30 @@ function update_tree!(tree::ExprTree, rule::Int, subout)
return tree

# from label to integer.
function _label_dict(code::EinCode)
LT = OMEinsum.labeltype(code)
ixsv, iyv = getixsv(code), getiyv(code)
# map labels to integers.
_label_dict(code) = _label_dict(getixsv(code), getiyv(code))
function _label_dict(ixsv::AbstractVector{<:AbstractVector{LT}}, iyv::AbstractVector{LT}) where LT
v = unique(vcat(ixsv..., iyv))
labels = Dict{LT,Int}(zip(v, 1:length(v)))
return labels

# construct the contraction tree recursively from a nested einsum.
function ExprTree(code::NestedEinsum)
_exprtree(code, _label_dict(OMEinsum.flatten(code)))
_exprtree(code, _label_dict(code))
function _exprtree(code::NestedEinsum, labels)
@assert length(code.args) == 2
ExprTree(map(enumerate(code.args)) do (i,arg)
if isleaf(arg)
if isleaf(arg) # leaf nodes
ExprTree(ExprInfo(Int[labels[i] for i=OMEinsum.getixs(code.eins)[i]], arg.tensorindex))
res = _exprtree(arg, labels)
end..., ExprInfo(Int[labels[i] for i=OMEinsum.getiy(code.eins)]))

# convert a contraction tree back to a nested einsum
OMEinsum.NestedEinsum(expr::ExprTree) = _nestedeinsum(expr, 1:maxlabel(expr))
OMEinsum.NestedEinsum(expr::ExprTree, labelmap) = _nestedeinsum(expr, labelmap)
function _nestedeinsum(tree::ExprTree, lbs)
Expand Down
4 changes: 2 additions & 2 deletions test/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ end
ne = NestedEinsum((1,), code)
dne = NestedEinsum((1,), DynamicEinCode(code))
@test optimize_code(code, sizes, GreedyMethod()) == ne
@test optimize_code(code, sizes, TreeSA()) == dne
@test optimize_code(code, sizes, TreeSA(nslices=2)) == dne
@test optimize_code(code, sizes, TreeSA()) == SlicedEinsum(Slicing(Char[]), dne)
@test optimize_code(code, sizes, TreeSA(nslices=2)) == SlicedEinsum(Slicing(Char[]), dne)
@test optimize_code(code, sizes, KaHyParBipartite(sc_target=25)) == dne
@test optimize_code(code, sizes, SABipartite(sc_target=25)) == dne

