Skip to content

Commit

Permalink
[REFACTOR][TIR] Migrate Low-level Passes to Pass Manager (apache#5198)
Browse files Browse the repository at this point in the history
* [TIR][TRANSFORM] Migrate LowerIntrin

* LowerDeviceStorageAccessInfo

* Migrate LowerWarpMemory
  • Loading branch information
tqchen authored and zhiics committed Apr 17, 2020
1 parent 8407922 commit 02db3cc
Show file tree
Hide file tree
Showing 14 changed files with 319 additions and 148 deletions.
3 changes: 3 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ class IRModule : public ObjectRef {
* \return A Relay module.
*/
TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path);

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;
};

/*!
Expand Down
26 changes: 24 additions & 2 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,33 @@ TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
const tvm::Array<tvm::PrimExpr>& required);

/*!
* \brief Create PrimFuncPass to combine context calls in the host function.
* \brief Combine context calls in the host function.
*
* \return The pass.
*/
Pass CombineContextCall();
TVM_DLL Pass CombineContextCall();

/*!
* \brief Lower the target specific function intrinsics in each of the function.
*
* \return The pass.
*/
TVM_DLL Pass LowerIntrin();

/*!
* \brief Lower attached storage access information on device.
*
* \note Run this pass after all storage access analysis finish.
*
* \return The pass.
*/
TVM_DLL Pass LowerDeviceStorageAccessInfo();

/*!
* \brief Lower warp memory access to low-level device related function calls.
* \return The pass.
*/
TVM_DLL Pass LowerWarpMemory();

} // namespace transform
} // namespace tir
Expand Down
37 changes: 37 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,40 @@ def CombineContextCall():
The result pass
"""
return _ffi_api.CombineContextCall()


def LowerIntrin():
"""Lower target specific intrinsic calls.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerIntrin()


def LowerDeviceStorageAccessInfo():
"""Lower attached storage access information on device.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after all storage access analysis finish.
"""
return _ffi_api.LowerDeviceStorageAccessInfo()


def LowerWarpMemory():
"""Lower warp memory access to low-level device related function calls.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.LowerWarpMemory()
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,8 @@ TVM_REGISTER_GLOBAL("ir.Module_Add")
bool update = args[3];
CHECK(val->IsInstance<RelayExprNode>());

if (val->IsInstance<relay::FunctionNode>()) {
mod->Add(var, Downcast<relay::Function>(val), update);
if (val->IsInstance<BaseFuncNode>()) {
mod->Add(var, Downcast<BaseFunc>(val), update);
} else if (val->IsInstance<GlobalVarNode>()) {
GlobalVar gv = Downcast<GlobalVar>(val);
auto mod_copy = IRModule(make_object<IRModuleNode>(*mod.operator->()));
Expand Down
1 change: 1 addition & 0 deletions src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ TVM_REGISTER_GLOBAL("transform.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Pass pass = args[0];
IRModule mod = args[1];
ObjectRef ref = args[1];
*ret = pass(mod);
});

Expand Down
2 changes: 0 additions & 2 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ runtime::Module BuildForIRModule(const IRModule& module,
return (*bf)(module, target->str());
}



// convert legacy LoweredFunc to PrimFunc.
tir::PrimFunc ToPrimFunc(tir::LoweredFunc from) {
// remap args to attach type annotations.
Expand Down
111 changes: 0 additions & 111 deletions src/tir/pass/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,116 +235,5 @@ StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const {
return it->second;
}


class StorageAccessInfoLower : public StmtExprMutator {
public:
Stmt VisitStmt_(const AllocateNode* op) final {
// Lower allocate to device allocate when needed.
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
// For special memory, remove allocate, or use head expr
auto it = storage_info_.find(op->buffer_var.get());
if (it != storage_info_.end() && it->second.info.defined()) {
const MemoryInfo& info = it->second.info;
++it->second.alloc_count;
CHECK_LE(it->second.alloc_count, 1)
<< "Double allocation of " << it->second.scope.to_string();
if (info->head_address.defined()) {
return AllocateNode::make(
op->buffer_var, op->dtype, op->extents, op->condition,
op->body, info->head_address, "nop");
}
return op->body;
} else {
return stmt;
}
}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::storage_scope) {
const VarNode* buf = op->node.as<VarNode>();
StorageScope scope = StorageScope::make(op->value.as<StringImmNode>()->value);
StorageEntry e;
e.scope = scope;
if (scope.tag.length() != 0) {
e.info = GetMemoryInfo(op->value.as<StringImmNode>()->value);
CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string();
}
storage_info_[buf] = e;
return StmtExprMutator::VisitStmt_(op);

} else {
return StmtExprMutator::VisitStmt_(op);
}
}

PrimExpr VisitExpr_(const CallNode* op) final {
if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
return MakeAccessPtr(op);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}

private:
// tvm_access_ptr
PrimExpr MakeAccessPtr(const CallNode* op) {
// Specially handle the buffer packed intrinsic
PrimExpr expr = StmtExprMutator::VisitExpr_(op);
op = expr.as<CallNode>();
CHECK_EQ(op->args.size(), 5U);
DataType dtype = op->args[0].dtype();
const VarNode* buffer = op->args[1].as<VarNode>();
Var buffer_var = Downcast<Var>(op->args[1]);
PrimExpr offset = op->args[2];
auto it = storage_info_.find(buffer);
if (it != storage_info_.end() && it->second.info.defined()) {
return MakeTaggedAccessPtr(
op->dtype, buffer_var, dtype, offset,
it->second.info);
}
CHECK(op->dtype.is_handle());
// Change to address_of
return AddressOffset(buffer_var, dtype, offset);
}

PrimExpr MakeTaggedAccessPtr(DataType ptr_type,
Var buffer_var,
DataType dtype,
PrimExpr offset,
const MemoryInfo& info) {
if (ptr_type.is_handle()) {
CHECK(info->head_address.defined())
<< buffer_var << " is not adddressable.";
return AddressOffset(buffer_var, dtype, offset);
}
int dtype_bits = dtype.bits() * dtype.lanes();
CHECK_EQ(info->unit_bits % dtype_bits, 0);
return cast(ptr_type,
tir::Simplify(offset / make_const(
offset.dtype(), info->unit_bits / dtype_bits)));
}
// The storage entry.
struct StorageEntry {
// Whether it is tagged memory.
StorageScope scope;
// The memory info if any.
MemoryInfo info;
// Allocation counter
int alloc_count{0};
};
// The storage scope of each buffer
std::unordered_map<const VarNode*, StorageEntry> storage_info_;
};

Stmt LowerStorageAccessInfo(Stmt stmt) {
return StorageAccessInfoLower()(std::move(stmt));
}

LoweredFunc LowerDeviceStorageAccessInfo(LoweredFunc f) {
auto n = make_object<LoweredFuncNode>(*f.operator->());
n->body = LowerStorageAccessInfo(f->body);
return LoweredFunc(n);
}

} // namespace tir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/tir/transforms/combine_context_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ Pass CombineContextCall() {
n->body = ContextCallCombiner().Combine(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "CombineContextCall", {});
return CreatePrimFuncPass(pass_func, 0, "tir.CombineContextCall", {});
}

TVM_REGISTER_GLOBAL("tir.transform.CombineContextCall")
Expand Down
Loading

0 comments on commit 02db3cc

Please sign in to comment.