From dd87bcbd61e12369090d2c58c9522ec153a37f8c Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Wed, 11 Sep 2024 15:57:25 +0800 Subject: [PATCH] [DLight] Fix Matmul rule for Conv3D Currently, the matmul rule for Conv3D is incorrect, due to the incorrect reindexing of the input tensor. This commit fixes the issue by correctly The `index map` of `transform_layout` should be calculated after the `reindex` process --- python/tvm/dlight/gpu/matmul.py | 100 ++++++++++++----------- tests/python/dlight/test_gpu_conv.py | 118 +++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 48 deletions(-) create mode 100644 tests/python/dlight/test_gpu_conv.py diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 5fb8e2469d54..5568083982b9 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -364,13 +364,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if reduction_blocks is None: return None - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - # Step 0. Configs block_size_x: int = 16 block_size_y: int = 16 @@ -382,12 +375,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring vector_size: int = 4 # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + # Reindex first and than analyze the index map + main_block = reduction_blocks[0] + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 2. Padding for dynamic shape kernels @@ -508,13 +508,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if reduction_blocks is None: return None - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - # Start Schedule # Step 0. Get schedule config. # NOTE: we can analyze the config by the hardware spec in the future @@ -539,12 +532,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring k_pad_factor = k_factors[1] # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + # Reindex first and than analyze the index map + main_block = reduction_blocks[0] + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 2. Padding for dynamic shape kernels @@ -729,13 +729,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if reduction_blocks is None: return None - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - # Start Schedule # Step 0. Get schedule config. # NOTE: we can analyze the config by the hardware spec in the future @@ -760,12 +753,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring k_pad_factor = k_factors[1] # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + # Reindex first and than analyze the index map + main_block = reduction_blocks[0] + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 2. Padding for dynamic shape kernels @@ -979,12 +979,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring main_block = reduction_blocks[0] block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None main_block_info = get_block_info(sch, main_block) iter_infos = main_block_info.iters + if not get_index_map(block_stmt): + return None # Checks if it's a inner reduction by getting the last matrix's inner Index def is_inner_reduction(block_stmt, iter_infos): @@ -1000,13 +999,18 @@ def is_inner_reduction(block_stmt, iter_infos): return ret # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + # Reindex first and than analyze the index map + reindex_a = sch.reindex(main_block, ("read", 0)) + reindex_b = sch.reindex(main_block, ("read", 1)) + reindex_c = sch.reindex(main_block, ("write", 0)) + + index_maps = get_index_map(sch.get(main_block)) + assert index_maps is not None matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) + + sch.transform_layout(reindex_a, ("write", 0), a_index_map) + sch.transform_layout(reindex_b, ("write", 0), b_index_map) + sch.transform_layout(reindex_c, ("read", 0), c_index_map) sch.transform_block_layout(main_block, matmul_index_map) # Step 1. Check Tensor Core support diff --git a/tests/python/dlight/test_gpu_conv.py b/tests/python/dlight/test_gpu_conv.py new file mode 100644 index 000000000000..4997975dd311 --- /dev/null +++ b/tests/python/dlight/test_gpu_conv.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import tir as T +from tvm.target import Target + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("nvidia/geforce-gtx-1080-ti"): + # Use Matmul rule for Conv for now + return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) + + return transform + + +class TestConv3d(BaseBeforeAfter): + # fmt: off + @T.prim_func + def before( + A: T.Buffer((14308, 3, 2, 14, 14), "float16"), + W: T.Buffer((1280, 3, 2, 14, 14), "float16"), + C: T.Buffer((14308, 1280, 1, 1, 1), "float16"), + ): + pad_A = T.alloc_buffer((14308, 3, 2, 14, 14), "float16") + for i0, i1, i2, i3, i4 in T.grid(14308, 3, 2, 14, 14): + with T.block("pad_A"): + v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + pad_A[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4] + for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(14308, 1280, 1, 1, 1, 3, 2, 14, 14): + with T.block("C"): + v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz]) + with T.init(): + C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0) + C[v_nn, v_ff, v_yy, v_xx, v_zz] += pad_A[v_nn, v_rc, v_yy * 2 + v_ry, v_xx * 14 + v_rx, v_zz * 14 + v_rz]* W[v_ff, v_rc, v_ry, v_rx, v_rz] + + @T.prim_func + def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + C_reindex_pad_local = T.alloc_buffer((1, 14336, 1280), "float16", scope="local") + pad_A_reindex_pad_shared = T.alloc_buffer((1, 14336, 1184), "float16", scope="shared") + W_reindex_pad_shared = T.alloc_buffer((1, 1280, 1184), "float16", scope="shared") + for ax0_ax2_0_fused in T.thread_binding(20, thread="blockIdx.y"): + for ax1_0 in T.thread_binding(448, thread="blockIdx.x"): + for ax2_1 in T.thread_binding(1, thread="vthread.y"): + for ax1_1 in T.thread_binding(1, thread="vthread.x"): + for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): + for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_3_init, ax2_3_0_init in T.grid(4, 2): + for ax2_3_1_init in T.vectorized(2): + with T.block("C_init"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init) + v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init) + C_reindex_pad_local[0, v1, v2] = T.float16(0.0) + for ax3_0 in range(74): + for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(2): + for ax0_ax1_ax2_fused_3 in T.vectorized(2): + with T.block("pad_A_reindex_pad_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) + v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + pad_A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 14308 and v2 < 1176, A[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0)) + for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(4): + for ax0_ax1_ax2_fused_3 in T.vectorized(2): + with T.block("W_reindex_pad_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) + v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 1176, W[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0)) + for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2): + for ax2_3_1 in T.vectorized(2): + with T.block("C_update"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3) + v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1) + v3 = T.axis.reduce(1184, ax3_0 * 16 + ax3_1) + C_reindex_pad_local[0, v1, v2] = C_reindex_pad_local[0, v1, v2] + pad_A_reindex_pad_shared[0, v1, v3] * W_reindex_pad_shared[0, v2, v3] + for ax0, ax1, ax2_0 in T.grid(1, 4, 2): + for ax2_1_1 in T.vectorized(2): + with T.block("C_reindex_pad_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_2 * 4 + ax1) + v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) + T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < 14308) + C[v1, v2, 0, 0, 0] = C_reindex_pad_local[v0, v1, v2] + # fmt: on + + +if __name__ == "__main__": + tvm.testing.main()