Skip to content

Commit

Permalink
[TIR] Move UnifyThreadBinding to earlier stage (#9365)
Browse files Browse the repository at this point in the history
* Move unify thread binding to earlier stage

* Unify thread binding support AttrStmt
  • Loading branch information
vinx13 authored Oct 26, 2021
1 parent 133a7dc commit 4d0cfd9
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 123 deletions.
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down
90 changes: 70 additions & 20 deletions src/tir/transforms/unify_thread_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../support/utils.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {

using support::StartsWith;

/*!
* \brief A mutator which searches AttrStmts of thread bindings and changes the `node` field IterVar
* of the AttrStmts, so that for one kind of thread binding, all such thread bindings use the same
Expand All @@ -41,24 +44,41 @@ class ThreadBindingUnifier : public StmtExprMutator {
static Stmt Unify(Stmt stmt) { return ThreadBindingUnifier()(std::move(stmt)); }

private:
Stmt VisitStmt_(const AttrStmtNode* attr) final {
Stmt VisitStmt_(const AttrStmtNode* op) final {
// If this AttrStmt is not thread binding attribute, return as usual.
if (attr->attr_key != attr::thread_extent && attr->attr_key != attr::virtual_thread) {
return StmtMutator::VisitStmt_(attr);
if (op->attr_key != attr::thread_extent && op->attr_key != attr::virtual_thread) {
return StmtMutator::VisitStmt_(op);
}
IterVar old_iter_var = Downcast<IterVar>(op->node);
return UnifyThreadBindingImpl(op, old_iter_var->var, old_iter_var, old_iter_var->dom);
}

Stmt VisitStmt_(const ForNode* op) final {
// If this For is not thread binding attribute, return as usual.
if (op->kind != ForKind::kThreadBinding) {
return StmtExprMutator::VisitStmt_(op);
}
return UnifyThreadBindingImpl(op, op->loop_var, op->thread_binding.value(),
Range::FromMinExtent(op->min, op->extent));
}

// Step 1. Fetch the old IterVar and the thread tag.
IterVar old_iter_var = Downcast<IterVar>(attr->node);
template <typename Node>
Stmt UnifyThreadBindingImpl(const Node* op, const Var& old_var, const IterVar& old_iter_var,
const Range& dom) {
// Step 1. Fetch the thread tag.
IterVar new_iter_var{nullptr};
const String& thread_tag = old_iter_var->thread_tag;

// Step 2: Increase `thread_block_depth_` if the thread tag starts with "blockIdx". If the
// thread block depth is 0 before the increasement, it means we are entering a new kernel, and
// therefore we need to make `thread_tag2iter_var_map_` empty, as different kernels can have
// thread axes with different extents.
if (std::string(thread_tag).substr(0, 9) == "blockIdx.") {
bool is_kernel_launch_scope = false;
int old_thread_block_depth = thread_block_depth_;
if (StartsWith(thread_tag, "blockIdx.") || !thread_block_depth_) {
if (!thread_block_depth_) {
thread_tag2iter_var_map_.clear();
is_kernel_launch_scope = true;
}
++thread_block_depth_;
}
Expand All @@ -69,31 +89,56 @@ class ThreadBindingUnifier : public StmtExprMutator {
Map<String, IterVar>::iterator it = thread_tag2iter_var_map_.find(thread_tag);
if (it != thread_tag2iter_var_map_.end()) {
new_iter_var = (*it).second;
CHECK(ana.CanProveEqual(old_iter_var->dom->extent, (*it).second->dom->extent))
ICHECK(ana.CanProveEqual(dom->min, new_iter_var->dom->min));
CHECK(ana.CanProveEqual(dom->extent, new_iter_var->dom->extent))
<< "ValueError: All loops that are bound to `" << thread_tag
<< "` should have the same extent. However, there are two loops with extent "
<< (*it).second->dom->extent << " and " << old_iter_var->dom->extent
<< ", which are not equal";
<< new_iter_var->dom->extent << " and " << dom->extent << ", which are not equal";
} else {
ObjectPtr<IterVarNode> p_new_iter_var = make_object<IterVarNode>(*old_iter_var.get());
p_new_iter_var->var = Var(thread_tag);
p_new_iter_var->dom = dom;
new_iter_var = IterVar(p_new_iter_var);
thread_tag2iter_var_map_.Set(thread_tag, new_iter_var);
launch_threads_.push_back(new_iter_var);
}

// Step 4. We will substitute the occurrences of the old variable in the old IterVar with the
// new variable in further mutation. Thus, we store the mapping entry.
var_substitution_map_.Set(old_iter_var->var, new_iter_var->var);

// Step 5. Mutate recursively, update the AttrStmt with the new IterVar, and decrease the depth
// counter if the thread tag starts with "blockIdx".
AttrStmt new_attr = Downcast<AttrStmt>(StmtMutator::VisitStmt_(attr));
ObjectPtr<AttrStmtNode> p_new_attr = CopyOnWrite(new_attr.get());
p_new_attr->node = new_iter_var;
if (std::string(thread_tag).substr(0, 9) == "blockIdx.") {
--thread_block_depth_;
var_substitution_map_.Set(old_var, new_iter_var->var);

// Step 5. Mutate recursively, update the body with the new IterVar, and restore the depth
// counter. Emit for-loops to launch threads if current statement is the outermost thread
// binding of the kernel.
Stmt new_stmt = StmtMutator::VisitStmt_(op);
auto* new_node = new_stmt.as<Node>();
ICHECK(new_node);
thread_block_depth_ = old_thread_block_depth;
if (is_kernel_launch_scope) {
return EmitLaunchThreads(new_node->body);
} else {
return new_node->body;
}
return Stmt(p_new_attr);
}

/*!
* \brief Emit loop nests representing all thread bindings of the kernel
* \param body The body of the innermost loop of the thread bindings.
* \return The loop nests of the thread bindings.
*/
Stmt EmitLaunchThreads(const Stmt& body) {
Stmt result = body;
while (!launch_threads_.empty()) {
const IterVar& thread_binding = launch_threads_.back();
// Recreate the IterVar as we don't duplicate `dom` in both For and IterVar. This is
// necessary for unit tests.
result = For(thread_binding->var, thread_binding->dom->min, thread_binding->dom->extent,
ForKind::kThreadBinding, result,
IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
thread_binding->thread_tag));
launch_threads_.pop_back();
}
return result;
}

PrimExpr VisitExpr_(const VarNode* var) final {
Expand All @@ -106,8 +151,13 @@ class ThreadBindingUnifier : public StmtExprMutator {
/*!
* \brief A mapping from a thread tag to its corresponding IterVar that is shared by all
* occurrences of the thread tag
* */
*/
Map<String, IterVar> thread_tag2iter_var_map_;
/*!
* \brief A list of IterVar corresponding to threads in current kernel. This will be used to
* generate for-loops to launch threads.
*/
Array<IterVar> launch_threads_;
/*! \brief A mapping from old variables to new variables, which is used for substitution */
Map<Var, Var> var_substitution_map_;
/*! \brief A integer counter storing the depth of thread bindings of "blockIdx.x/y/z" */
Expand Down
Loading

0 comments on commit 4d0cfd9

Please sign in to comment.