Skip to content

Commit

Permalink
[TIR] Enhance VerifyGPUCode (apache#6194)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and Trevor Morris committed Aug 26, 2020
1 parent dbfed2e commit d62baa6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 19 deletions.
48 changes: 29 additions & 19 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z, int64_t max_vector_bytes) {
int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
max_vthread_ = static_cast<size_t>(max_vthread);
max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);

Reset_();
Expand Down Expand Up @@ -78,7 +79,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
visited_shared_buffers_.insert(op->node.as<VarNode>());
}
StmtVisitor::VisitStmt_(op);
} else if (op->attr_key == attr::thread_extent) {
} else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
if (nest_level_ == 0) {
// enter a new kernel, reset statistics
Reset_();
Expand All @@ -88,9 +89,10 @@ class GPUCodeVerifier : public StmtExprVisitor {
const auto* extent = op->value.as<IntImmNode>();
CHECK(extent);

// record the number of threads in a block
std::string name = var.get()->name_hint;
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") {
// record the number of threads in a block
if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" ||
name == "vthread") {
size_t length = static_cast<size_t>(extent->value);
if (!visited_threads_.count(name)) {
visited_threads_.insert(name);
Expand All @@ -105,6 +107,8 @@ class GPUCodeVerifier : public StmtExprVisitor {
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
thread_z_extent_ = length;
} else if (name == "vthread") {
valid_ &= length <= max_vthread_;
}
} else {
// the thread should be bound to axes with the same length
Expand Down Expand Up @@ -134,25 +138,28 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
}

void VisitStmt_(const ForNode* op) {
if (op->loop_var->name_hint == "vthread.s") {
const auto* extent = op->extent.as<IntImmNode>();
CHECK(extent);

valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
}

StmtVisitor::VisitStmt_(op);
}

void VisitExpr_(const LoadNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
}
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
}
ExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
max_vector_bytes_;
}
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
max_vector_bytes_;
}
StmtVisitor::VisitStmt_(op);
}
Expand All @@ -173,7 +180,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t max_local_memory_per_block_;
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;
size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
size_t max_vector_bytes_;

bool valid_{true};
Expand All @@ -198,6 +205,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
int64_t max_vthread = INT64_MAX;
int64_t max_vector_bytes = INT64_MAX;

for (auto iter : constraints) {
Expand All @@ -214,6 +222,8 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
max_thread_y = val->value;
} else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
} else if (iter.first == "max_vthread") {
max_vthread = val->value;
} else if (iter.first == "max_vector_bytes") {
max_vector_bytes = val->value;
} else {
Expand All @@ -223,7 +233,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {

return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
max_vector_bytes);
max_vthread, max_vector_bytes);
}

TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_tir_analysis_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,40 @@ def test_vectorize():
tvm.lower(s, [A, B])
assert not valid[0]

def test_vthread():
N = 1024

A = te.placeholder((N, 16), name='A')
B = te.compute((N, 16), lambda i, j: A[i, j])

s = te.create_schedule([B.op])

s[B].bind(s[B].op.axis[0], te.thread_axis("blockIdx.x"))
s[B].bind(s[B].op.axis[1], te.thread_axis("vthread"))

for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue

valid = [None]

for phase in [1, 2]:
with tvm.transform.PassContext(config={"tir.add_lower_pass": [
(phase, get_verify_pass(valid, max_vthread=16))]}):
tvm.build(s, [A, B], target)
assert valid[0]

with tvm.transform.PassContext(config={"tir.add_lower_pass": [
(phase, get_verify_pass(valid, max_vthread=15))]}):
tvm.build(s, [A, B], target)
assert not valid[0]


if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_num_thread()
test_multiple_kernels()
test_wrong_bind()
test_vectorize()
test_vthread()

0 comments on commit d62baa6

Please sign in to comment.