Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] RewriteForTensorCore -> te/schedule #5379

Merged
merged 2 commits into from
Apr 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions include/tvm/te/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,23 @@
namespace tvm {
namespace te {

/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);

/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

/*!
* \brief Infer the bound of all iteration variables relates to the schedule.
*
Expand All @@ -55,6 +72,21 @@ Map<IterVar, Range> InferBound(const Schedule& sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool debug_keep_trivial_loop);


/*!
* \brief Try to modify the AST generated by ScheduleOps to support TensorCore.
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer);

/*!
* \brief Postprocessing the Stmt generated by ScheduleOps to create
* a PrimFunc that can then be used for further TIR optimizations.
Expand All @@ -75,23 +107,6 @@ PrimFunc SchedulePostProcToPrimFunc(Array<ObjectRef> arg_list,
Stmt body,
Optional<Map<Tensor, Buffer>> bindings);

/*!
* \brief To automatically inline the element-wise operations.
*
* \param sch The schedule to be inlined.
*/
void AutoInlineElemWise(Schedule sch);

/*!
* \brief To automatically inline operations with injective writes
* (i.e. writes without reduction or sequential loops). Note
* that in this case, guarantees about contiguity, transpose, stride,
* alignemnt and memory footprint in general do not hold.
*
* \param sch The schedule to be inlined.
*/
TVM_DLL void AutoInlineInjective(Schedule sch);

} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_PASS_H_
13 changes: 0 additions & 13 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,19 +164,6 @@ Stmt Inline(Stmt stmt,
Array<Var> args,
PrimExpr body);

/*!
* \brief Try to modify the AST to support TensorCore
*
* \param stmt The stmt to be trasnformed.
* \param schedule The original schedule.
* \param extern_buffer Map specifies external
* buffer assignment of input and outputs.
* \return Transformed stmt.
*/
Stmt RewriteForTensorCore(Stmt stmt,
te::Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
45 changes: 27 additions & 18 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,43 @@ def get_binds(args, compact=False, binds=None):
return binds, arg_list


def form_body(sch):
def form_irmodule(sch, args, name, binds):
"""According to the given schedule, form a function.
Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns
-------
The body formed according to the given schedule
"""
# normalize schedule first
cfg = BuildConfig.current()
sch = sch.normalize()
bounds = schedule.InferBound(sch)
stmt = schedule.ScheduleOps(sch, bounds)
return stmt

compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)

stmt = schedule.SchedulePostProcRewriteForTensorCore(stmt, sch, binds)
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
return tvm.IRModule({name: func})


def _wrap_as_prim_func_pass(flist, name):
Expand Down Expand Up @@ -166,24 +186,13 @@ def lower(sch,

# Phase 0
if isinstance(sch, schedule.Schedule):
stmt = form_body(sch)

for f in lower_phase0:
stmt = f(stmt)

compact = ir_pass.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, binds)
stmt = ir_pass.RewriteForTensorCore(stmt, sch, binds)

# Start the new style pass manager.
func = schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)
func = func.with_attr("global_symbol", name)
if cfg.restricted_func:
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
mod = form_irmodule(sch, args, name, binds)
else:
mod = sch

# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/te/hybrid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
# 2. Support multi-level HalideIR
import inspect
import tvm._ffi
from tvm.driver.build_module import form_body
import tvm.te.schedule
from tvm._ffi.base import decorate

from .module import HybridModule
Expand Down Expand Up @@ -87,8 +87,10 @@ def build(sch, inputs, outputs, name="hybrid_func"):
The built results is wrapped in a HybridModule.
The usage of HybridModule is roughly the same as normal TVM-built modules.
"""
sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds)

stmt = form_body(sch)
src = _Dump(stmt, inputs, outputs, name)

return HybridModule(src, name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
*/

/*!
* \file tensor_core.cc
* \file schedule_postproc_rewrite_for_tensor_core.cc
*
* \brief Rewrite the Stmt generated by ScheduleOps
* to accomondate tensorcore.
*/
// IR Passes for TensorCore CodeGen
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/te/operation.h>
Expand All @@ -32,12 +35,11 @@
#include <tvm/target/target.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
#include "ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace tir {
namespace te {

using namespace te;
using runtime::StorageRank;
Expand Down Expand Up @@ -86,10 +88,10 @@ class MMAMatcher: public StmtVisitor {
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::pragma_tensor_core) {
if (op->attr_key == tir::attr::pragma_tensor_core) {
tensor_core_on_ = true;
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else {
Expand Down Expand Up @@ -414,18 +416,18 @@ class BufferAnalyser : public StmtExprVisitor {
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
if (op->attr_key == tir::attr::thread_extent) {
if (const IntImmNode* value = op->value.as<IntImmNode>()) {
thread_extent_.insert(
std::make_pair(
op->node.as<IterVarNode>()->var->name_hint,
value->value));
}
StmtExprVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::realize_scope) {
} else if (op->attr_key == tir::attr::realize_scope) {
storage_scope_[op->node.get()] = op->value.as<StringImmNode>()->value;
this->VisitStmt(op->body);
} else if (op->attr_key == attr::buffer_dim_align) {
} else if (op->attr_key == tir::attr::buffer_dim_align) {
te::Tensor tensor = Downcast<te::Tensor>(op->node);
const CallNode* tuple = op->value.as<CallNode>();
CHECK(tuple && tuple->is_intrinsic(intrinsic::tvm_tuple));
Expand Down Expand Up @@ -850,7 +852,7 @@ class TensorCoreIRMutator : public StmtExprMutator {

Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
if (op->attr_key == attr::realize_scope) {
if (op->attr_key == tir::attr::realize_scope) {
auto node = op->node.as<te::OperationNode>();
if (node != nullptr) {
if (!frag_reg_.count(node->name)) {
Expand Down Expand Up @@ -1186,9 +1188,10 @@ class TensorCoreIRMutator : public StmtExprMutator {
int warp_threads_y_{-1};
};

Stmt RewriteForTensorCore(Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
Stmt SchedulePostProcRewriteForTensorCore(
Stmt stmt,
Schedule schedule,
Map<Tensor, Buffer> extern_buffer) {
// Check if current lower target is CUDA
auto target = tvm::Target::Current(true);
if (target.defined() && target->target_name != "cuda") {
Expand Down Expand Up @@ -1223,5 +1226,13 @@ Stmt RewriteForTensorCore(Stmt stmt,
return TensorCoreIRMutator(schedule_analyser, buffer_analyser)(std::move(stmt));
}

} // namespace tir
TVM_REGISTER_GLOBAL("schedule.SchedulePostProcRewriteForTensorCore")
.set_body_typed([](Stmt stmt,
Schedule schedule,
Map<te::Tensor, Buffer> extern_buffer) {
return SchedulePostProcRewriteForTensorCore(
stmt, schedule, extern_buffer);
});

} // namespace te
} // namespace tvm
8 changes: 0 additions & 8 deletions src/tir/pass/ffi_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,6 @@ TVM_REGISTER_GLOBAL("ir_pass.Substitute")
}
});

TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed
([](const Stmt& stmt,
const te::Schedule& schedule,
const Map<te::Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});

TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ExprUseVar(args[0].operator PrimExpr(), args[1].operator Var());
Expand Down