From d62baa6f7bb59b02cd3513e4470b57a78ce39182 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 3 Aug 2020 13:23:07 -0700 Subject: [PATCH] [TIR] Enhance VerifyGPUCode (#6194) --- src/tir/analysis/verify_gpu_code.cc | 48 +++++++++++-------- .../test_tir_analysis_verify_gpu_code.py | 30 ++++++++++++ 2 files changed, 59 insertions(+), 19 deletions(-) diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index d221dde51a02e..cce0823ca0484 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -37,13 +37,14 @@ class GPUCodeVerifier : public StmtExprVisitor { public: bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block, int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y, - int64_t max_thread_z, int64_t max_vector_bytes) { + int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) { max_local_memory_per_block_ = static_cast(max_local_memory_per_block); max_shared_memory_per_block_ = static_cast(max_shared_memory_per_block); max_threads_per_block_ = static_cast(max_threads_per_block); max_thread_x_ = static_cast(max_thread_x); max_thread_y_ = static_cast(max_thread_y); max_thread_z_ = static_cast(max_thread_z); + max_vthread_ = static_cast(max_vthread); max_vector_bytes_ = static_cast(max_vector_bytes); Reset_(); @@ -78,7 +79,7 @@ class GPUCodeVerifier : public StmtExprVisitor { visited_shared_buffers_.insert(op->node.as()); } StmtVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::thread_extent) { + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { if (nest_level_ == 0) { // enter a new kernel, reset statistics Reset_(); @@ -88,9 +89,10 @@ class GPUCodeVerifier : public StmtExprVisitor { const auto* extent = op->value.as(); CHECK(extent); - // record the number of threads in a block std::string name = var.get()->name_hint; - if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z") { + // record the number of threads in a block + if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || + name == "vthread") { size_t length = static_cast(extent->value); if (!visited_threads_.count(name)) { visited_threads_.insert(name); @@ -105,6 +107,8 @@ class GPUCodeVerifier : public StmtExprVisitor { } else if (name == "threadIdx.z") { valid_ &= length <= max_thread_z_; thread_z_extent_ = length; + } else if (name == "vthread") { + valid_ &= length <= max_vthread_; } } else { // the thread should be bound to axes with the same length @@ -134,25 +138,28 @@ class GPUCodeVerifier : public StmtExprVisitor { } } + void VisitStmt_(const ForNode* op) { + if (op->loop_var->name_hint == "vthread.s") { + const auto* extent = op->extent.as(); + CHECK(extent); + + valid_ &= static_cast(extent->value) <= max_vthread_; + } + + StmtVisitor::VisitStmt_(op); + } + void VisitExpr_(const LoadNode* op) { - // Currently not able to check out: If the index expression failed - // to be simplified to a RampNode - if (op->index->IsInstance()) { - if (op->dtype.lanes() > 1) { - valid_ &= static_cast(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_; - } + if (op->dtype.lanes() > 1) { + valid_ &= static_cast(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_; } ExprVisitor::VisitExpr_(op); } void VisitStmt_(const StoreNode* op) { - // Currently not able to check out: If the index expression failed - // to be simplified to a RampNode - if (op->index->IsInstance()) { - if (op->index->dtype.lanes() > 1) { - valid_ &= static_cast(op->index->dtype.lanes() * op->index->dtype.bytes()) <= - max_vector_bytes_; - } + if (op->index->dtype.lanes() > 1) { + valid_ &= static_cast(op->index->dtype.lanes() * op->index->dtype.bytes()) <= + max_vector_bytes_; } StmtVisitor::VisitStmt_(op); } @@ -173,7 +180,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t max_local_memory_per_block_; size_t max_shared_memory_per_block_; size_t max_threads_per_block_; - size_t max_thread_x_, max_thread_y_, max_thread_z_; + size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_; size_t max_vector_bytes_; bool valid_{true}; @@ -198,6 +205,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { int64_t max_thread_x = INT64_MAX; int64_t max_thread_y = INT64_MAX; int64_t max_thread_z = INT64_MAX; + int64_t max_vthread = INT64_MAX; int64_t max_vector_bytes = INT64_MAX; for (auto iter : constraints) { @@ -214,6 +222,8 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { max_thread_y = val->value; } else if (iter.first == "max_thread_z") { max_thread_z = val->value; + } else if (iter.first == "max_vthread") { + max_vthread = val->value; } else if (iter.first == "max_vector_bytes") { max_vector_bytes = val->value; } else { @@ -223,7 +233,7 @@ bool VerifyGPUCode(const PrimFunc& func, Map constraints) { return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block, max_threads_per_block, max_thread_x, max_thread_y, max_thread_z, - max_vector_bytes); + max_vthread, max_vector_bytes); } TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode); diff --git a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py index ece8402a77cef..2e37de49f2435 100644 --- a/tests/python/unittest/test_tir_analysis_verify_gpu_code.py +++ b/tests/python/unittest/test_tir_analysis_verify_gpu_code.py @@ -233,6 +233,35 @@ def test_vectorize(): tvm.lower(s, [A, B]) assert not valid[0] +def test_vthread(): + N = 1024 + + A = te.placeholder((N, 16), name='A') + B = te.compute((N, 16), lambda i, j: A[i, j]) + + s = te.create_schedule([B.op]) + + s[B].bind(s[B].op.axis[0], te.thread_axis("blockIdx.x")) + s[B].bind(s[B].op.axis[1], te.thread_axis("vthread")) + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + + valid = [None] + + for phase in [1, 2]: + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ + (phase, get_verify_pass(valid, max_vthread=16))]}): + tvm.build(s, [A, B], target) + assert valid[0] + + with tvm.transform.PassContext(config={"tir.add_lower_pass": [ + (phase, get_verify_pass(valid, max_vthread=15))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + if __name__ == "__main__": test_local_memory() test_shared_memory() @@ -240,3 +269,4 @@ def test_vectorize(): test_multiple_kernels() test_wrong_bind() test_vectorize() + test_vthread()