Skip to content

Commit

Permalink
improve tensor intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Feb 7, 2024
1 parent 3891f2a commit 7df35c3
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 26 deletions.
167 changes: 143 additions & 24 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,29 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""Intrinsics for tensorization on NVIDIA GPU."""
from typing import Dict, Optional, Tuple, Literal
from typing import Dict, Optional, Tuple, Literal, List

from tvm._ffi import register_func
from tvm.runtime import convert
from tvm.script import tir as T
from tvm.tir.function import PrimFunc
from tvm.tir import Cast, IntImm, TensorIntrin


def shared_16x16_to_mma_32x8_smoothlayout(i, j):
return (i * 2 + j // 8, j % 8)


def shared_16x32_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)


def shared_32x16_to_mma_32x16_smoothlayout(i, j):
return (i * 2 + j // 16, j % 16)


def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = (thread_id % 16)
row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8
return row, col

Expand All @@ -47,6 +49,18 @@ def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
return row, col


def shared_16x16_to_mma_32x8_layout(i, j):
thread_id = 4 * (i % 8) + (j % 8) // 2
return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2)
Expand Down Expand Up @@ -90,7 +104,7 @@ def get_ldmatrix_intrin(
matrix_name: Literal["A", "B"],
transposed: bool,
shared_scope: str = "shared",
propagate_layout: bool = False
propagate_layout: bool = False,
):
local_size = (M_DIM * k_dim) // WARP_SIZE
smem_offset = None
Expand Down Expand Up @@ -276,17 +290,19 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:

LDMATRIX_f16_B_TRANS_SMOOTH_INTRIN = "mma_ldmatrix_f16_b_trans_smooth"
TensorIntrin.register(
LDMATRIX_f16_B_TRANS_SMOOTH_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", True, "shared", True)
LDMATRIX_f16_B_TRANS_SMOOTH_INTRIN,
*get_ldmatrix_intrin(16, "float16", "B", True, "shared", True),
)

LDMATRIX_f16_A_DYN_INTRIN = "mma_ldmatrix_f16_a_dyn"
TensorIntrin.register(
LDMATRIX_f16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn")
)

LDMATRIX_f16_A_DYN_SMOOTH_INTRIN = "mma_ldmatrix_f16_a_dyn_smooth"
LDMATRIX_f16_A_DYN_SMOOTH_INTRIN = "mma_ldmatrix_f16_a_smooth_dyn"
TensorIntrin.register(
LDMATRIX_f16_A_DYN_SMOOTH_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn", True)
LDMATRIX_f16_A_DYN_SMOOTH_INTRIN,
*get_ldmatrix_intrin(16, "float16", "A", False, "shared.dyn", True),
)

LDMATRIX_f16_B_DYN_INTRIN = "mma_ldmatrix_f16_b_dyn"
Expand All @@ -306,7 +322,8 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:

LDMATRIX_f16_B_TRANS_SMOOTH_DYN_INTRIN = "mma_ldmatrix_f16_b_trans_smooth_dyn"
TensorIntrin.register(
LDMATRIX_f16_B_TRANS_SMOOTH_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", True, "shared.dyn", True)
LDMATRIX_f16_B_TRANS_SMOOTH_DYN_INTRIN,
*get_ldmatrix_intrin(16, "float16", "B", True, "shared.dyn", True),
)

LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_i8_a"
Expand All @@ -319,23 +336,85 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True))


LDMATRIX_i8_A_SMOOTH_INTRIN = "mma_ldmatrix_i8_a_smooth"
TensorIntrin.register(
LDMATRIX_i8_A_SMOOTH_INTRIN, *get_ldmatrix_intrin(32, "int8", "A", False, "shared", True)
)

LDMATRIX_i8_B_SMOOTH_INTRIN = "mma_ldmatrix_i8_b_smooth"
TensorIntrin.register(
LDMATRIX_i8_B_SMOOTH_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", False, "shared", True)
)

LDMATRIX_i8_B_TRANS_SMOOTH_INTRIN = "mma_ldmatrix_i8_b_trans_smooth"
TensorIntrin.register(
LDMATRIX_i8_B_TRANS_SMOOTH_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True, "shared", True)
)

LDMATRIX_i8_A_DYN_INTRIN = "mma_ldmatrix_i8_a_dyn"
TensorIntrin.register(
LDMATRIX_i8_A_DYN_INTRIN, *get_ldmatrix_intrin(32, "int8", "A", False, "shared.dyn")
)

LDMATRIX_i8_B_DYN_INTRIN = "mma_ldmatrix_i8_b_dyn"
TensorIntrin.register(
LDMATRIX_i8_B_DYN_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", False, "shared.dyn")
)

LDMATRIX_i8_B_TRANS_DYN_INTRIN = "mma_ldmatrix_i8_b_trans_dyn"
TensorIntrin.register(
LDMATRIX_i8_B_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True, "shared.dyn")
)


LDMATRIX_i8_A_SMOOTH_DYN_INTRIN = "mma_ldmatrix_i8_a_smooth_dyn"
TensorIntrin.register(
LDMATRIX_i8_A_SMOOTH_DYN_INTRIN,
*get_ldmatrix_intrin(32, "int8", "A", False, "shared.dyn", True),
)

LDMATRIX_i8_B_SMOOTH_DYN_INTRIN = "mma_ldmatrix_i8_b_smooth_dyn"
TensorIntrin.register(
LDMATRIX_i8_B_SMOOTH_DYN_INTRIN,
*get_ldmatrix_intrin(32, "int8", "B", False, "shared.dyn", True),
)

LDMATRIX_i8_B_TRANS_SMOOTH_DYN_INTRIN = "mma_ldmatrix_i8_b_trans_smooth_dyn"
TensorIntrin.register(
LDMATRIX_i8_B_TRANS_SMOOTH_DYN_INTRIN,
*get_ldmatrix_intrin(32, "int8", "B", True, "shared.dyn", True),
)


def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed, smooth_a=False, smooth_b=False):
local_size = (M_DIM * k_dim) // WARP_SIZE
local_size_out = (M_DIM * N_DIM) // 32

index_map_C = shared_16x16_to_mma_32x8_layout

if k_dim == 16:
index_map_A = shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout
index_map_B = shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout
index_map_A = (
shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout
)
index_map_B = (
shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout
)
mma_prefix = "m16n8k16"
elif k_dim == 32 and b_transposed:
index_map_A = shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout
index_map_B = shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout
index_map_A = (
shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout
)
index_map_B = (
shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout
)
mma_prefix = "m16n8k32"
elif k_dim == 32 and not b_transposed:
index_map_A = shared_16x32_to_mma_32x16_layout if smooth_a else shared_16x32_to_mma_32x16_layout
index_map_B = shared_32x16_to_mma_32x16_layout if smooth_b else shared_32x16_to_mma_32x16_layout
index_map_A = (
shared_16x32_to_mma_32x16_layout if smooth_a else shared_16x32_to_mma_32x16_layout
)
index_map_B = (
shared_32x16_to_mma_32x16_layout if smooth_b else shared_32x16_to_mma_32x16_layout
)
mma_prefix = "m16n8k32"
else:
assert False
Expand Down Expand Up @@ -515,10 +594,15 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", False, True))

MMA_f16f16f16_TRANS_SMOOTH_B_INTRIN = "mma_f16f16f16_trans_b_smooth_b"
TensorIntrin.register(MMA_f16f16f16_TRANS_SMOOTH_B_INTRIN, *get_mma_intrin(16, "float16", False, True, False, True))
TensorIntrin.register(
MMA_f16f16f16_TRANS_SMOOTH_B_INTRIN, *get_mma_intrin(16, "float16", False, True, False, True)
)

MMA_f16f16f16_SMOOTH_A_TRANS_SMOOTH_B_INTRIN = "mma_f16f16f16_smooth_a_trans_b_smooth_b"
TensorIntrin.register(MMA_f16f16f16_SMOOTH_A_TRANS_SMOOTH_B_INTRIN, *get_mma_intrin(16, "float16", False, True, True, True))
TensorIntrin.register(
MMA_f16f16f16_SMOOTH_A_TRANS_SMOOTH_B_INTRIN,
*get_mma_intrin(16, "float16", False, True, True, True),
)

MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a"
TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", True, False))
Expand All @@ -534,6 +618,17 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b"
TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", False, True))

MMA_i8i8i32_TRANS_B_SMOOTH_B_INTRIN = "mma_i8i8i32_trans_b_smooth_b"
TensorIntrin.register(
MMA_i8i8i32_TRANS_B_SMOOTH_B_INTRIN, *get_mma_intrin(32, "int32", False, True, False, True)
)

MMA_i8i8i32_SMOOTH_A_TRANS_B_SMOOTH_B_INTRIN = "mma_i8i8i32_smooth_a_trans_b_smooth_b"
TensorIntrin.register(
MMA_i8i8i32_SMOOTH_A_TRANS_B_SMOOTH_B_INTRIN,
*get_mma_intrin(32, "int32", False, True, True, True),
)


def get_mma_fill_intrin(dtype, local_size):
zero = IntImm("int32", 0).astype(dtype)
Expand Down Expand Up @@ -657,7 +752,6 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
row, col = T.meta_var(index_map_rev(tx, local_id))
C[row, col] = C_warp[tx, local_id]


return mma_store_desc, mma_store_impl


Expand Down Expand Up @@ -708,14 +802,19 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:

MMA_store_16x16_f16_shared_INTRIN = "mma_store_16x16_f16_shared_"
TensorIntrin.register(
MMA_store_16x16_f16_shared_INTRIN, *get_mma_store_intrin("float16", 8, "shared", False)
MMA_store_16x16_f16_shared_INTRIN, *get_mma_store_intrin("float16", 8, "shared", True)
)

MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
TensorIntrin.register(
MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, "global", True)
)

MMA_store_16x16_i32_shared_INTRIN = "mma_store_16x16_i32_shared_"
TensorIntrin.register(
MMA_store_16x16_i32_shared_INTRIN, *get_mma_store_intrin("int32", 8, "shared", True)
)


def get_mma_intrin_group(
load_scope: Literal["shared", "shared.dyn"],
Expand Down Expand Up @@ -750,10 +849,10 @@ def get_mma_intrin_group(
trans_b : bool
Whether the input matrix B is transposed.
smooth_a: bool
Whether assume the propagted layout of A is smooth.
smooth_b: bool
Whether assume the propagted layout of B is smooth.
Expand Down Expand Up @@ -801,7 +900,9 @@ def get_mma_intrin_group(
trans_b_str = trans_b + "_b" if trans_b != "" else ""
smooth_a_str = smooth_a + "_a" if smooth_a != "" else ""
smooth_b_str = smooth_b + "_b" if smooth_b != "" else ""
compute_intrin = f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{smooth_a_str}{trans_b_str}{smooth_b_str}"
compute_intrin = (
f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{smooth_a_str}{trans_b_str}{smooth_b_str}"
)

# e.g. mma_store_16x16_f32_shared_dyn_simple_
store_scope = store_scope.replace(".", "_")
Expand All @@ -811,11 +912,28 @@ def get_mma_intrin_group(

index_map_c = shared_16x16_to_mma_32x8_layout
if in_dtype == "f16":
index_map_a = shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout
index_map_b = shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout
index_map_a = (
shared_16x16_to_mma_32x8_smoothlayout if smooth_a else shared_16x16_to_mma_32x8_layout
)
index_map_b = (
shared_16x16_to_mma_32x8_smoothlayout if smooth_b else shared_16x16_to_mma_32x8_layout
)
elif in_dtype == "i8":
index_map_a = (
shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout
)
index_map_b = (
shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout
)
else:
raise ValueError(f"Unsupported in_dtype: {in_dtype}")

# micro kernel size, the order is [m, n, k]
micro_kernel: List[int]
if in_dtype == "f16":
micro_kernel = [16, 16, 16]
elif in_dtype == "i8":
index_map_a = shared_16x32_to_mma_32x16_smoothlayout if smooth_a else shared_16x32_to_mma_32x16_layout
index_map_b = shared_16x32_to_mma_32x16_smoothlayout if smooth_b else shared_16x32_to_mma_32x16_layout
micro_kernel = [16, 16, 32]
else:
raise ValueError(f"Unsupported in_dtype: {in_dtype}")

Expand All @@ -826,6 +944,7 @@ def get_mma_intrin_group(
"compute": compute_intrin,
"store": store_intrin,
"index_map": [index_map_a, index_map_b, index_map_c],
"micro_kernel": micro_kernel,
}


Expand Down
2 changes: 1 addition & 1 deletion src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Array<PrimExpr> IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const Array<Pr
if (non_trivial_only) {
(simplified).MutateByApply([&](const PrimExpr& e) {
auto s = this->analyzer_->Simplify(e);
return s->IsInstance<IntImmNode>() ? s : e;
return s->IsInstance<IntImmNode>() ? e : s;
});
for (int i = 0; i < n; ++i) {
if (simplified[i]->IsInstance<IntImmNode>() && indices[i]->IsInstance<VarNode>()) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {
void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
*buffer = new_buffer_;
*indices = index_map_->MapIndices(*indices, &index_simplifier_);
*indices = this->IterMapSimplifyWithContext(*indices, true);
*indices = this->IterMapSimplifyWithContext(*indices, true);
}

using Parent = arith::IRMutatorWithAnalyzer;
Expand Down

0 comments on commit 7df35c3

Please sign in to comment.