Skip to content

Commit

Permalink
[BUILD] Simplify after bind device type (apache#2670)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and wweic committed Mar 9, 2019
1 parent beac50c commit f96b27b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
41 changes: 37 additions & 4 deletions src/pass/make_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,13 @@ class DeviceTypeBinder: public IRMutator {
explicit DeviceTypeBinder(int device_type)
: device_type_(device_type) {}

Stmt Mutate_(const AttrStmt* op, const Stmt &s) final {
Stmt Mutate_(const AttrStmt* op, const Stmt& s) final {
if (op->attr_key == attr::device_context_type) {
if (const Variable* var = op->value.as<Variable>()) {
std::unordered_map<const Variable*, Expr> dmap;
var_ = var;
Expr value = make_const(op->value.type(), device_type_);
dmap[var] = value;
Stmt body = Substitute(s, dmap);
Stmt body = IRMutator::Mutate_(op, s);
var_ = nullptr;
std::ostringstream os;
os << "device_type need to be " << device_type_;
return AssertStmt::make(op->value == value, os.str(), body);
Expand All @@ -191,7 +191,40 @@ class DeviceTypeBinder: public IRMutator {
return IRMutator::Mutate_(op, s);
}

Stmt Mutate_(const IfThenElse* op, const Stmt& s) final {
// eager simplify if guard.
Stmt res = IRMutator::Mutate_(op, s);
op = res.as<IfThenElse>();
if (is_zero(op->condition)) {
if (op->else_case.defined()) return op->else_case;
return Evaluate::make(0);
}
if (is_one(op->condition)) {
return op->then_case;
}
return res;
}

Expr Mutate_(const NE* op, const Expr& e) final {
// eager check NE for device check
Expr res = IRMutator::Mutate_(op, e);
op = res.as<NE>();
if (ir::Equal(op->a, op->b)) {
return make_const(op->type, false);
}
return res;
}

Expr Mutate_(const Variable* op, const Expr& e) final {
if (op == var_) {
return make_const(op->type, device_type_);
} else {
return e;
}
}

public:
const Variable* var_{nullptr};
int device_type_;
};

Expand Down
5 changes: 1 addition & 4 deletions tests/python/unittest/test_codegen_c_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ def test_add():
s = tvm.create_schedule(C.op)

def check_c():
f1 = tvm.lower(s, [A, B, C], name="fadd")
fsplits = [x for x in tvm.ir_pass.SplitHostDevice(f1)]
fsplits[0] = tvm.ir_pass.LowerTVMBuiltin(fsplits[0])
mhost = tvm.codegen.build_module(fsplits[0], "c")
mhost = tvm.build(s, [A, B, C], "c", name="fadd")
temp = util.tempdir()
path_dso = temp.relpath("temp.so")
mhost.export_library(path_dso)
Expand Down

0 comments on commit f96b27b

Please sign in to comment.