Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DLight] Fix Matmul rule for Conv3D #17363

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 52 additions & 48 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if reduction_blocks is None:
return None

main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

# Step 0. Configs
block_size_x: int = 16
block_size_y: int = 16
Expand All @@ -382,12 +375,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
vector_size: int = 4

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
# Reindex first and than analyze the index map
main_block = reduction_blocks[0]
reindex_a = sch.reindex(main_block, ("read", 0))
reindex_b = sch.reindex(main_block, ("read", 1))
reindex_c = sch.reindex(main_block, ("write", 0))

index_maps = get_index_map(sch.get(main_block))
assert index_maps is not None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

sch.transform_layout(reindex_a, ("write", 0), a_index_map)
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 2. Padding for dynamic shape kernels
Expand Down Expand Up @@ -508,13 +508,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if reduction_blocks is None:
return None

main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

# Start Schedule
# Step 0. Get schedule config.
# NOTE: we can analyze the config by the hardware spec in the future
Expand All @@ -539,12 +532,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
k_pad_factor = k_factors[1]

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
# Reindex first and than analyze the index map
main_block = reduction_blocks[0]
reindex_a = sch.reindex(main_block, ("read", 0))
reindex_b = sch.reindex(main_block, ("read", 1))
reindex_c = sch.reindex(main_block, ("write", 0))

index_maps = get_index_map(sch.get(main_block))
assert index_maps is not None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

sch.transform_layout(reindex_a, ("write", 0), a_index_map)
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 2. Padding for dynamic shape kernels
Expand Down Expand Up @@ -729,13 +729,6 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
if reduction_blocks is None:
return None

main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

# Start Schedule
# Step 0. Get schedule config.
# NOTE: we can analyze the config by the hardware spec in the future
Expand All @@ -760,12 +753,19 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
k_pad_factor = k_factors[1]

# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)
# Reindex first and than analyze the index map
main_block = reduction_blocks[0]
reindex_a = sch.reindex(main_block, ("read", 0))
reindex_b = sch.reindex(main_block, ("read", 1))
reindex_c = sch.reindex(main_block, ("write", 0))

index_maps = get_index_map(sch.get(main_block))
assert index_maps is not None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps

sch.transform_layout(reindex_a, ("write", 0), a_index_map)
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 2. Padding for dynamic shape kernels
Expand Down Expand Up @@ -979,12 +979,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring

main_block = reduction_blocks[0]
block_stmt = sch.get(main_block)
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None

main_block_info = get_block_info(sch, main_block)
iter_infos = main_block_info.iters
if not get_index_map(block_stmt):
return None

# Checks if it's a inner reduction by getting the last matrix's inner Index
def is_inner_reduction(block_stmt, iter_infos):
Expand All @@ -1000,13 +999,18 @@ def is_inner_reduction(block_stmt, iter_infos):
return ret

# Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]
# Reindex first and than analyze the index map
reindex_a = sch.reindex(main_block, ("read", 0))
reindex_b = sch.reindex(main_block, ("read", 1))
reindex_c = sch.reindex(main_block, ("write", 0))

index_maps = get_index_map(sch.get(main_block))
assert index_maps is not None
matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
sch.transform_layout(block, ("write", 0), b_index_map)
block = sch.reindex(main_block, ("write", 0))
sch.transform_layout(block, ("read", 0), c_index_map)

sch.transform_layout(reindex_a, ("write", 0), a_index_map)
sch.transform_layout(reindex_b, ("write", 0), b_index_map)
sch.transform_layout(reindex_c, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)

# Step 1. Check Tensor Core support
Expand Down
118 changes: 118 additions & 0 deletions tests/python/dlight/test_gpu_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import pytest

import tvm.testing
from tvm import dlight as dl
from tvm.script import tir as T
from tvm.target import Target


class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@pytest.fixture
def transform(self):
def transform(mod):
with Target("nvidia/geforce-gtx-1080-ti"):
# Use Matmul rule for Conv for now
return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)

return transform


class TestConv3d(BaseBeforeAfter):
# fmt: off
@T.prim_func
def before(
A: T.Buffer((14308, 3, 2, 14, 14), "float16"),
W: T.Buffer((1280, 3, 2, 14, 14), "float16"),
C: T.Buffer((14308, 1280, 1, 1, 1), "float16"),
):
pad_A = T.alloc_buffer((14308, 3, 2, 14, 14), "float16")
for i0, i1, i2, i3, i4 in T.grid(14308, 3, 2, 14, 14):
with T.block("pad_A"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
pad_A[v_i0, v_i1, v_i2, v_i3, v_i4] = A[v_i0, v_i1, v_i2, v_i3, v_i4]
for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(14308, 1280, 1, 1, 1, 3, 2, 14, 14):
with T.block("C"):
v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz])
with T.init():
C[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0.0)
C[v_nn, v_ff, v_yy, v_xx, v_zz] += pad_A[v_nn, v_rc, v_yy * 2 + v_ry, v_xx * 14 + v_rx, v_zz * 14 + v_rz]* W[v_ff, v_rc, v_ry, v_rx, v_rz]

@T.prim_func
def expected(A: T.Buffer((14308, 3, 2, 14, 14), "float16"), W: T.Buffer((1280, 3, 2, 14, 14), "float16"), C: T.Buffer((14308, 1280, 1, 1, 1), "float16")):
T.func_attr({"tir.is_scheduled": 1})
# with T.block("root"):
C_reindex_pad_local = T.alloc_buffer((1, 14336, 1280), "float16", scope="local")
pad_A_reindex_pad_shared = T.alloc_buffer((1, 14336, 1184), "float16", scope="shared")
W_reindex_pad_shared = T.alloc_buffer((1, 1280, 1184), "float16", scope="shared")
for ax0_ax2_0_fused in T.thread_binding(20, thread="blockIdx.y"):
for ax1_0 in T.thread_binding(448, thread="blockIdx.x"):
for ax2_1 in T.thread_binding(1, thread="vthread.y"):
for ax1_1 in T.thread_binding(1, thread="vthread.x"):
for ax2_2 in T.thread_binding(16, thread="threadIdx.y"):
for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
for ax1_3_init, ax2_3_0_init in T.grid(4, 2):
for ax2_3_1_init in T.vectorized(2):
with T.block("C_init"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_init)
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0_init * 2 + ax2_3_1_init)
C_reindex_pad_local[0, v1, v2] = T.float16(0.0)
for ax3_0 in range(74):
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(2):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("pad_A_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
pad_A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 14308 and v2 < 1176, A[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0))
for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"):
for ax0_ax1_ax2_fused_2 in range(4):
for ax0_ax1_ax2_fused_3 in T.vectorized(2):
with T.block("W_reindex_pad_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
v2 = T.axis.spatial(1184, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v2 < 1176, W[v1, v2 // 392, v2 // 196 % 2, v2 // 14 % 14, v2 % 14], T.float16(0.0))
for ax3_1, ax1_3, ax2_3_0 in T.grid(16, 4, 2):
for ax2_3_1 in T.vectorized(2):
with T.block("C_update"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_1 * 64 + ax2_2 * 4 + ax2_3_0 * 2 + ax2_3_1)
v3 = T.axis.reduce(1184, ax3_0 * 16 + ax3_1)
C_reindex_pad_local[0, v1, v2] = C_reindex_pad_local[0, v1, v2] + pad_A_reindex_pad_shared[0, v1, v3] * W_reindex_pad_shared[0, v2, v3]
for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
for ax2_1_1 in T.vectorized(2):
with T.block("C_reindex_pad_local"):
v0 = T.axis.spatial(1, ax0)
v1 = T.axis.spatial(14336, ax1_0 * 32 + ax1_2 * 4 + ax1)
v2 = T.axis.spatial(1280, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < 14308)
C[v1, v2, 0, 0, 0] = C_reindex_pad_local[v0, v1, v2]
# fmt: on


if __name__ == "__main__":
tvm.testing.main()
Loading