Skip to content

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Mar 20, 2020
1 parent 35832b4 commit 178274c
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 8 deletions.
2 changes: 1 addition & 1 deletion include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ Stmt HoistIfThenElse(Stmt stmt);
* \param stmt The stmt to do datatype rewrite
* \return Transformed stmt.
*/
Stmt DataTypeRewrite(Stmt stmt);
Stmt NarrowDataType(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def lower(sch,
# Phase 1
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)
stmt = ir_pass.StorageFlatten(stmt, binds, 64, cfg.instrument_bound_checkers)
stmt = ir_pass.DataTypeRewrite(stmt)
stmt = ir_pass.NarrowDataType(stmt)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase1:
stmt = f(stmt)
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
PrimExpr end = MinNode::make((task_id + make_const(t, 1)) * step, op->extent);
CreateSerialFor(MakeValue(begin),
MakeValue(end),
llvm::ConstantInt::getSigned(LLVMType(end.dtype()), 1),
llvm::ConstantInt::getSigned(GetLLVMType(end), 1),
op->loop_var,
op->body);
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) {
CHECK(op->for_type == ForType::Serial);
}
CreateSerialFor(MakeValue(op->min), MakeValue(op->extent),
llvm::ConstantInt::getSigned(LLVMType(op->extent.dtype()), 1),
llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1),
op->loop_var, op->body);
}

Expand Down
2 changes: 1 addition & 1 deletion src/tir/pass/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,6 @@ REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS(VerifyCompactBuffer);
REGISTER_PASS(HoistIfThenElse);
REGISTER_PASS(InferFragment)
REGISTER_PASS(DataTypeRewrite);
REGISTER_PASS(NarrowDataType);
} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* \file rewrite_datatype.cc
* \file narrow_datatype.cc
* \brief narrow the datatype of indexing vars
*/

Expand All @@ -30,6 +30,28 @@
namespace tvm {
namespace tir {

// This pass narrows indexing expressions (like StoreNode::Index)
// that trivially fit into i32 to i32. Considering that i32 indices
// may be more efficient on some backends (while i64 may be more
// efficient on others, like llvm), we may want this pass when i32
// indices are more efficient.
//
// For Var v, we determine its dtype by examining all the PrimExpr
// that contains v, denoted by E = {e_0 = v, e_1, e_2, ..., e_k}.
// If all expressions in E fit into i32, then we think v can be narrowed
// to i32.
//
// To make an indexing expression i32, we must make sure that every
// component of that expression is of dtype i32. So besides Var, we
// rewrite the following inside an indexing expression
// - Var
// - IntImm
// - Cast
//
// Algorithm:
// - Use DataTypeVisitor to determine whether a Var can be narrowed or not.
// - Use DataTypeRewritter to rewrite the components of an indexing expression.

using arith::Analyzer;
using arith::IRMutatorWithAnalyzer;
using arith::ConstIntBound;
Expand Down Expand Up @@ -166,6 +188,9 @@ class DataTypeRewriter : public StmtExprMutator {
Stmt VisitStmt_(const ForNode* op) final {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<ForNode>();
CHECK(op != nullptr)
<< "Expected type to be ForNode"
<< ", but get " << s->GetTypeKey();
PrimExpr e = VisitExpr(op->loop_var);
Var var = Downcast<Var, PrimExpr>(e);
return ForNode::make(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent),
Expand All @@ -177,7 +202,13 @@ class DataTypeRewriter : public StmtExprMutator {
op->attr_key == attr::virtual_thread) {
Stmt s = StmtExprMutator::VisitStmt_(op);
op = s.as<AttrStmtNode>();
CHECK(op != nullptr)
<< "Expected type to be AttrStmtNode"
<< ", but get " << s->GetTypeKey();
const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv != nullptr)
<< "Expected type to be IterVarNode"
<< ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
Var var = Downcast<Var, PrimExpr>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
Expand Down Expand Up @@ -233,6 +264,9 @@ class DataTypeRewriter : public StmtExprMutator {
if (is_index_ && visitor_.vmap.find(op) != visitor_.vmap.end()) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
const CastNode* new_op = e.as<CastNode>();
CHECK(new_op != nullptr)
<< "Expected type to be CastNode"
<< ", but get " << e->GetTypeKey();
return CastNode::make(visitor_.vmap[op], new_op->value);
}
return StmtExprMutator::VisitExpr_(op);
Expand Down Expand Up @@ -298,6 +332,9 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=)
PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
PrimExpr e = StmtExprMutator::VisitExpr_(op);
op = e.as<CallNode>();
CHECK(op != nullptr)
<< "Expected type to be CallNode"
<< ", but get " << e->GetTypeKey();
if (op->call_type == CallNode::PureIntrinsic) {
if (op->name == intrinsic::tvm_if_then_else) {
return if_then_else(op->args[0], op->args[1], op->args[2]);
Expand All @@ -318,7 +355,7 @@ PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
return e;
}

Stmt DataTypeRewrite(Stmt stmt) {
Stmt NarrowDataType(Stmt stmt) {
return DataTypeRewriter()(stmt);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def lower(sch, args):
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
stmt = tvm.tir.ir_pass.DataTypeRewrite(stmt)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt)
return stmt


Expand Down

0 comments on commit 178274c

Please sign in to comment.