From 7df35c3561f4d8df9c19054ccb184001eddcffc0 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Wed, 7 Feb 2024 13:03:26 -0400 Subject: [PATCH] improve tensor intrin --- python/tvm/tir/tensor_intrin/cuda.py | 167 +++++++++++++++--- src/arith/ir_mutator_with_analyzer.cc | 2 +- .../primitive/layout_transformation.cc | 2 +- 3 files changed, 145 insertions(+), 26 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index de837c13f10d..a0b97659b498 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring,unused-variable """Intrinsics for tensorization on NVIDIA GPU.""" -from typing import Dict, Optional, Tuple, Literal +from typing import Dict, Optional, Tuple, Literal, List from tvm._ffi import register_func from tvm.runtime import convert @@ -24,6 +24,7 @@ from tvm.tir.function import PrimFunc from tvm.tir import Cast, IntImm, TensorIntrin + def shared_16x16_to_mma_32x8_smoothlayout(i, j): return (i * 2 + j // 8, j % 8) @@ -31,12 +32,13 @@ def shared_16x16_to_mma_32x8_smoothlayout(i, j): def shared_16x32_to_mma_32x16_smoothlayout(i, j): return (i * 2 + j // 16, j % 16) + def shared_32x16_to_mma_32x16_smoothlayout(i, j): return (i * 2 + j // 16, j % 16) def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id): - row = (thread_id % 16) + row = thread_id % 16 col = 8 * (thread_id // 16) + local_id % 8 return row, col @@ -47,6 +49,18 @@ def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id): return row, col +def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id): + row = thread_id % 16 + col = local_id + (thread_id // 16) * 16 + return row, col + + +def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id): + row = (thread_id // 16) * 8 + (thread_id % 8) + col = local_id + 16 * ((thread_id % 16) // 8) + return row, col + + def shared_16x16_to_mma_32x8_layout(i, j): thread_id = 4 * (i % 8) + (j % 8) // 2 return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) @@ -90,7 +104,7 @@ def get_ldmatrix_intrin( matrix_name: Literal["A", "B"], transposed: bool, shared_scope: str = "shared", - propagate_layout: bool = False + propagate_layout: bool = False, ): local_size = (M_DIM * k_dim) // WARP_SIZE smem_offset = None @@ -276,7 +290,8 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_f16_B_TRANS_SMOOTH_INTRIN = "mma_ldmatrix_f16_b_trans_smooth" TensorIntrin.register( - LDMATRIX_f16_B_TRANS_SMOOTH_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", True, "shared", True) + LDMATRIX_f16_B_TRANS_SMOOTH_INTRIN, + *get_ldmatrix_intrin(16, "float16", "B", True, "shared", True), ) LDMATRIX_f16_A_DYN_INTRIN = "mma_ldmatrix_f16_a_dyn" @@ -284,9 +299,10 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_f16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn") ) -LDMATRIX_f16_A_DYN_SMOOTH_INTRIN = "mma_ldmatrix_f16_a_dyn_smooth" +LDMATRIX_f16_A_DYN_SMOOTH_INTRIN = "mma_ldmatrix_f16_a_smooth_dyn" TensorIntrin.register( - LDMATRIX_f16_A_DYN_SMOOTH_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn", True) + LDMATRIX_f16_A_DYN_SMOOTH_INTRIN, + *get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn", True), ) LDMATRIX_f16_B_DYN_INTRIN = "mma_ldmatrix_f16_b_dyn" @@ -306,7 +322,8 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_f16_B_TRANS_SMOOTH_DYN_INTRIN = "mma_ldmatrix_f16_b_trans_smooth_dyn" TensorIntrin.register( - LDMATRIX_f16_B_TRANS_SMOOTH_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", True, "shared.dyn", True) + LDMATRIX_f16_B_TRANS_SMOOTH_DYN_INTRIN, + *get_ldmatrix_intrin(16, "float16", "B", True, "shared.dyn", True), ) LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_i8_a" @@ -319,6 +336,56 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True)) +LDMATRIX_i8_A_SMOOTH_INTRIN = "mma_ldmatrix_i8_a_smooth" +TensorIntrin.register( + LDMATRIX_i8_A_SMOOTH_INTRIN, *get_ldmatrix_intrin(32, "int8", "A", False, "shared", True) +) + +LDMATRIX_i8_B_SMOOTH_INTRIN = "mma_ldmatrix_i8_b_smooth" +TensorIntrin.register( + LDMATRIX_i8_B_SMOOTH_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", False, "shared", True) +) + +LDMATRIX_i8_B_TRANS_SMOOTH_INTRIN = "mma_ldmatrix_i8_b_trans_smooth" +TensorIntrin.register( + LDMATRIX_i8_B_TRANS_SMOOTH_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True, "shared", True) +) + +LDMATRIX_i8_A_DYN_INTRIN = "mma_ldmatrix_i8_a_dyn" +TensorIntrin.register( + LDMATRIX_i8_A_DYN_INTRIN, *get_ldmatrix_intrin(32, "int8", "A", False, "shared.dyn") +) + +LDMATRIX_i8_B_DYN_INTRIN = "mma_ldmatrix_i8_b_dyn" +TensorIntrin.register( + LDMATRIX_i8_B_DYN_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", False, "shared.dyn") +) + +LDMATRIX_i8_B_TRANS_DYN_INTRIN = "mma_ldmatrix_i8_b_trans_dyn" +TensorIntrin.register( + LDMATRIX_i8_B_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True, "shared.dyn") +) + + +LDMATRIX_i8_A_SMOOTH_DYN_INTRIN = "mma_ldmatrix_i8_a_smooth_dyn" +TensorIntrin.register( + LDMATRIX_i8_A_SMOOTH_DYN_INTRIN, + *get_ldmatrix_intrin(32, "int8", "A", False, "shared.dyn", True), +) + +LDMATRIX_i8_B_SMOOTH_DYN_INTRIN = "mma_ldmatrix_i8_b_smooth_dyn" +TensorIntrin.register( + LDMATRIX_i8_B_SMOOTH_DYN_INTRIN, + *get_ldmatrix_intrin(32, "int8", "B", False, "shared.dyn", True), +) + +LDMATRIX_i8_B_TRANS_SMOOTH_DYN_INTRIN = "mma_ldmatrix_i8_b_trans_smooth_dyn" +TensorIntrin.register( + LDMATRIX_i8_B_TRANS_SMOOTH_DYN_INTRIN, + *get_ldmatrix_intrin(32, "int8", "B", True, "shared.dyn", True), +) + + def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed, smooth_a=False, smooth_b=False): local_size = (M_DIM * k_dim) // WARP_SIZE local_size_out = (M_DIM * N_DIM) // 32 @@ -326,16 +393,28 @@ def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed, smooth_a=False, index_map_C = shared_16x16_to_mma_32x8_layout if k_dim == 16: - index_map_A = shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout - index_map_B = shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout + index_map_A = ( + shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout + ) + index_map_B = ( + shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout + ) mma_prefix = "m16n8k16" elif k_dim == 32 and b_transposed: - index_map_A = shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout - index_map_B = shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout + index_map_A = ( + shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout + ) + index_map_B = ( + shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout + ) mma_prefix = "m16n8k32" elif k_dim == 32 and not b_transposed: - index_map_A = shared_16x32_to_mma_32x16_layout if smooth_a else shared_16x32_to_mma_32x16_layout - index_map_B = shared_32x16_to_mma_32x16_layout if smooth_b else shared_32x16_to_mma_32x16_layout + index_map_A = ( + shared_16x32_to_mma_32x16_layout if smooth_a else shared_16x32_to_mma_32x16_layout + ) + index_map_B = ( + shared_32x16_to_mma_32x16_layout if smooth_b else shared_32x16_to_mma_32x16_layout + ) mma_prefix = "m16n8k32" else: assert False @@ -515,10 +594,15 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", False, True)) MMA_f16f16f16_TRANS_SMOOTH_B_INTRIN = "mma_f16f16f16_trans_b_smooth_b" -TensorIntrin.register(MMA_f16f16f16_TRANS_SMOOTH_B_INTRIN, *get_mma_intrin(16, "float16", False, True, False, True)) +TensorIntrin.register( + MMA_f16f16f16_TRANS_SMOOTH_B_INTRIN, *get_mma_intrin(16, "float16", False, True, False, True) +) MMA_f16f16f16_SMOOTH_A_TRANS_SMOOTH_B_INTRIN = "mma_f16f16f16_smooth_a_trans_b_smooth_b" -TensorIntrin.register(MMA_f16f16f16_SMOOTH_A_TRANS_SMOOTH_B_INTRIN, *get_mma_intrin(16, "float16", False, True, True, True)) +TensorIntrin.register( + MMA_f16f16f16_SMOOTH_A_TRANS_SMOOTH_B_INTRIN, + *get_mma_intrin(16, "float16", False, True, True, True), +) MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a" TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", True, False)) @@ -534,6 +618,17 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b" TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", False, True)) +MMA_i8i8i32_TRANS_B_SMOOTH_B_INTRIN = "mma_i8i8i32_trans_b_smooth_b" +TensorIntrin.register( + MMA_i8i8i32_TRANS_B_SMOOTH_B_INTRIN, *get_mma_intrin(32, "int32", False, True, False, True) +) + +MMA_i8i8i32_SMOOTH_A_TRANS_B_SMOOTH_B_INTRIN = "mma_i8i8i32_smooth_a_trans_b_smooth_b" +TensorIntrin.register( + MMA_i8i8i32_SMOOTH_A_TRANS_B_SMOOTH_B_INTRIN, + *get_mma_intrin(32, "int32", False, True, True, True), +) + def get_mma_fill_intrin(dtype, local_size): zero = IntImm("int32", 0).astype(dtype) @@ -657,7 +752,6 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: row, col = T.meta_var(index_map_rev(tx, local_id)) C[row, col] = C_warp[tx, local_id] - return mma_store_desc, mma_store_impl @@ -708,7 +802,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: MMA_store_16x16_f16_shared_INTRIN = "mma_store_16x16_f16_shared_" TensorIntrin.register( - MMA_store_16x16_f16_shared_INTRIN, *get_mma_store_intrin("float16", 8, "shared", False) + MMA_store_16x16_f16_shared_INTRIN, *get_mma_store_intrin("float16", 8, "shared", True) ) MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_" @@ -716,6 +810,11 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global", True) ) +MMA_store_16x16_i32_shared_INTRIN = "mma_store_16x16_i32_shared_" +TensorIntrin.register( + MMA_store_16x16_i32_shared_INTRIN, *get_mma_store_intrin("int32", 8, "shared", True) +) + def get_mma_intrin_group( load_scope: Literal["shared", "shared.dyn"], @@ -750,10 +849,10 @@ def get_mma_intrin_group( trans_b : bool Whether the input matrix B is transposed. - + smooth_a: bool Whether assume the propagted layout of A is smooth. - + smooth_b: bool Whether assume the propagted layout of B is smooth. @@ -801,7 +900,9 @@ def get_mma_intrin_group( trans_b_str = trans_b + "_b" if trans_b != "" else "" smooth_a_str = smooth_a + "_a" if smooth_a != "" else "" smooth_b_str = smooth_b + "_b" if smooth_b != "" else "" - compute_intrin = f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{smooth_a_str}{trans_b_str}{smooth_b_str}" + compute_intrin = ( + f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{smooth_a_str}{trans_b_str}{smooth_b_str}" + ) # e.g. mma_store_16x16_f32_shared_dyn_simple_ store_scope = store_scope.replace(".", "_") @@ -811,11 +912,28 @@ def get_mma_intrin_group( index_map_c = shared_16x16_to_mma_32x8_layout if in_dtype == "f16": - index_map_a = shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout - index_map_b = shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout + index_map_a = ( + shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout + ) + index_map_b = ( + shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout + ) + elif in_dtype == "i8": + index_map_a = ( + shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout + ) + index_map_b = ( + shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout + ) + else: + raise ValueError(f"Unsupported in_dtype: {in_dtype}") + + # micro kernel size, the order is [m, n, k] + micro_kernel: List[int] + if in_dtype == "f16": + micro_kernel = [16, 16, 16] elif in_dtype == "i8": - index_map_a = shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout - index_map_b = shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout + micro_kernel = [16, 16, 32] else: raise ValueError(f"Unsupported in_dtype: {in_dtype}") @@ -826,6 +944,7 @@ def get_mma_intrin_group( "compute": compute_intrin, "store": store_intrin, "index_map": [index_map_a, index_map_b, index_map_c], + "micro_kernel": micro_kernel, } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 08c3b0490491..2132d2e4ba39 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -52,7 +52,7 @@ Array IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Arrayanalyzer_->Simplify(e); - return s->IsInstance() ? s : e; + return s->IsInstance() ? e : s; }); for (int i = 0; i < n; ++i) { if (simplified[i]->IsInstance() && indices[i]->IsInstance()) { diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 16b8116c105b..6c6427a90649 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -803,7 +803,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { void RewriteBufferAccess(Buffer* buffer, Array* indices) { *buffer = new_buffer_; *indices = index_map_->MapIndices(*indices, &index_simplifier_); - *indices = this->IterMapSimplifyWithContext(*indices, true); + *indices = this->IterMapSimplifyWithContext(*indices, true); } using Parent = arith::IRMutatorWithAnalyzer;