Skip to content

Commit

Permalink
enhance pragma to support single point copy (apache#863)
Browse files Browse the repository at this point in the history
* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op

* repare address offset for different kinds of dtype

* bc

* aaa

* aaaaa

* repare address for different dtypes

* remove nonsense files

* add whitespace of line 581

* use base alloc elem_type

* enhance the testcast of basic buffer is 64bits,32bits,16bits,8bits

* use extends[0]->type() as dtype of offset

* clear program writes

* enhance inject_copy_intin to support of pragma stmt with no loops

* fix cpplint errors

* fix cpplint error of !

* enhance detectLinearEquation to support with no loop vars

* fix cpplint errors
  • Loading branch information
libing4752 authored and tqchen committed Feb 4, 2018
1 parent 0ca5364 commit fbb472b
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 22 deletions.
31 changes: 17 additions & 14 deletions src/arithmetic/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,28 @@ class LinearEqDetector
};

Array<Expr> DetectLinearEquation(const Expr& e, const Array<Var>& vars) {
CHECK_GE(vars.size(), 1U);
Expr base = e;
Array<Expr> coeff;

for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
if (0 == vars.size()) {
coeff.push_back(make_const(Int(32), 1));
} else {
for (Var v : vars) {
LinearEqEntry ret;
if (!LinearEqDetector(v).Detect(base, &ret)) {
return Array<Expr>();
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}
coeff.push_back(ret.coeff);
base = std::move(ret.base);
}

std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>();
std::unordered_set<const Variable*> vset;
for (size_t i = vars.size(); i != 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset)) {
return Array<Expr>();
}
}
}
coeff.push_back(base);
Expand Down
26 changes: 18 additions & 8 deletions src/pass/inject_copy_intrin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CopyIntrinInjector : public IRMutator {
private:
bool MatchCopyPattern(Stmt stmt, Stmt *out) {
Stmt body = stmt;
bool is_single_point_copy = false;

// strip the loops
std::vector<const For*> loops;
Expand All @@ -53,7 +54,10 @@ class CopyIntrinInjector : public IRMutator {
const Select* select = store->value.as<Select>();
const Cast* cast = store->value.as<Cast>();
const Load* load = store->value.as<Load>();

if (0 == loops.size()) {
is_single_point_copy = true;
CHECK(select == nullptr);
}
// for now only support true condition matching
if (select != nullptr) {
load = select->true_value.as<Load>();
Expand All @@ -74,13 +78,19 @@ class CopyIntrinInjector : public IRMutator {
arith::DetectLinearEquation(load->index, loop_vars);
if (load_strides.size() == 0 || store_strides.size() == 0) return false;
Array<Expr> dst_shape;
for (const For* op : loops) {
dst_shape.push_back(op->extent);
auto loop_var_size = loop_vars.size();
if (is_single_point_copy) {
loop_var_size = 1;
dst_shape.push_back(make_const(Int(32), 1));
} else {
for (const For* op : loops) {
dst_shape.push_back(op->extent);
}
}
Array<Expr> src_shape = dst_shape;
Array<Expr> pad_before, pad_after;
Expr pad_value;
Expr src_elem_offset = load_strides[loop_vars.size()];
Expr src_elem_offset = load_strides[loop_var_size];
if (select != nullptr) {
Array<Expr> clip_bound =
arith::DetectClipBound(select->condition, loop_vars);
Expand Down Expand Up @@ -114,15 +124,15 @@ class CopyIntrinInjector : public IRMutator {
src_elem_offset = Simplify(src_elem_offset);
}
CHECK_EQ(load_strides.size(), store_strides.size());
CHECK_EQ(load_strides.size(), loop_vars.size() + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_vars.size());
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_vars.size());
CHECK_EQ(load_strides.size(), loop_var_size + 1);
Array<Expr> src_strides(load_strides.begin(), load_strides.begin() + loop_var_size);
Array<Expr> dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size);
Buffer dst = BufferNode::make(
Var(store->buffer_var.node_),
store->value.type(),
dst_shape,
dst_strides,
store_strides[loop_vars.size()],
store_strides[loop_var_size],
store->buffer_var->name_hint,
GetStorageScope(store->buffer_var.get()),
0, 0);
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_pass_inject_copy_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ def cb(src, dst, pad_before, pad_after, pad_value):
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def test_single_point_test():
A = tvm.placeholder((1,), name='A')
B = tvm.compute((1,), lambda i:
A[i], name='B')
s = tvm.create_schedule(B.op)
s[B].pragma(B.op.axis[0], "memcpy")
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64)
def cb(src, dst, pad_before, pad_after, pad_value):
assert tvm.ir_pass.Simplify(src.elem_offset).value == 0
assert tvm.ir_pass.Simplify(dst.elem_offset).value == 0
assert tvm.ir_pass.Simplify(src.strides[0]).value == 1
assert tvm.ir_pass.Simplify(dst.strides[0]).value == 1
return tvm.make.Evaluate(0)
stmt = tvm.ir_pass.InjectCopyIntrin(stmt, "memcpy", cb)

def assert_expr_equal(a, b):
assert tvm.ir_pass.Simplify(a - b).value == 0

Expand Down Expand Up @@ -80,3 +99,4 @@ def cb(src, dst, pad_before, pad_after, pad_value):
test_copy2d()
test_copy_pad()
test_copy_pad_split()
test_single_point_test()

0 comments on commit fbb472b

Please sign in to comment.