From d7253fbc7ebdfe9ab349eb85a7db3d68260ec059 Mon Sep 17 00:00:00 2001 From: Tian Xia <74357442+Rainy-Memory@users.noreply.github.com> Date: Fri, 17 Feb 2023 09:51:59 +0800 Subject: [PATCH] [TIR] Add cp.async support for tir.if_then_else (#13966) This PR supports CUDA's cp.async ptx for un-vectorized BufferStore from a `tir.if_then_else` call and thus enables padded async copy. Co-authored-by: Junru Shao <junrushao1994@gmail.com> --- src/target/source/codegen_cuda.cc | 8 +- src/target/source/ptx.cc | 31 +++ src/target/source/ptx.h | 16 ++ src/tir/transforms/inject_ptx_async_copy.cc | 156 +++++++----- .../unittest/test_cp_async_in_if_then_else.py | 238 ++++++++++++++++++ 5 files changed, 386 insertions(+), 63 deletions(-) create mode 100644 tests/python/unittest/test_cp_async_in_if_then_else.py diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index c891ec5a28cf..9bf0109cace1 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -914,7 +914,13 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string src = this->PrintExpr(op->args[2]); std::string src_offset = this->PrintExpr(op->args[3]); std::string size = this->PrintExpr(op->args[4]); - this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + // use size of argument list to indicate whether or not to use predicated cp.async + if (op->args.size() == 5) { + this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size); + } else { + this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size, + this->PrintExpr(op->args[5])); + } } else if (op->op.same_as(builtin::ptx_commit_group())) { this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n"; } else if (op->op.same_as(builtin::ptx_wait_group())) { diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 886242efe08c..b5299b4e4b2a 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -659,5 +659,36 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, return asm_code; } +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value) { + std::string predicated_asm_code = R"( + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)({smem_addr})) + ); + int src_bytes = {pred_guard} ? {bytes} : 0; + __asm__ __volatile__( + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes) + ); + } +)"; + Replacer replacer; + replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); + replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); + replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); + replacer.register_rule("{pred_guard}", predicate_value); + predicated_asm_code = replacer.rewrite(predicated_asm_code); + return predicated_asm_code; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/source/ptx.h b/src/target/source/ptx.h index c811a1b9c1d6..1e49b57c1790 100644 --- a/src/target/source/ptx.h +++ b/src/target/source/ptx.h @@ -92,6 +92,22 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, const std::string& global_ptr, const std::string& global_elem_offset, const std::string& bytes); +/*! + * \brief Print predicated ptx cp.async assembly string given parameters. + * \param shared_ptr: The pointer to the destination shared memory. + * \param shared_elem_offset: The offset into the shared memory. + * \param global_ptr: The pointer to the global memory. + * \param global_elem_offset: The offset into the global memory. + * \param bytes: The number of bytes to copy, valid values are 4, 8, and 16. + * \param predicate_value: The value of predicate `@p`. + */ +std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr, + const std::string& shared_elem_offset, + const std::string& global_ptr, + const std::string& global_elem_offset, + const std::string& bytes, + const std::string& predicate_value); + } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 8ee0d054e56d..2e3c906e89c1 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -47,73 +47,105 @@ class PTXAsyncCopyInjector : public StmtMutator { return StmtMutator::VisitStmt_(attr); } + Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false, + PrimExpr predicate_value = PrimExpr()) { + if (load->buffer.scope() == "global") { + ICHECK(load->indices.size() == 1 && store->indices.size() == 1); + ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); + + const int indices_lanes = load->indices[0]->dtype.lanes(); + const int bytes = indices_lanes * load->buffer->dtype.bytes(); + + if (bytes == 4 || bytes == 8 || bytes == 16) { + auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); + auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); + ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) + << "Both store and load buffer should have a pointer type annotation."; + + int index_factor = 1; + if (dst_elem_type.value() != src_elem_type.value()) { + // The only case where src and dst have different dtypes is when the dst shared memory + // is a byte buffer generated by merging dynamic shared memory. + ICHECK(store->buffer.scope() == "shared.dyn"); + ICHECK(dst_elem_type.value() == DataType::UInt(8)); + // BufferStore/Load have the "pointer reinterpret" semantics according to their + // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, + // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; + // To replace BufferStore/Load with cp.async, we need to multiply the store index by + // the byte size of the "value" dtype, to get the correct offset into the byte buffer. + index_factor = src_elem_type->bytes(); + } + + if (indices_lanes == 1) { + auto src_offset = load->indices[0]; + auto dst_offset = store->indices[0]; + Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)}; + // use arguments size to indicate whether or not to use predicated cp.async + if (predicated) { + args.push_back(predicate_value); + } + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args)); + } + + // Predicated load don't support vectorized indexing. + if (!predicated) { + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance<RampNode>()) { + return load->indices[0].as<RampNode>()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as<RampNode>()) { + return store->indices[0].as<RampNode>()->base; + } else if (store->indices[0].as<AddNode>()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as<AddNode>(); + if (!add->a->IsInstance<RampNode>()) return PrimExpr(); + if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr(); + return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); + } + } + } + } + return StmtMutator::VisitStmt_(store); + } + Stmt VisitStmt_(const BufferStoreNode* store) { if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { if (auto* load = store->value.as<BufferLoadNode>()) { - if (load->buffer.scope() == "global") { - ICHECK(load->indices.size() == 1 && store->indices.size() == 1); - ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); - - const int indices_lanes = load->indices[0]->dtype.lanes(); - const int bytes = indices_lanes * load->buffer->dtype.bytes(); - - if (bytes == 4 || bytes == 8 || bytes == 16) { - auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); - auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); - ICHECK(dst_elem_type.has_value() && src_elem_type.has_value()) - << "Both store and load buffer should have a pointer type annotation."; - - int index_factor = 1; - if (dst_elem_type.value() != src_elem_type.value()) { - // The only case where src and dst have different dtypes is when the dst shared memory - // is a byte buffer generated by merging dynamic shared memory. - ICHECK(store->buffer.scope() == "shared.dyn"); - ICHECK(dst_elem_type.value() == DataType::UInt(8)); - // BufferStore/Load have the "pointer reinterpret" semantics according to their - // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, - // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; - // To replace BufferStore/Load with cp.async, we need to multiply the store index by - // the byte size of the "value" dtype, to get the correct offset into the byte buffer. - index_factor = src_elem_type->bytes(); + return InjectPTX(load, store); + } else if (auto* call = store->value.as<CallNode>()) { + // tir.if_then_else is a call to tir::builtin::if_then_else() + if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) { + if (auto* load = call->args[1].as<BufferLoadNode>()) { + // Only default value of 0 is supported since 0 is the default value used by cp.async + // ptx. @see section 9.7.8.22.3. of + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations + bool else_value_is_zero = false; + if (auto* b = call->args[2].as<BroadcastNode>()) { + if (auto* f = b->value.as<FloatImmNode>()) { + else_value_is_zero = f->value == 0.0f; + } } - - if (indices_lanes == 1) { - auto src_offset = load->indices[0]; - auto dst_offset = store->indices[0]; - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + if (auto* f = call->args[2].as<FloatImmNode>()) { + else_value_is_zero = f->value == 0.0f; } - - // Only some vectorized indexing patterns are supported for now. - auto src_offset = [=]() -> PrimExpr { - if (load->indices[0]->IsInstance<RampNode>()) { - return load->indices[0].as<RampNode>()->base; - } - return PrimExpr(); - }(); - - auto dst_offset = [=]() -> PrimExpr { - if (store->indices[0].as<RampNode>()) { - return store->indices[0].as<RampNode>()->base; - } else if (store->indices[0].as<AddNode>()) { - // The case where the dst buffer is a byte buffer generated by merging dynamic - // shared memory. - // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] - auto* add = store->indices[0].as<AddNode>(); - if (!add->a->IsInstance<RampNode>()) return PrimExpr(); - if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr(); - return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value); - } - return PrimExpr(); - }(); - - if (src_offset.defined() && dst_offset.defined()) { - return Evaluate( - Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), - {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), - load->buffer->data, src_offset, PrimExpr(bytes)})); + if (else_value_is_zero) { + return InjectPTX(load, store, true, call->args[0]); } } } diff --git a/tests/python/unittest/test_cp_async_in_if_then_else.py b/tests/python/unittest/test_cp_async_in_if_then_else.py new file mode 100644 index 000000000000..08de5ba34da1 --- /dev/null +++ b/tests/python/unittest/test_cp_async_in_if_then_else.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""test the correctness of inject async memory copy from an if_then_else load""" +import tvm +import numpy as np + +from tvm.script import tir as T +import tvm.testing + +expected_cuda_script = r""" +#ifdef _WIN32 + using uint = unsigned int; + using uchar = unsigned char; + using ushort = unsigned short; + using int64_t = long long; + using uint64_t = unsigned long long; +#else + #define uint unsigned int + #define uchar unsigned char + #define ushort unsigned short + #define int64_t long long + #define uint64_t unsigned long long +#endif +extern "C" __global__ void __launch_bounds__(16) main_kernel0(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { + __shared__ float A_shared[64]; + __shared__ float B_shared[64]; + A_shared[((int)threadIdx.x)] = 0.000000e+00f; + B_shared[((int)threadIdx.x)] = 0.000000e+00f; +__asm__ __volatile__("cp.async.commit_group;"); + + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + (((int)threadIdx.x) + 16))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(A + (((int)threadIdx.x) * 14))), "n"(4) + ); + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((int)threadIdx.x) + 16))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(B + (((int)threadIdx.x) * 14))), "n"(4) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + (((int)threadIdx.x) + 32))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(A + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + ); + } + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + (((int)threadIdx.x) + 32))) + ); + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2;" + :: "r"(addr), "l"((void*)(B + ((((int)threadIdx.x) * 14) + 1))), "n"(4) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + for (int i = 0; i < 13; ++i) { + bool cse_var_1 = (i < 12); + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) + ); + int src_bytes = cse_var_1 ? 4 : 0; + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)(A + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + +__asm__ __volatile__("cp.async.wait_group 5;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + i)] = (A_shared[(((i & 3) * 16) + ((int)threadIdx.x))] + B_shared[(((i & 3) * 16) + ((int)threadIdx.x))]); + __syncthreads(); + + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x)))) + ); + int src_bytes = cse_var_1 ? 4 : 0; + __asm__ __volatile__( + "cp.async.ca.shared.global [%0], [%1], %2, %3;" + :: "r"(addr), "l"((void*)(B + (((((int)threadIdx.x) * 14) + i) + 2))), "n"(4), "r"(src_bytes) + ); + } +__asm__ __volatile__("cp.async.commit_group;"); + + } +__asm__ __volatile__("cp.async.wait_group 2;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 13)] = (A_shared[(((int)threadIdx.x) + 16)] + B_shared[(((int)threadIdx.x) + 16)]); +__asm__ __volatile__("cp.async.wait_group 1;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 14)] = (A_shared[(((int)threadIdx.x) + 32)] + B_shared[(((int)threadIdx.x) + 32)]); +__asm__ __volatile__("cp.async.wait_group 0;"); + + __syncthreads(); + C[((((int)threadIdx.x) * 16) + 15)] = (A_shared[(((int)threadIdx.x) + 48)] + B_shared[(((int)threadIdx.x) + 48)]); +} + +""" + + +generated_code = "" +support_async = True + + +@tvm.register_func +def tvm_callback_cuda_postproc(code): + global generated_code + global support_async + generated_code = code + # return a dummy code so that device < sm80 could build correctly + if not support_async: + ret = "" + for line in code.split("\n"): + ret += line + "\n" + if line.startswith('extern "C" __global__'): + break + ret += "}" + return ret + return code + + +@tvm.testing.requires_cuda +def test_cp_async_in_if_then_else(): + global support_async + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + if major < 8: + # At least sm80 is required + support_async = False + + @T.prim_func + def simple_compute( + A: T.Buffer((16, 14), "float32"), + B: T.Buffer((16, 14), "float32"), + C: T.Buffer((16, 16), "float32"), + ): + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 16, + annotations={ + "software_pipeline_stage": [0, 0, 3], + "software_pipeline_order": [0, 2, 1], + "software_pipeline_async_stages": [0], + }, + ): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + A_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + B_shared = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(A_shared[tx, 0]) + A_shared[tx, 0] = T.if_then_else( + 1 <= i and i < 15, A[tx, i - 1], T.float32(0), dtype="float32" + ) + with T.block(): + T.reads(B[tx, i]) + T.writes(B_shared[tx, 0]) + B_shared[tx, 0] = T.if_then_else( + 1 <= i and i < 15, B[tx, i - 1], T.float32(0), dtype="float32" + ) + with T.block(): + T.reads(A_shared[tx, 0], B_shared[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = A_shared[tx, 0] + B_shared[tx, 0] + + mod = tvm.IRModule.from_expr(simple_compute) + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + tvm.build(mod, target="cuda") + + assert generated_code == expected_cuda_script + + if not support_async: + # avoid return dummy code to other tests + support_async = True + + +if __name__ == "__main__": + test_cp_async_in_if_then_else()