Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 118 additions & 36 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def get_ldmatrix_intrin(
assert (
matrix_name == "B" or not transposed
), "Now only B matrix can be transposed for int8 matmul"
assert (
k_dim == 32 and dtype == "int8"
assert k_dim == 32 and (
dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8"
), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"

if matrix_name == "B" and not transposed:
Expand Down Expand Up @@ -260,8 +260,37 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None:
LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans"
TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True))

LDMATRIX_e4m3_A_INTRIN = "mma_ldmatrix_e4m3_a"
TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "A", False))

LDMATRIX_e4m3_B_INTRIN = "mma_ldmatrix_e4m3_b"
TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False))

LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans"
TensorIntrin.register(
LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True)
)

LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a"
TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False))

LDMATRIX_e5m2_B_INTRIN = "mma_ldmatrix_e5m2_b"
TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False))

def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed):
LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans"
TensorIntrin.register(
LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True)
)


def get_mma_intrin(
k_dim,
a_dtype="float16",
b_dtype="float16",
out_dtype="float16",
a_transposed=False,
b_transposed=False,
):
local_size = (M_DIM * k_dim) // WARP_SIZE
local_size_out = (M_DIM * N_DIM) // 32

Expand All @@ -281,14 +310,17 @@ def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed):
else:
assert False

out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype]

if out_dtype in ["float16", "float32"]:
in_dtype = "float16"
in_dtype_abbrv = "fp16"
else:
in_dtype = "int8"
in_dtype_abbrv = "int8"
dtype_abbrv = {
"float16": "fp16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
}
a_dtype_abbrv = dtype_abbrv[a_dtype]
b_dtype_abbrv = dtype_abbrv[b_dtype]
out_dtype_abbrv = dtype_abbrv[out_dtype]

def cast_to_out_dtype(v):
if out_dtype in ["float32", "int32"]:
Expand All @@ -307,15 +339,15 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a,
(WARP_SIZE, local_size),
in_dtype,
a_dtype,
align=64,
offset_factor=A_offset_factor,
scope="warp",
)
B = T.match_buffer(
b,
(WARP_SIZE, local_size),
in_dtype,
b_dtype,
align=64,
offset_factor=B_offset_factor,
scope="warp",
Expand Down Expand Up @@ -363,15 +395,15 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
a,
(WARP_SIZE, local_size),
in_dtype,
a_dtype,
align=64,
offset_factor=A_offset_factor,
scope="warp",
)
B = T.match_buffer(
b,
(WARP_SIZE, local_size),
in_dtype,
b_dtype,
align=64,
offset_factor=B_offset_factor,
scope="warp",
Expand Down Expand Up @@ -399,8 +431,8 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
mma_prefix,
"row",
"col",
in_dtype_abbrv,
in_dtype_abbrv,
a_dtype_abbrv,
b_dtype_abbrv,
out_dtype_abbrv,
A.data,
A.elem_offset + tx * lift(local_size),
Expand All @@ -418,8 +450,8 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
mma_prefix,
"row",
"col",
in_dtype_abbrv,
in_dtype_abbrv,
a_dtype_abbrv,
b_dtype_abbrv,
out_dtype_abbrv,
A.data,
A.elem_offset + tx * lift(local_size),
Expand All @@ -436,38 +468,80 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:


MMA_f16f16f32_INTRIN = "mma_f16f16f32"
TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False, False))
TensorIntrin.register(
MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, False)
)

MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b"
TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", False, True))
TensorIntrin.register(
MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, True)
)

MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a"
TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float32", True, False))
TensorIntrin.register(
MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, False)
)

MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b"
TensorIntrin.register(
MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", True, True)
MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN,
*get_mma_intrin(16, "float16", "float16", "float32", True, True),
)

MMA_f16f16f16_INTRIN = "mma_f16f16f16"
TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False, False))
TensorIntrin.register(
MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, False)
)

MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b"
TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", False, True))
TensorIntrin.register(
MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, True)
)

MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a"
TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", True, False))
TensorIntrin.register(
MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, False)
)

MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b"
TensorIntrin.register(
MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", True, True)
MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN,
*get_mma_intrin(16, "float16", "float16", "float16", True, True),
)

MMA_i8i8i32_INTRIN = "mma_i8i8i32"
TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False, False))
TensorIntrin.register(
MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, False)
)

MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b"
TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", False, True))
TensorIntrin.register(
MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, True)
)

MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32"
TensorIntrin.register(
MMA_e5m2e5m2f32_INTRIN,
*get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False),
)

MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b"
TensorIntrin.register(
MMA_e5m2e5m2f32_TRANS_B_INTRIN,
*get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True),
)

MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32"
TensorIntrin.register(
MMA_e4m3e4m3f32_INTRIN,
*get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False),
)

MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b"
TensorIntrin.register(
MMA_e4m3e4m3f32_TRANS_B_INTRIN,
*get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True),
)


def get_mma_fill_intrin(dtype, local_size):
Expand Down Expand Up @@ -631,7 +705,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None:
def get_mma_intrin_group(
load_scope: Literal["shared", "shared.dyn"],
store_scope: Literal["global", "shared", "shared.dyn"],
in_dtype: Literal["float16", "int8"],
in_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"],
out_dtype: Literal["float16", "float32", "int32"],
trans_a: bool,
trans_b: bool,
Expand Down Expand Up @@ -678,13 +752,21 @@ def get_mma_intrin_group(
"""
assert load_scope in ["shared", "shared.dyn"]
assert store_scope in ["global", "shared", "shared.dyn"]
assert in_dtype in ["float16", "int8"]
assert in_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"]
assert out_dtype in ["float16", "float32", "int32"]

shape = "16x16"

dtype_mapping = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"}
in_dtype = dtype_mapping[in_dtype]
dtype_mapping = {
"float16": "f16",
"float32": "f32",
"int8": "i8",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
"int32": "i32",
}
a_dtype = dtype_mapping[in_dtype]
b_dtype = dtype_mapping[in_dtype]
out_dtype = dtype_mapping[out_dtype]

# e.g. mma_fill_16x16_f32
Expand All @@ -694,13 +776,13 @@ def get_mma_intrin_group(
trans_a = "_trans" if trans_a else ""
trans_b = "_trans" if trans_b else ""
load_scope = "_dyn" if load_scope == "shared.dyn" else ""
load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}"
load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}"
load_a_intrin = f"mma_ldmatrix_{a_dtype}_a{trans_a}{load_scope}"
load_b_intrin = f"mma_ldmatrix_{b_dtype}_b{trans_b}{load_scope}"

# e.g. mma_f16f16f32_trans_a_trans_b
trans_a_str = trans_a + "_a" if trans_a != "" else ""
trans_b_str = trans_b + "_b" if trans_b != "" else ""
compute_intrin = f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{trans_b_str}"
compute_intrin = f"mma_{a_dtype}{b_dtype}{out_dtype}{trans_a_str}{trans_b_str}"

# e.g. mma_store_16x16_f32_shared_dyn_simple_
store_scope = store_scope.replace(".", "_")
Expand Down
16 changes: 15 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,12 @@ std::string GetFP8Type(DataType type) {
vec = "x2";
} else if (lanes == 4) {
vec = "x4";
} else if (lanes == 8) {
vec = "x8";
} else if (lanes == 16) {
vec = "x16";
} else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8";
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) for FP8";
}
stream << "__nv_fp8";
std::string suffix;
Expand Down Expand Up @@ -147,6 +151,16 @@ std::string CodeGenCUDA::Finish() {
if (enable_fp8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n";
decl_stream << "#include <cuda_fp8.h>\n";
decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n";
decl_stream << "using fp8_e4x2_t = __nv_fp8x2_e4m3;\n";
decl_stream << "using fp8_e4x4_t = __nv_fp8x4_e4m3;\n";
decl_stream << "struct fp8_e4x8_t {\n fp8_e4_t data[8]; \n};\n";
decl_stream << "struct fp8_e4x16_t {\n fp8_e4_t data[16]; \n};\n";
decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n";
decl_stream << "using fp8_e5x2_t = __nv_fp8x2_e5m2;\n";
decl_stream << "using fp8_e5x4_t = __nv_fp8x4_e5m2;\n";
decl_stream << "struct fp8_e5x8_t {\n fp8_e5_t data[8]; \n};\n";
decl_stream << "struct fp8_e5x16_t {\n fp8_e5_t data[16]; \n};\n";
decl_stream << "#endif\n\n";
}
declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_);
Expand Down
Loading