-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PASS]LoopPartition #56
Conversation
bounds = tvm.schedule.InferBound(s) | ||
stmt = tvm.schedule.ScheduleOps(s, bounds) | ||
stmt = tvm.ir_pass.LoopPartition(stmt) | ||
print(stmt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any suggestion to add assert statement and construction strong test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can assert the result AST
something like
str(stmt.condition) == xyz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us also add end to end test cases, to test elementwise add and gemm, to make sure the result program looks correct
include/tvm/ir_pass.h
Outdated
@@ -137,6 +137,8 @@ Stmt InjectVirtualThread(Stmt stmt); | |||
*/ | |||
Stmt LiftAllocate(Stmt stmt); | |||
|
|||
Stmt LoopPartition(Stmt stmt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always document functions in header
bounds = tvm.schedule.InferBound(s) | ||
stmt = tvm.schedule.ScheduleOps(s, bounds) | ||
stmt = tvm.ir_pass.LoopPartition(stmt) | ||
print(stmt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can assert the result AST
something like
str(stmt.condition) == xyz
bounds = tvm.schedule.InferBound(s) | ||
stmt = tvm.schedule.ScheduleOps(s, bounds) | ||
stmt = tvm.ir_pass.LoopPartition(stmt) | ||
print(stmt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us also add end to end test cases, to test elementwise add and gemm, to make sure the result program looks correct
src/pass/loop_partition.cc
Outdated
public: | ||
explicit LoopPartitioner() {} | ||
Expr Mutate(Expr e) override { | ||
return IRMutator::Mutate(e); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to override if there is no change
src/pass/loop_partition.cc
Outdated
|
||
void Visit_(const For* op) { | ||
for (auto kv : out_vars_) { | ||
if (ExprUseVar(op->min, kv.first) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider make ExprUseVar take the map, so it is one pass cost
src/pass/loop_partition.cc
Outdated
using IRMutator::Mutate; | ||
|
||
private: | ||
const std::vector<Partition>& ps_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
map of Node*->Partition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoids the linear search per expression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to remove struct Partition
, use std::unordered_map<const Node*, Interval>
directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Partition is good, if things might change in future
src/pass/loop_partition.cc
Outdated
return res; | ||
} | ||
|
||
PartitionFinder finder(op->loop_var, vars_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the subsequent logic to a separate function DoPartition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or std::vector<Partition> FindPartitions(Expr target, std::unorder_map<const Variable*, IntSet>)
? here do the real partition in fact
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is fine. The only thing is there is a check on whether you want to do partition or not, and we skipped the constant loops. So Have a function name can indicate here we really start to do partition
src/pass/loop_partition.cc
Outdated
Stmt s = simplified_stmt; | ||
|
||
if (!can_prove(true_itrv.min() == universe.min())) { | ||
Expr pre_doubt_cond = (true_itrv.min() != universe.min()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don;t need to check this cond, since they are implicit in for loop and will be checked
src/arithmetic/int_set.h
Outdated
* \return the set after intersected | ||
*/ | ||
IntSet Intersect(const Array<IntSet>& sets); | ||
IntSet Intersect(const std::vector<IntSet>& sets); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simply construct Array in outside and pass it in. Should be same as vector
@@ -16,17 +16,17 @@ def test_deduce(): | |||
d_s = tvm.arith.intset_interval(-3, -1) | |||
|
|||
e0 = (-b)*a+c-d | |||
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}) | |||
res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a testcase where relax set is presented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
tvm.make.IfThenElse( | ||
(i*m+j+k < n), tvm.make.Evaluate(m), tvm.make.Evaluate(n))))) | ||
stmt = tvm.ir_pass.LoopPartition(stmt) | ||
# assert(stmt.body.first.body.body.condition.value == 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add assert is good
src/pass/loop_partition.cc
Outdated
bool success = false; | ||
PostOrderVisit(expr, [&vars, &success](const NodeRef& node) { | ||
for (const Variable* v : vars) { | ||
if (node.get() == v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use unordered_set<Variable*>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always think about time complexity, try use O(1) data structure when we can if it does not have to be O(n) cost
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quite correct, I always not notice this
src/pass/loop_partition.cc
Outdated
return success; | ||
} | ||
|
||
inline bool IsConstDomain(Expr min, Expr extent) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simply fold it in, since the expression is so simple and fit into oneline
src/pass/loop_partition.cc
Outdated
}; | ||
|
||
std::unordered_map<const Node*, Partition> | ||
FindPartitions(VarExpr target, Stmt body, std::unordered_map<const Variable*, IntSet> vars) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I am being not clear here. I do not mean FindPartition should be a function, instead the insertPartition logic should be a single function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the logic is check if there can be partition, call insertPartition
src/pass/loop_partition.cc
Outdated
return res; | ||
} | ||
|
||
Stmt s = DoPartition(op, stmt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merge the logic
if (!is_const(op->min) || is_const(op->extent)) {
s = DoPartition()
if (s.defined()) return s;
}
the default logic
src/pass/loop_partition.cc
Outdated
private: | ||
Stmt DoPartition(const For* op, const Stmt& stmt); | ||
|
||
std::unordered_map<const Variable*, IntSet> vars_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dom_map_
src/pass/loop_partition.cc
Outdated
public: | ||
explicit PartitionFinder(VarExpr loop_var, | ||
const std::unordered_map<const Variable*, IntSet>& vars) | ||
: target_var_(loop_var), out_vars_(vars.size()), hint_map_(vars), relax_map_() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to put relax map in the initializer
src/pass/loop_partition.cc
Outdated
explicit PartitionFinder(VarExpr loop_var, | ||
const std::unordered_map<const Variable*, IntSet>& vars) | ||
: target_var_(loop_var), out_vars_(vars.size()), hint_map_(vars), relax_map_() { | ||
for (auto kv : vars) out_vars_.insert(kv.first); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto& kv :
src/pass/loop_partition.cc
Outdated
@@ -0,0 +1,203 @@ | |||
/*! | |||
* Copyright (c) 2016 by Contributors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2017
A few new fixes have gone into tvm/unity that would be useful for us to grab.
No description provided.