Skip to content

Commit

Permalink
CUDA device API & VerifyGPUCode pass update (apache#5898)
Browse files Browse the repository at this point in the history
* Add kMaxRegistersPerBlock device api for cuda

* Add vectorize check to verify_gpu_code

* Lint fix

* Cast fix
  • Loading branch information
jcf94 authored and zhiics committed Jul 2, 2020
1 parent 6e4813f commit 2134265
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 12 deletions.
3 changes: 2 additions & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ enum DeviceAttrKind : int {
kMaxClockRate = 6,
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8,
kGcnArch = 9
kMaxRegistersPerBlock = 9,
kGcnArch = 10
};

/*! \brief Number of bytes each allocation must align to */
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ class CUDADeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
case kMaxRegistersPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxRegistersPerBlock, ctx.device_id));
break;
}
case kGcnArch:
return;
}
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@
case kMaxThreadDimensions:
return;
case kExist:
break;
return;
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
*rv = ss.str();
break;
}
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = ss.str();
return;
}
case kMaxRegistersPerBlock:
return;
case kGcnArch: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
*rv = ss.str();
break;
}
case kMaxRegistersPerBlock:
return;
case kGcnArch:
return;
}
Expand Down
42 changes: 32 additions & 10 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,22 @@
namespace tvm {
namespace tir {

class GPUCodeVerifier : public StmtVisitor {
class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z) {
int64_t max_thread_z, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
max_thread_x_ = static_cast<size_t>(max_thread_x);
max_thread_y_ = static_cast<size_t>(max_thread_y);
max_thread_z_ = static_cast<size_t>(max_thread_z);
max_vector_bytes_ = static_cast<size_t>(max_vector_bytes);

Reset_();

// TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);

return valid_;
Expand All @@ -62,6 +64,9 @@ class GPUCodeVerifier : public StmtVisitor {
size_t size = static_cast<size_t>(op->constant_allocation_size());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
}
}

void VisitStmt_(const AttrStmtNode* op) final {
Expand Down Expand Up @@ -129,6 +134,17 @@ class GPUCodeVerifier : public StmtVisitor {
}
}

void VisitExpr_(const LoadNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
}
}
ExprVisitor::VisitExpr_(op);
}

private:
int nest_level_{0};

Expand All @@ -146,6 +162,7 @@ class GPUCodeVerifier : public StmtVisitor {
size_t max_shared_memory_per_block_;
size_t max_threads_per_block_;
size_t max_thread_x_, max_thread_y_, max_thread_z_;
size_t max_vector_bytes_;

bool valid_{true};

Expand All @@ -169,27 +186,32 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
int64_t max_thread_x = INT64_MAX;
int64_t max_thread_y = INT64_MAX;
int64_t max_thread_z = INT64_MAX;
int64_t max_vector_bytes = INT64_MAX;

for (auto iter : constraints) {
const IntImmNode* val = iter.second.as<IntImmNode>();
if (iter.first == "max_local_memory_per_block")
if (iter.first == "max_local_memory_per_block") {
max_local_memory_per_block = val->value;
else if (iter.first == "max_shared_memory_per_block")
} else if (iter.first == "max_shared_memory_per_block") {
max_shared_memory_per_block = val->value;
else if (iter.first == "max_threads_per_block")
} else if (iter.first == "max_threads_per_block") {
max_threads_per_block = val->value;
else if (iter.first == "max_thread_x")
} else if (iter.first == "max_thread_x") {
max_thread_x = val->value;
else if (iter.first == "max_thread_y")
} else if (iter.first == "max_thread_y") {
max_thread_y = val->value;
else if (iter.first == "max_thread_z")
} else if (iter.first == "max_thread_z") {
max_thread_z = val->value;
else
} else if (iter.first == "max_vector_bytes") {
max_vector_bytes = val->value;
} else {
LOG(FATAL) << "Invalid check item: " << iter.first;
}
}

return verifier.Verify(func->body, max_local_memory_per_block, max_shared_memory_per_block,
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z);
max_threads_per_block, max_thread_x, max_thread_y, max_thread_z,
max_vector_bytes);
}

TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_tir_analysis_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,35 @@ def test_wrong_bind():
tvm.build(s, [A, B], target)
assert not valid[0]

def test_vectorize():
N = 1024

A = te.placeholder((N, N), name='A')
B = te.compute((N, N), lambda i, j: A[i, j])

s = te.create_schedule([B.op])

i, j = s[B].op.axis

s[B].bind(i, te.thread_axis("blockIdx.x"))
jo, ji = s[B].split(j, factor=64)
s[B].bind(jo, te.thread_axis("threadIdx.x"))
s[B].vectorize(ji)

for target in ['opencl', 'cuda']:
if not tvm.context(target).exist:
continue

valid = [None]
with tvm.transform.PassContext(config={"tir.add_lower_pass": [
(2, get_verify_pass(valid, max_vector_bytes=16))]}):
tvm.lower(s, [A, B])
assert not valid[0]

if __name__ == "__main__":
test_local_memory()
test_shared_memory()
test_num_thread()
test_multiple_kernels()
test_wrong_bind()
test_vectorize()

0 comments on commit 2134265

Please sign in to comment.