From 983ac9c0584a7731f1a9f8a9fc91b00d76fb0345 Mon Sep 17 00:00:00 2001 From: noituIover Date: Fri, 13 Sep 2019 09:04:52 +0800 Subject: [PATCH] Fix CUDA int8x4 vectorize (#3928) * Fix int8x4 vectorize * Fix gpu shared/local memory accumulate * Add test_shared_memory for int8x4 * Adjust test format * Fix cpplint --- src/codegen/codegen_cuda.cc | 13 +++- src/pass/verify_gpu_code.cc | 4 +- tests/python/unittest/test_codegen_cuda.py | 1 + .../unittest/test_pass_verify_gpu_code.py | 72 ++++++++++--------- 4 files changed, 53 insertions(+), 37 deletions(-) diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index d13b2c99c3bc..b48f647688c5 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -207,7 +207,11 @@ void CodeGenCUDA::PrintVecElemLoad( const std::string& vec, Type t, int i, std::ostream& os) { // NOLINT(*) static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < 4); - os << vec << "." << access[i]; + if (t.is_int() && t.bits() == 8) { + os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))"; + } else { + os << vec << "." << access[i]; + } } void CodeGenCUDA::PrintVecElemStore( @@ -215,7 +219,12 @@ void CodeGenCUDA::PrintVecElemStore( this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < 4); - stream << vec << "." << access[i] << " = " << value << ";\n"; + if (t.is_int() && t.bits() == 8) { + stream << vec << "=" << vec << " & ~(0x000000ff << " << i * 8 << ") | (" + << value << " << " << i * 8 << ");\n"; + } else { + stream << vec << "." << access[i] << " = " << value << ";\n"; + } } void CodeGenCUDA::PrintStorageSync(const Call* op) { diff --git a/src/pass/verify_gpu_code.cc b/src/pass/verify_gpu_code.cc index f1de2906385b..d5cd46b1ba09 100644 --- a/src/pass/verify_gpu_code.cc +++ b/src/pass/verify_gpu_code.cc @@ -83,10 +83,10 @@ class GPUCodeVerifier : public IRVisitor { // visit an allocation of a buffer in shared memory, record its size if (visited_local_buffers_.count(op->buffer_var.get()) != 0) { size_t size = static_cast(op->constant_allocation_size()); - local_memory_per_block_ += size * op->type.bytes(); + local_memory_per_block_ += size * op->type.bytes() * op->type.lanes(); } else if (visited_shared_buffers_.count(op->buffer_var.get()) != 0) { size_t size = static_cast(op->constant_allocation_size()); - shared_memory_per_block_ += size * op->type.bytes(); + shared_memory_per_block_ += size * op->type.bytes() * op->type.lanes(); } } diff --git a/tests/python/unittest/test_codegen_cuda.py b/tests/python/unittest/test_codegen_cuda.py index 1fb9c0abc5e8..63aaf2146ca8 100644 --- a/tests/python/unittest/test_codegen_cuda.py +++ b/tests/python/unittest/test_codegen_cuda.py @@ -52,6 +52,7 @@ def check_cuda(dtype, n, lanes): check_cuda("float32", 64, 2) check_cuda("float16", 64, 2) + check_cuda("int8", 64, 4) def test_cuda_multiply_add(): diff --git a/tests/python/unittest/test_pass_verify_gpu_code.py b/tests/python/unittest/test_pass_verify_gpu_code.py index b49b52ec46b7..bb646f40296e 100644 --- a/tests/python/unittest/test_pass_verify_gpu_code.py +++ b/tests/python/unittest/test_pass_verify_gpu_code.py @@ -24,39 +24,45 @@ def verify_pass(stmt): return verify_pass def test_shared_memory(): - N = 1024 - M = 128 - - A = tvm.placeholder((N,), name='A', dtype='float32') - B = tvm.compute((N, ), lambda i: A[i], name='B') - - s = tvm.create_schedule([B.op]) - AA = s.cache_read(A, "shared", [B]) - o, i = s[B].split(s[B].op.axis[0], M) - s[AA].compute_at(s[B], o) - s[B].bind(o, tvm.thread_axis("blockIdx.x")) - s[B].bind(i, tvm.thread_axis("threadIdx.x")) - - # shared memory usage: M * 4B - # thread usage: M - - for target in ['opencl', 'cuda']: - if not tvm.context(target).exist: - continue - valid = [None] - with tvm.build_config(**{"add_lower_pass": [ - (2, get_verify_pass(valid, - max_shared_memory_per_block=4 * M - 1, - max_threads_per_block=M))]}): - tvm.build(s, [A, B], target) - assert not valid[0] - - with tvm.build_config(**{"add_lower_pass": [ - (2, get_verify_pass(valid, - max_shared_memory_per_block=4 * M, - max_threads_per_block=M))]}): - tvm.build(s, [A, B], target) - assert valid[0] + def check_shared_memory(dtype): + N = 1024 + M = 128 + + tvm_type = tvm.datatype._TVMType(dtype) + type_size = tvm_type.bits // 8 * tvm_type.lanes + + A = tvm.placeholder((N,), name='A', dtype=dtype) + B = tvm.compute((N, ), lambda i: A[i], name='B') + + s = tvm.create_schedule([B.op]) + AA = s.cache_read(A, "shared", [B]) + o, i = s[B].split(s[B].op.axis[0], M) + s[AA].compute_at(s[B], o) + s[B].bind(o, tvm.thread_axis("blockIdx.x")) + s[B].bind(i, tvm.thread_axis("threadIdx.x")) + + # shared memory usage: M * sizeof(dtype) Bytes + # thread usage: M + + for target in ['opencl', 'cuda']: + if not tvm.context(target).exist: + continue + valid = [None] + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=type_size * M - 1, + max_threads_per_block=M))]}): + tvm.build(s, [A, B], target) + assert not valid[0] + + with tvm.build_config(**{"add_lower_pass": [ + (2, get_verify_pass(valid, + max_shared_memory_per_block=type_size * M, + max_threads_per_block=M))]}): + tvm.build(s, [A, B], target) + assert valid[0] + check_shared_memory('float32') + check_shared_memory('int8x4') def test_local_memory(): N = 1024