Skip to content

Commit

Permalink
Merge pull request #9 from TensorBFS/treesa
Browse files Browse the repository at this point in the history
Expression rewritting based simulated annealing
  • Loading branch information
GiggleLiu authored Sep 14, 2021
2 parents e9c03aa + 494f9ba commit 7ad371c
Show file tree
Hide file tree
Showing 5 changed files with 388 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/OMEinsumContractionOrders.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ using OMEinsum: NestedEinsum

include("kahypar.jl")
include("sa.jl")
include("treesa.jl")

end
10 changes: 5 additions & 5 deletions src/sa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,23 +75,23 @@ function bipartite_sc(bipartiter::SABipartite, adj::SparseMatrixCSC, vertices, l
end
accept && update_state!(state, adjt, vertices, idxi, sc_ti, sc_tinew, newloss)
end
tc, sc1, sc2 = timespace_complexity_singlestep(state, adj, vertices, log2_sizes)
tc, sc1, sc2 = timespace_complexity_singlestep(state.config, adj, vertices, log2_sizes)
@assert state.group_scs [sc1, sc2] # sanity check
if maximum(state.group_scs) <= max(bipartiter.sc_target, maximum(best.group_scs)) && (maximum(best.group_scs) >= bipartiter.sc_target || state.loss[] < best.loss[])
best = state
end
end
best_tc, = timespace_complexity_singlestep(best, adj, vertices, log2_sizes)
best_tc, = timespace_complexity_singlestep(best.config, adj, vertices, log2_sizes)
@debug "best loss = $(round(best.loss[]; digits=3)) space complexities = $(best.group_scs) time complexity = $(best_tc) groups_sizes = $(best.group_sizes)"
if maximum(best.group_scs) > bipartiter.sc_target
@warn "target space complexity not found, got: $(maximum(best.group_scs)), with time complexity $best_tc."
end
return vertices[findall(==(1), best.config)], vertices[findall(==(2), best.config)]
end

function timespace_complexity_singlestep(state, adj, group, log2_sizes)
g1 = group[findall(==(1), state.config)]
g2 = group[findall(==(2), state.config)]
function timespace_complexity_singlestep(config, adj, group, log2_sizes)
g1 = group[findall(==(1), config)]
g2 = group[findall(==(2), config)]
d1 = sum(adj[g1,:], dims=1)
d2 = sum(adj[g2,:], dims=1)
dall = sum(adj, dims=1)
Expand Down
259 changes: 259 additions & 0 deletions src/treesa.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
using OMEinsum.ContractionOrder: ContractionTree, log2sumexp2

export optimize_tree

struct ExprInfo
out_dims::Vector{Int}
end
struct LeafNode
tensorid::Int
labels::Vector{Int}
end

mutable struct ExprTree
left::Union{ExprTree,LeafNode}
right::Union{ExprTree,LeafNode}
info::ExprInfo
end
function print_expr(io::IO, expr::ExprTree, level=0)
print(io, " "^(2*level), "(\n")
print_expr(io, expr.left, level+1)
print("\n")
print_expr(io, expr.right, level+1)
print("\n")
print(io, " "^(2*level), ") := ", expr.info.out_dims)
end
print_expr(io::IO, expr::LeafNode, level=0) = print(io, " "^(2*level), expr.labels, " ($(expr.tensorid))")
Base.show(io::IO, expr::ExprTree) = print_expr(io, expr, 0)
Base.show(io::IO, ::MIME"text/plain", expr::ExprTree) = show(io, expr)

function optimize_tree(code::NestedEinsum, size_dict; sc_target=32, βs=0.1:0.1:10, ntrials=2, niters=100, sc_weight=1.0)
labels = _label_dict(OMEinsum.flatten(code)) # label to int
inverse_map = Dict([v=>k for (k,v) in labels])
log2_sizes = [log2.(size_dict[inverse_map[i]]) for i=1:length(labels)]
tree = _exprtree(code, labels)
opttree = optimize_tree_sa(tree, log2_sizes; sc_target=sc_target, βs=βs, ntrials=ntrials, niters=niters, sc_weight=sc_weight)
return NestedEinsum(opttree, inverse_map)
end

siblings(t::ExprTree) = Any[t.left, t.right]
siblings(::LeafNode) = Any[]
Base.copy(t::ExprTree) = ExprTree(copy(t.left), copy(t.right), ExprInfo(copy(t.info.out_dims)))
Base.copy(t::LeafNode) = LeafNode(t.tensorid, copy(t.labels))
labels(t::ExprTree) = t.info.out_dims
labels(t::LeafNode) = t.labels
maxlabel(t::ExprTree) = max(isempty(labels(t)) ? 0 : maximum(labels(t)), maxlabel(t.left), maxlabel(t.right))
maxlabel(t::LeafNode) = maximum(isempty(labels(t)) ? 0 : labels(t))
Base.:(==)(t1::ExprTree, t2::ExprTree) = _equal(t1, t2)
Base.:(==)(t1::ExprInfo, t2::ExprInfo) = _equal(t1.out_dims, t2.out_dims)
_equal(t1::ExprTree, t2::ExprTree) = _equal(t1.left, t2.left) && _equal(t1.right, t2.right) && t1.info == t2.info
_equal(t1::LeafNode, t2::LeafNode) = t1.tensorid == t2.tensorid
_equal(t1::Vector, t2::Vector) = Set(t1) == Set(t2)
_equal(a, b) = false

function optimize_tree_sa(tree::ExprTree, log2_sizes; βs, niters, ntrials, sc_target, sc_weight)
best_tree = tree
best_tc, best_sc = tree_timespace_complexity(tree, log2_sizes)
for _ = 1:ntrials
ctree = copy(tree)
@inbounds for β in βs, _ = 1:niters
optimize_subtree!(ctree, β, log2_sizes, sc_target, sc_weight) # single sweep
end
tc, sc = tree_timespace_complexity(ctree, log2_sizes)
if sc < best_sc || (sc == best_sc && tc < best_tc)
best_tree, best_tc, best_sc = ctree, tc, sc
end
end
@debug "best space complexities = $best_tc time complexity = $best_sc"
if best_sc > sc_target
@warn "target space complexity not found, got: $best_sc, with time complexity $best_tc."
end
return best_tree
end

function OMEinsum.timespace_complexity(tree::ExprTree, size_vec)
tree_timespace_complexity(tree, log2.(size_vec))
end
function tree_timespace_complexity(tree::LeafNode, log2_sizes)
-Inf, sum(i->log2_sizes[i], tree.labels)
end
function tree_timespace_complexity(tree::ExprTree, log2_sizes)
tcl, scl = tree_timespace_complexity(tree.left, log2_sizes)
tcr, scr = tree_timespace_complexity(tree.right, log2_sizes)
tc, sc = tcsc(labels(tree.left), labels(tree.right), labels(tree), log2_sizes)
return log2sumexp2([tc, tcl, tcr]), max(sc, scl, scr)
end
function tcsc(ix1, ix2, iy, log2_sizes)
l1, l2, l3 = ix1, ix2, iy
sc = isempty(l3) ? 0 : sum(i->log2_sizes[i], l3)
tc = sc
# Note: assuming labels in `l1` being unique
for l in l1
if l l2 && l l3
tc += log2_sizes[l]
end
end
return tc, sc
end

function random_exprtree(@nospecialize(code::EinCode{ixs, iy})) where {ixs, iy}
labels = _label_dict(code)
return random_exprtree([[labels[l] for l in ix] for ix in ixs], [labels[l] for l in iy], length(labels))
end

function random_exprtree(ixs::Vector{Vector{Int}}, iy::Vector{Int}, nedge::Int)
outercount = zeros(Int, nedge)
allcount = zeros(Int, nedge)
for l in iy
outercount[l] += 1
allcount[l] += 1
end
for ix in ixs
for l in ix
allcount[l] += 1
end
end
_random_exprtree(ixs, outercount, allcount, Ref(0))
end
function _random_exprtree(ixs::Vector{Vector{Int}}, outercount::Vector{Int}, allcount::Vector{Int}, k)
n = length(ixs)
if n == 1
k[] += 1
return LeafNode(k[], ixs[1])
end
mask = rand(Bool, n)
if all(mask) || !any(mask) # prevent invalid partition
i = rand(1:n)
mask[i] = ~(mask[i])
end
info = ExprInfo([i for i=1:length(outercount) if outercount[i]!=allcount[i] && outercount[i]!=0])
outercount1, outercount2 = copy(outercount), copy(outercount)
for i=1:n
counter = mask[i] ? outercount2 : outercount1
for l in ixs[i]
counter[l] += 1
end
end
return ExprTree(_random_exprtree(ixs[mask], outercount1, allcount, k), _random_exprtree(ixs[(!).(mask)], outercount2, allcount, k), info)
end

function optimize_subtree!(tree, β, log2_sizes, sc_target, sc_weight)
rst = ruleset(tree)
if !isempty(rst)
rule = rand(rst)
sc = length(labels(tree))==0 ? 0.0 : log2sumexp2(getindex.(Ref(log2_sizes), labels(tree)))
dtc, dsc = tcsc_diff(tree, rule, log2_sizes)
#log2(α*RW + tc) is the original `tc` term, which also optimizes read-write overheads.
dE = (max(sc, sc+dsc) > sc_target ? sc_weight : 0) * dsc + dtc
if rand() < exp(-β*dE)
update_tree!(tree, rule)
end
for subtree in siblings(tree)
optimize_subtree!(subtree, β, log2_sizes, sc_target, sc_weight)
end
end
return tree
end

ruleset(::LeafNode) = 1:-1
@inline function ruleset(tree::ExprTree)
if tree.left isa ExprTree && tree.right isa ExprTree
return 1:4
elseif tree.left isa ExprTree
return 1:2
elseif tree.right isa ExprTree
return 3:4
else
return 1:0
end
end

function tcsc_diff(tree::ExprTree, rule, log2_sizes)
if rule == 1 # (a,b), c -> (a,c),b
return abcacb(labels(tree.left.left), labels(tree.left.right), labels(tree.right), labels(tree), log2_sizes)
elseif rule == 2 # (a,b), c -> (c,b),a
return abcacb(labels(tree.left.right), labels(tree.left.left), labels(tree.right), labels(tree), log2_sizes)
elseif rule == 3 # a,(b,c) -> b,(a,c)
return abcacb(labels(tree.right.right), labels(tree.right.left), labels(tree.left), labels(tree), log2_sizes)
else # a,(b,c) -> c,(b,a)
return abcacb(labels(tree.right.left), labels(tree.right.right), labels(tree.left), labels(tree), log2_sizes)
end
end

function abcacb(a, b, c, d, log2_sizes)
tc0, sc0 = _tcsc_merge(a, b, c, d, log2_sizes)
tc1, sc1 = _tcsc_merge(a, c, b, d, log2_sizes)
return tc1-tc0, sc1-sc0
end

function _tcsc_merge(a, b, c, d, log2_sizes)
ab = Int[]
for l in a
if l c || l d # suppose no repeated indices
push!(ab, l)
end
end
for l in b
if l a && (l c || l d) # suppose no repeated indices
push!(ab, l)
end
end
ab = (a b) (c d)
tcl, scl = tcsc(a, b, ab, log2_sizes) # this is correct
tcr, scr = tcsc(ab, c, d, log2_sizes)
log2sumexp2([tcl, tcr]), max(scl, scr)
end

function update_tree!(tree::ExprTree, rule::Int)
if rule == 1 # (a,b), c -> (a,c),b
b, c = tree.left.right, tree.right
tree.left.right = c
tree.right = b
tree.left.info = ExprInfo((labels(tree.left.left) labels(tree.left.right)) (labels(tree) labels(tree.right)))
elseif rule == 2 # (a,b), c -> (c,b),a
a, c = tree.left.left, tree.right
tree.left.left = c
tree.right = a
tree.left.info = ExprInfo((labels(tree.left.left) labels(tree.left.right)) (labels(tree) labels(tree.right)))
elseif rule == 3 # a,(b,c) -> b,(a,c)
a, b = tree.left, tree.right.left
tree.left = b
tree.right.left = a
tree.right.info = ExprInfo((labels(tree.right.left) labels(tree.right.right)) (labels(tree) labels(tree.left)))
else # a,(b,c) -> c,(b,a)
a, c = tree.left, tree.right.right
tree.left = c
tree.right.right = a
tree.right.info = ExprInfo((labels(tree.right.left) labels(tree.right.right)) (labels(tree) labels(tree.left)))
end
return tree
end

# from label to integer.
function _label_dict(@nospecialize(code::EinCode{ixs, iy})) where {ixs, iy}
ixsv, iyv = collect.(ixs), collect(iy)
v = unique(vcat(ixsv..., iyv))
labels = Dict(zip(v, 1:length(v)))
return labels
end

ExprTree(code::NestedEinsum) = _exprtree(code, _label_dict(OMEinsum.flatten(code)))
function _exprtree(code::NestedEinsum, labels)
@assert length(code.args) == 2
ExprTree(map(enumerate(code.args)) do (i,arg)
if arg isa Int
LeafNode(arg, [labels[i] for i=OMEinsum.getixs(code.eins)[i]])
else
res = _exprtree(arg, labels)
end
end..., ExprInfo([labels[i] for i=OMEinsum.getiy(code.eins)]))
end

OMEinsum.NestedEinsum(expr::ExprTree) = _nestedeinsum(expr, 1:maxlabel(expr))
OMEinsum.NestedEinsum(expr::ExprTree, labelmap) = _nestedeinsum(expr, labelmap)
function _nestedeinsum(tree::ExprTree, lbs)
eins = EinCode(((getindex.(Ref(lbs), labels(tree.left))...,), (getindex.(Ref(lbs), labels(tree.right))...,)), (getindex.(Ref(lbs), labels(tree))...,))
NestedEinsum((_nestedeinsum(tree.left, lbs), _nestedeinsum(tree.right, lbs)), eins)
end
_nestedeinsum(tree::LeafNode, lbs) = tree.tensorid

4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ end
@testset "sa" begin
include("sa.jl")
end

@testset "treesa" begin
include("treesa.jl")
end
Loading

0 comments on commit 7ad371c

Please sign in to comment.