-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from TensorBFS/treesa
Expression rewritting based simulated annealing
- Loading branch information
Showing
5 changed files
with
388 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,5 +7,6 @@ using OMEinsum: NestedEinsum | |
|
||
include("kahypar.jl") | ||
include("sa.jl") | ||
include("treesa.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
using OMEinsum.ContractionOrder: ContractionTree, log2sumexp2 | ||
|
||
export optimize_tree | ||
|
||
struct ExprInfo | ||
out_dims::Vector{Int} | ||
end | ||
struct LeafNode | ||
tensorid::Int | ||
labels::Vector{Int} | ||
end | ||
|
||
mutable struct ExprTree | ||
left::Union{ExprTree,LeafNode} | ||
right::Union{ExprTree,LeafNode} | ||
info::ExprInfo | ||
end | ||
function print_expr(io::IO, expr::ExprTree, level=0) | ||
print(io, " "^(2*level), "(\n") | ||
print_expr(io, expr.left, level+1) | ||
print("\n") | ||
print_expr(io, expr.right, level+1) | ||
print("\n") | ||
print(io, " "^(2*level), ") := ", expr.info.out_dims) | ||
end | ||
print_expr(io::IO, expr::LeafNode, level=0) = print(io, " "^(2*level), expr.labels, " ($(expr.tensorid))") | ||
Base.show(io::IO, expr::ExprTree) = print_expr(io, expr, 0) | ||
Base.show(io::IO, ::MIME"text/plain", expr::ExprTree) = show(io, expr) | ||
|
||
function optimize_tree(code::NestedEinsum, size_dict; sc_target=32, βs=0.1:0.1:10, ntrials=2, niters=100, sc_weight=1.0) | ||
labels = _label_dict(OMEinsum.flatten(code)) # label to int | ||
inverse_map = Dict([v=>k for (k,v) in labels]) | ||
log2_sizes = [log2.(size_dict[inverse_map[i]]) for i=1:length(labels)] | ||
tree = _exprtree(code, labels) | ||
opttree = optimize_tree_sa(tree, log2_sizes; sc_target=sc_target, βs=βs, ntrials=ntrials, niters=niters, sc_weight=sc_weight) | ||
return NestedEinsum(opttree, inverse_map) | ||
end | ||
|
||
siblings(t::ExprTree) = Any[t.left, t.right] | ||
siblings(::LeafNode) = Any[] | ||
Base.copy(t::ExprTree) = ExprTree(copy(t.left), copy(t.right), ExprInfo(copy(t.info.out_dims))) | ||
Base.copy(t::LeafNode) = LeafNode(t.tensorid, copy(t.labels)) | ||
labels(t::ExprTree) = t.info.out_dims | ||
labels(t::LeafNode) = t.labels | ||
maxlabel(t::ExprTree) = max(isempty(labels(t)) ? 0 : maximum(labels(t)), maxlabel(t.left), maxlabel(t.right)) | ||
maxlabel(t::LeafNode) = maximum(isempty(labels(t)) ? 0 : labels(t)) | ||
Base.:(==)(t1::ExprTree, t2::ExprTree) = _equal(t1, t2) | ||
Base.:(==)(t1::ExprInfo, t2::ExprInfo) = _equal(t1.out_dims, t2.out_dims) | ||
_equal(t1::ExprTree, t2::ExprTree) = _equal(t1.left, t2.left) && _equal(t1.right, t2.right) && t1.info == t2.info | ||
_equal(t1::LeafNode, t2::LeafNode) = t1.tensorid == t2.tensorid | ||
_equal(t1::Vector, t2::Vector) = Set(t1) == Set(t2) | ||
_equal(a, b) = false | ||
|
||
function optimize_tree_sa(tree::ExprTree, log2_sizes; βs, niters, ntrials, sc_target, sc_weight) | ||
best_tree = tree | ||
best_tc, best_sc = tree_timespace_complexity(tree, log2_sizes) | ||
for _ = 1:ntrials | ||
ctree = copy(tree) | ||
@inbounds for β in βs, _ = 1:niters | ||
optimize_subtree!(ctree, β, log2_sizes, sc_target, sc_weight) # single sweep | ||
end | ||
tc, sc = tree_timespace_complexity(ctree, log2_sizes) | ||
if sc < best_sc || (sc == best_sc && tc < best_tc) | ||
best_tree, best_tc, best_sc = ctree, tc, sc | ||
end | ||
end | ||
@debug "best space complexities = $best_tc time complexity = $best_sc" | ||
if best_sc > sc_target | ||
@warn "target space complexity not found, got: $best_sc, with time complexity $best_tc." | ||
end | ||
return best_tree | ||
end | ||
|
||
function OMEinsum.timespace_complexity(tree::ExprTree, size_vec) | ||
tree_timespace_complexity(tree, log2.(size_vec)) | ||
end | ||
function tree_timespace_complexity(tree::LeafNode, log2_sizes) | ||
-Inf, sum(i->log2_sizes[i], tree.labels) | ||
end | ||
function tree_timespace_complexity(tree::ExprTree, log2_sizes) | ||
tcl, scl = tree_timespace_complexity(tree.left, log2_sizes) | ||
tcr, scr = tree_timespace_complexity(tree.right, log2_sizes) | ||
tc, sc = tcsc(labels(tree.left), labels(tree.right), labels(tree), log2_sizes) | ||
return log2sumexp2([tc, tcl, tcr]), max(sc, scl, scr) | ||
end | ||
function tcsc(ix1, ix2, iy, log2_sizes) | ||
l1, l2, l3 = ix1, ix2, iy | ||
sc = isempty(l3) ? 0 : sum(i->log2_sizes[i], l3) | ||
tc = sc | ||
# Note: assuming labels in `l1` being unique | ||
for l in l1 | ||
if l ∈ l2 && l ∉ l3 | ||
tc += log2_sizes[l] | ||
end | ||
end | ||
return tc, sc | ||
end | ||
|
||
function random_exprtree(@nospecialize(code::EinCode{ixs, iy})) where {ixs, iy} | ||
labels = _label_dict(code) | ||
return random_exprtree([[labels[l] for l in ix] for ix in ixs], [labels[l] for l in iy], length(labels)) | ||
end | ||
|
||
function random_exprtree(ixs::Vector{Vector{Int}}, iy::Vector{Int}, nedge::Int) | ||
outercount = zeros(Int, nedge) | ||
allcount = zeros(Int, nedge) | ||
for l in iy | ||
outercount[l] += 1 | ||
allcount[l] += 1 | ||
end | ||
for ix in ixs | ||
for l in ix | ||
allcount[l] += 1 | ||
end | ||
end | ||
_random_exprtree(ixs, outercount, allcount, Ref(0)) | ||
end | ||
function _random_exprtree(ixs::Vector{Vector{Int}}, outercount::Vector{Int}, allcount::Vector{Int}, k) | ||
n = length(ixs) | ||
if n == 1 | ||
k[] += 1 | ||
return LeafNode(k[], ixs[1]) | ||
end | ||
mask = rand(Bool, n) | ||
if all(mask) || !any(mask) # prevent invalid partition | ||
i = rand(1:n) | ||
mask[i] = ~(mask[i]) | ||
end | ||
info = ExprInfo([i for i=1:length(outercount) if outercount[i]!=allcount[i] && outercount[i]!=0]) | ||
outercount1, outercount2 = copy(outercount), copy(outercount) | ||
for i=1:n | ||
counter = mask[i] ? outercount2 : outercount1 | ||
for l in ixs[i] | ||
counter[l] += 1 | ||
end | ||
end | ||
return ExprTree(_random_exprtree(ixs[mask], outercount1, allcount, k), _random_exprtree(ixs[(!).(mask)], outercount2, allcount, k), info) | ||
end | ||
|
||
function optimize_subtree!(tree, β, log2_sizes, sc_target, sc_weight) | ||
rst = ruleset(tree) | ||
if !isempty(rst) | ||
rule = rand(rst) | ||
sc = length(labels(tree))==0 ? 0.0 : log2sumexp2(getindex.(Ref(log2_sizes), labels(tree))) | ||
dtc, dsc = tcsc_diff(tree, rule, log2_sizes) | ||
#log2(α*RW + tc) is the original `tc` term, which also optimizes read-write overheads. | ||
dE = (max(sc, sc+dsc) > sc_target ? sc_weight : 0) * dsc + dtc | ||
if rand() < exp(-β*dE) | ||
update_tree!(tree, rule) | ||
end | ||
for subtree in siblings(tree) | ||
optimize_subtree!(subtree, β, log2_sizes, sc_target, sc_weight) | ||
end | ||
end | ||
return tree | ||
end | ||
|
||
ruleset(::LeafNode) = 1:-1 | ||
@inline function ruleset(tree::ExprTree) | ||
if tree.left isa ExprTree && tree.right isa ExprTree | ||
return 1:4 | ||
elseif tree.left isa ExprTree | ||
return 1:2 | ||
elseif tree.right isa ExprTree | ||
return 3:4 | ||
else | ||
return 1:0 | ||
end | ||
end | ||
|
||
function tcsc_diff(tree::ExprTree, rule, log2_sizes) | ||
if rule == 1 # (a,b), c -> (a,c),b | ||
return abcacb(labels(tree.left.left), labels(tree.left.right), labels(tree.right), labels(tree), log2_sizes) | ||
elseif rule == 2 # (a,b), c -> (c,b),a | ||
return abcacb(labels(tree.left.right), labels(tree.left.left), labels(tree.right), labels(tree), log2_sizes) | ||
elseif rule == 3 # a,(b,c) -> b,(a,c) | ||
return abcacb(labels(tree.right.right), labels(tree.right.left), labels(tree.left), labels(tree), log2_sizes) | ||
else # a,(b,c) -> c,(b,a) | ||
return abcacb(labels(tree.right.left), labels(tree.right.right), labels(tree.left), labels(tree), log2_sizes) | ||
end | ||
end | ||
|
||
function abcacb(a, b, c, d, log2_sizes) | ||
tc0, sc0 = _tcsc_merge(a, b, c, d, log2_sizes) | ||
tc1, sc1 = _tcsc_merge(a, c, b, d, log2_sizes) | ||
return tc1-tc0, sc1-sc0 | ||
end | ||
|
||
function _tcsc_merge(a, b, c, d, log2_sizes) | ||
ab = Int[] | ||
for l in a | ||
if l ∈ c || l ∈ d # suppose no repeated indices | ||
push!(ab, l) | ||
end | ||
end | ||
for l in b | ||
if l ∉ a && (l ∈ c || l ∈ d) # suppose no repeated indices | ||
push!(ab, l) | ||
end | ||
end | ||
ab = (a ∪ b) ∩ (c ∪ d) | ||
tcl, scl = tcsc(a, b, ab, log2_sizes) # this is correct | ||
tcr, scr = tcsc(ab, c, d, log2_sizes) | ||
log2sumexp2([tcl, tcr]), max(scl, scr) | ||
end | ||
|
||
function update_tree!(tree::ExprTree, rule::Int) | ||
if rule == 1 # (a,b), c -> (a,c),b | ||
b, c = tree.left.right, tree.right | ||
tree.left.right = c | ||
tree.right = b | ||
tree.left.info = ExprInfo((labels(tree.left.left) ∪ labels(tree.left.right)) ∩ (labels(tree) ∪ labels(tree.right))) | ||
elseif rule == 2 # (a,b), c -> (c,b),a | ||
a, c = tree.left.left, tree.right | ||
tree.left.left = c | ||
tree.right = a | ||
tree.left.info = ExprInfo((labels(tree.left.left) ∪ labels(tree.left.right)) ∩ (labels(tree) ∪ labels(tree.right))) | ||
elseif rule == 3 # a,(b,c) -> b,(a,c) | ||
a, b = tree.left, tree.right.left | ||
tree.left = b | ||
tree.right.left = a | ||
tree.right.info = ExprInfo((labels(tree.right.left) ∪ labels(tree.right.right)) ∩ (labels(tree) ∪ labels(tree.left))) | ||
else # a,(b,c) -> c,(b,a) | ||
a, c = tree.left, tree.right.right | ||
tree.left = c | ||
tree.right.right = a | ||
tree.right.info = ExprInfo((labels(tree.right.left) ∪ labels(tree.right.right)) ∩ (labels(tree) ∪ labels(tree.left))) | ||
end | ||
return tree | ||
end | ||
|
||
# from label to integer. | ||
function _label_dict(@nospecialize(code::EinCode{ixs, iy})) where {ixs, iy} | ||
ixsv, iyv = collect.(ixs), collect(iy) | ||
v = unique(vcat(ixsv..., iyv)) | ||
labels = Dict(zip(v, 1:length(v))) | ||
return labels | ||
end | ||
|
||
ExprTree(code::NestedEinsum) = _exprtree(code, _label_dict(OMEinsum.flatten(code))) | ||
function _exprtree(code::NestedEinsum, labels) | ||
@assert length(code.args) == 2 | ||
ExprTree(map(enumerate(code.args)) do (i,arg) | ||
if arg isa Int | ||
LeafNode(arg, [labels[i] for i=OMEinsum.getixs(code.eins)[i]]) | ||
else | ||
res = _exprtree(arg, labels) | ||
end | ||
end..., ExprInfo([labels[i] for i=OMEinsum.getiy(code.eins)])) | ||
end | ||
|
||
OMEinsum.NestedEinsum(expr::ExprTree) = _nestedeinsum(expr, 1:maxlabel(expr)) | ||
OMEinsum.NestedEinsum(expr::ExprTree, labelmap) = _nestedeinsum(expr, labelmap) | ||
function _nestedeinsum(tree::ExprTree, lbs) | ||
eins = EinCode(((getindex.(Ref(lbs), labels(tree.left))...,), (getindex.(Ref(lbs), labels(tree.right))...,)), (getindex.(Ref(lbs), labels(tree))...,)) | ||
NestedEinsum((_nestedeinsum(tree.left, lbs), _nestedeinsum(tree.right, lbs)), eins) | ||
end | ||
_nestedeinsum(tree::LeafNode, lbs) = tree.tensorid | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,7 @@ end | |
@testset "sa" begin | ||
include("sa.jl") | ||
end | ||
|
||
@testset "treesa" begin | ||
include("treesa.jl") | ||
end |
Oops, something went wrong.