diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 409a1ff10a78..e1ff18bc8fb9 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -122,8 +122,8 @@ def get_ldmatrix_intrin( assert ( matrix_name == "B" or not transposed ), "Now only B matrix can be transposed for int8 matmul" - assert ( - k_dim == 32 and dtype == "int8" + assert k_dim == 32 and ( + dtype == "int8" or dtype == "e4m3_float8" or dtype == "e5m2_float8" ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now" if matrix_name == "B" and not transposed: @@ -260,8 +260,37 @@ def ldmatrix_impl(warp_handle: T.handle, shared_handle: T.handle) -> None: LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans" TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "int8", "B", True)) +LDMATRIX_e4m3_A_INTRIN = "mma_ldmatrix_e4m3_a" +TensorIntrin.register(LDMATRIX_e4m3_A_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "A", False)) + +LDMATRIX_e4m3_B_INTRIN = "mma_ldmatrix_e4m3_b" +TensorIntrin.register(LDMATRIX_e4m3_B_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", False)) + +LDMATRIX_e4m3_B_TRANS_INTRIN = "mma_ldmatrix_e4m3_b_trans" +TensorIntrin.register( + LDMATRIX_e4m3_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e4m3_float8", "B", True) +) + +LDMATRIX_e5m2_A_INTRIN = "mma_ldmatrix_e5m2_a" +TensorIntrin.register(LDMATRIX_e5m2_A_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "A", False)) + +LDMATRIX_e5m2_B_INTRIN = "mma_ldmatrix_e5m2_b" +TensorIntrin.register(LDMATRIX_e5m2_B_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", False)) -def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed): +LDMATRIX_e5m2_B_TRANS_INTRIN = "mma_ldmatrix_e5m2_b_trans" +TensorIntrin.register( + LDMATRIX_e5m2_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, "e5m2_float8", "B", True) +) + + +def get_mma_intrin( + k_dim, + a_dtype="float16", + b_dtype="float16", + out_dtype="float16", + a_transposed=False, + b_transposed=False, +): local_size = (M_DIM * k_dim) // WARP_SIZE local_size_out = (M_DIM * N_DIM) // 32 @@ -281,14 +310,17 @@ def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed): else: assert False - out_dtype_abbrv = {"float16": "fp16", "float32": "fp32", "int32": "int32"}[out_dtype] - - if out_dtype in ["float16", "float32"]: - in_dtype = "float16" - in_dtype_abbrv = "fp16" - else: - in_dtype = "int8" - in_dtype_abbrv = "int8" + dtype_abbrv = { + "float16": "fp16", + "float32": "fp32", + "int8": "int8", + "int32": "int32", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + } + a_dtype_abbrv = dtype_abbrv[a_dtype] + b_dtype_abbrv = dtype_abbrv[b_dtype] + out_dtype_abbrv = dtype_abbrv[out_dtype] def cast_to_out_dtype(v): if out_dtype in ["float32", "int32"]: @@ -307,7 +339,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), - in_dtype, + a_dtype, align=64, offset_factor=A_offset_factor, scope="warp", @@ -315,7 +347,7 @@ def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer( b, (WARP_SIZE, local_size), - in_dtype, + b_dtype, align=64, offset_factor=B_offset_factor, scope="warp", @@ -363,7 +395,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( a, (WARP_SIZE, local_size), - in_dtype, + a_dtype, align=64, offset_factor=A_offset_factor, scope="warp", @@ -371,7 +403,7 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: B = T.match_buffer( b, (WARP_SIZE, local_size), - in_dtype, + b_dtype, align=64, offset_factor=B_offset_factor, scope="warp", @@ -399,8 +431,8 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: mma_prefix, "row", "col", - in_dtype_abbrv, - in_dtype_abbrv, + a_dtype_abbrv, + b_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), @@ -418,8 +450,8 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: mma_prefix, "row", "col", - in_dtype_abbrv, - in_dtype_abbrv, + a_dtype_abbrv, + b_dtype_abbrv, out_dtype_abbrv, A.data, A.elem_offset + tx * lift(local_size), @@ -436,38 +468,80 @@ def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None: MMA_f16f16f32_INTRIN = "mma_f16f16f32" -TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", False, False)) +TensorIntrin.register( + MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, False) +) MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b" -TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", False, True)) +TensorIntrin.register( + MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", False, True) +) MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a" -TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float32", True, False)) +TensorIntrin.register( + MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float32", True, False) +) MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b" TensorIntrin.register( - MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", True, True) + MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, + *get_mma_intrin(16, "float16", "float16", "float32", True, True), ) MMA_f16f16f16_INTRIN = "mma_f16f16f16" -TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", False, False)) +TensorIntrin.register( + MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, False) +) MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b" -TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", False, True)) +TensorIntrin.register( + MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", False, True) +) MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a" -TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", True, False)) +TensorIntrin.register( + MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, "float16", "float16", "float16", True, False) +) MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b" TensorIntrin.register( - MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", True, True) + MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, + *get_mma_intrin(16, "float16", "float16", "float16", True, True), ) MMA_i8i8i32_INTRIN = "mma_i8i8i32" -TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False, False)) +TensorIntrin.register( + MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, False) +) MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b" -TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", False, True)) +TensorIntrin.register( + MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int8", "int8", "int32", False, True) +) + +MMA_e5m2e5m2f32_INTRIN = "mma_e5m2e5m2f32" +TensorIntrin.register( + MMA_e5m2e5m2f32_INTRIN, + *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, False), +) + +MMA_e5m2e5m2f32_TRANS_B_INTRIN = "mma_e5m2e5m2f32_trans_b" +TensorIntrin.register( + MMA_e5m2e5m2f32_TRANS_B_INTRIN, + *get_mma_intrin(32, "e5m2_float8", "e5m2_float8", "float32", False, True), +) + +MMA_e4m3e4m3f32_INTRIN = "mma_e4m3e4m3f32" +TensorIntrin.register( + MMA_e4m3e4m3f32_INTRIN, + *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, False), +) + +MMA_e4m3e4m3f32_TRANS_B_INTRIN = "mma_e4m3e4m3f32_trans_b" +TensorIntrin.register( + MMA_e4m3e4m3f32_TRANS_B_INTRIN, + *get_mma_intrin(32, "e4m3_float8", "e4m3_float8", "float32", False, True), +) def get_mma_fill_intrin(dtype, local_size): @@ -631,7 +705,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: def get_mma_intrin_group( load_scope: Literal["shared", "shared.dyn"], store_scope: Literal["global", "shared", "shared.dyn"], - in_dtype: Literal["float16", "int8"], + in_dtype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"], out_dtype: Literal["float16", "float32", "int32"], trans_a: bool, trans_b: bool, @@ -678,13 +752,21 @@ def get_mma_intrin_group( """ assert load_scope in ["shared", "shared.dyn"] assert store_scope in ["global", "shared", "shared.dyn"] - assert in_dtype in ["float16", "int8"] + assert in_dtype in ["float16", "int8", "e4m3_float8", "e5m2_float8"] assert out_dtype in ["float16", "float32", "int32"] shape = "16x16" - dtype_mapping = {"float16": "f16", "float32": "f32", "int8": "i8", "int32": "i32"} - in_dtype = dtype_mapping[in_dtype] + dtype_mapping = { + "float16": "f16", + "float32": "f32", + "int8": "i8", + "e4m3_float8": "e4m3", + "e5m2_float8": "e5m2", + "int32": "i32", + } + a_dtype = dtype_mapping[in_dtype] + b_dtype = dtype_mapping[in_dtype] out_dtype = dtype_mapping[out_dtype] # e.g. mma_fill_16x16_f32 @@ -694,13 +776,13 @@ def get_mma_intrin_group( trans_a = "_trans" if trans_a else "" trans_b = "_trans" if trans_b else "" load_scope = "_dyn" if load_scope == "shared.dyn" else "" - load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}" - load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}" + load_a_intrin = f"mma_ldmatrix_{a_dtype}_a{trans_a}{load_scope}" + load_b_intrin = f"mma_ldmatrix_{b_dtype}_b{trans_b}{load_scope}" # e.g. mma_f16f16f32_trans_a_trans_b trans_a_str = trans_a + "_a" if trans_a != "" else "" trans_b_str = trans_b + "_b" if trans_b != "" else "" - compute_intrin = f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{trans_b_str}" + compute_intrin = f"mma_{a_dtype}{b_dtype}{out_dtype}{trans_a_str}{trans_b_str}" # e.g. mma_store_16x16_f32_shared_dyn_simple_ store_scope = store_scope.replace(".", "_") diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bd2804830172..040051825119 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -51,8 +51,12 @@ std::string GetFP8Type(DataType type) { vec = "x2"; } else if (lanes == 4) { vec = "x4"; + } else if (lanes == 8) { + vec = "x8"; + } else if (lanes == 16) { + vec = "x16"; } else { - LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8) for FP8"; + LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) for FP8"; } stream << "__nv_fp8"; std::string suffix; @@ -147,6 +151,16 @@ std::string CodeGenCUDA::Finish() { if (enable_fp8_) { decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)\n"; decl_stream << "#include \n"; + decl_stream << "using fp8_e4_t = __nv_fp8_e4m3;\n"; + decl_stream << "using fp8_e4x2_t = __nv_fp8x2_e4m3;\n"; + decl_stream << "using fp8_e4x4_t = __nv_fp8x4_e4m3;\n"; + decl_stream << "struct fp8_e4x8_t {\n fp8_e4_t data[8]; \n};\n"; + decl_stream << "struct fp8_e4x16_t {\n fp8_e4_t data[16]; \n};\n"; + decl_stream << "using fp8_e5_t = __nv_fp8_e5m2;\n"; + decl_stream << "using fp8_e5x2_t = __nv_fp8x2_e5m2;\n"; + decl_stream << "using fp8_e5x4_t = __nv_fp8x4_e5m2;\n"; + decl_stream << "struct fp8_e5x8_t {\n fp8_e5_t data[8]; \n};\n"; + decl_stream << "struct fp8_e5x16_t {\n fp8_e5_t data[16]; \n};\n"; decl_stream << "#endif\n\n"; } declare_vector_type_extensions(decl_stream, enable_fp16_, enable_fp8_); diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index ed6125e74cae..c9c15ee0cb2e 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -54,24 +54,27 @@ enum class DataType : int { kUInt32 = 7, kInt64 = 8, kUInt64 = 9, - kFloat16 = 10, - kBFloat16 = 11, - kFloat16x2 = 12, - kFloat32 = 13, - kTensorFloat32 = 14, - kFloat64 = 15, - kBit1 = 16, - kBit8 = 17, - kBit16 = 18, - kBit32 = 19, - kBit64 = 20, + kFloat8_e4m3 = 10, + kFloat8_e5m2 = 11, + kFloat16 = 12, + kBFloat16 = 13, + kFloat16x2 = 14, + kFloat32 = 15, + kTensorFloat32 = 16, + kFloat64 = 17, + kBit1 = 18, + kBit8 = 19, + kBit16 = 20, + kBit32 = 21, + kBit64 = 22 }; -static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", ".s32", - ".u32", ".s64", ".u64", ".f16", ".bf16", ".f16x2", ".f32", - ".tf32", ".f64", ".b1", ".b8", ".b16", ".b32", ".b64"}; -static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 16, - 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; +static const char* dtype_str[] = {".s4", ".u4", ".s8", ".u8", ".s16", ".u16", + ".s32", ".u32", ".s64", ".u64", ".e4m3", ".e5m2", + ".f16", ".bf16", ".f16x2", ".f32", ".tf32", ".f64", + ".b1", ".b8", ".b16", ".b32", ".b64"}; +static const uint32_t num_bits[] = {4, 4, 8, 8, 16, 16, 32, 32, 64, 64, 8, 8, + 16, 16, 32, 32, 32, 64, 1, 8, 16, 32, 64}; /*! * \brief Create PTX data type from string. @@ -97,6 +100,10 @@ inline DataType DTypeFromString(const std::string str) { return DataType::kInt64; } else if (str == "uint64" || str == ".u64") { return DataType::kUInt64; + } else if (str == "e4m3" || str == ".e4m3") { + return DataType::kFloat8_e4m3; + } else if (str == "e5m2" || str == ".e5m2") { + return DataType::kFloat8_e5m2; } else if (str == "float16" || str == "fp16" || str == ".f16") { return DataType::kFloat16; } else if (str == "bfloat16" || str == "bf16") { @@ -232,6 +239,10 @@ const MMAConfig valid_mma_configs[] = { MMAConfig(16, 8, 128, DataType::kInt4, false, true), MMAConfig(16, 8, 64, DataType::kUInt4, false, true), MMAConfig(16, 8, 128, DataType::kUInt4, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e4m3, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e4m3, false, true), + MMAConfig(16, 8, 32, DataType::kFloat8_e5m2, false, false), + MMAConfig(16, 8, 64, DataType::kFloat8_e5m2, false, true), }; /*! @@ -263,6 +274,11 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ case DataType::kUInt8: CHECK(dtype_b == DataType::kInt8 || dtype_b == DataType::kUInt8) << ab_not_match_err_str; break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_b == DataType::kFloat8_e4m3 || dtype_b == DataType::kFloat8_e5m2) + << ab_not_match_err_str; + break; default: CHECK(false) << "Invalid multiplicand data types: " << DTypeToString(dtype_a) << DTypeToString(dtype_b); @@ -291,6 +307,11 @@ void CheckMMADTypeCompatible(DataType dtype_a, DataType dtype_b, DataType dtype_ CHECK(dtype_c == DataType::kFloat64) << "For multiplicand data type f64, accumulator data type can only be f64."; break; + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: + CHECK(dtype_c == DataType::kFloat32) + << "For multiplicand data type e4m3/e5m2, accumulator data type can only be f32."; + break; default: CHECK(false) << "Invalid multiplicand/accumulator data types: " << DTypeToString(dtype_a) << DTypeToString(dtype_b) << DTypeToString(dtype_c) << "."; @@ -371,6 +392,8 @@ inline FragAttrs GetFragAttrs(DataType dtype) { case DataType::kUInt4: case DataType::kInt8: case DataType::kUInt8: + case DataType::kFloat8_e4m3: + case DataType::kFloat8_e5m2: case DataType::kBit16: case DataType::kFloat16: // .f16x2 register case DataType::kBFloat16: diff --git a/tests/python/codegen/test_target_codegen_cuda_fp8.py b/tests/python/codegen/test_target_codegen_cuda_fp8.py index c22f3f01a880..d04262a3701a 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp8.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp8.py @@ -41,7 +41,7 @@ ml_dtypes = None -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_conversions(): dtype = "e4m3_float8" @@ -86,7 +86,7 @@ def add( ) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_packing(): length = 64 vector_length = 4 @@ -151,7 +151,7 @@ def add( ) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) def test_e4m3_vector_conversions(native_dtype, promoted_dtype): vector_length = 64 @@ -791,7 +791,7 @@ def compiled_functions( dev, ) - @tvm.testing.requires_cuda_compute_version(9) + @tvm.testing.requires_cuda_compute_version(8, 9) def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): quant, dequant = compiled_functions dev = tvm.device(target_str, 0) @@ -806,7 +806,7 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions): tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2) -@tvm.testing.requires_cuda_compute_version(9) +@tvm.testing.requires_cuda_compute_version(8, 9) @pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) def test_const(dtype): @T.prim_func @@ -821,6 +821,35 @@ def func(A: T.Buffer((4,), dtype)) -> None: tvm.build(mod, target="cuda") +@tvm.testing.requires_cuda_compute_version(8, 9) +@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"]) +@pytest.mark.parametrize("vec_len", [2, 4, 8, 16]) +def test_copy(dtype, vec_len): + @T.prim_func + def func( + A: T.Buffer( + ( + 4, + vec_len, + ), + dtype, + ), + B: T.Buffer( + ( + 4, + vec_len, + ), + dtype, + ), + ) -> None: + for tx in T.thread_binding(0, 4, "threadIdx.x"): + for i in T.vectorized(vec_len): + B[tx, i] = A[tx, i] + + mod = tvm.IRModule({"main": func}) + rtmod = tvm.build(mod, target="cuda") + + num_experts = 8 reduce_size = 1792 spatial_size = 4096 diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index d704dc243891..390745fe9d96 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -28,6 +28,10 @@ LDMATRIX_i8_A_INTRIN, LDMATRIX_i8_B_TRANS_INTRIN, LDMATRIX_i8_B_INTRIN, + LDMATRIX_e4m3_A_INTRIN, + LDMATRIX_e4m3_B_TRANS_INTRIN, + LDMATRIX_e5m2_A_INTRIN, + LDMATRIX_e5m2_B_TRANS_INTRIN, MMA_f16f16f16_INTRIN, MMA_f16f16f16_TRANS_B_INTRIN, MMA_f16f16f32_INTRIN, @@ -37,6 +41,8 @@ MMA_fill_16x16_i32_INTRIN, MMA_i8i8i32_INTRIN, MMA_i8i8i32_TRANS_B_INTRIN, + MMA_e5m2e5m2f32_TRANS_B_INTRIN, + MMA_e4m3e4m3f32_TRANS_B_INTRIN, MMA_store_16x16_f16_global_INTRIN, MMA_store_16x16_f32_global_INTRIN, MMA_store_16x16_i32_global_INTRIN, @@ -126,6 +132,30 @@ def run_test( else: b_np = np.random.normal(size=(K, N)).astype("float16") c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + elif in_dtype in ["e4m3_float8", "e5m2_float8"]: + typemap = { + "e4m3_float8": "float8_e4m3fn", + "e5m2_float8": "float8_e5m2", + } + a_np = ( + np.random.uniform(low=-5, high=5, size=(M * K)) + .reshape((M, K)) + .astype(typemap[in_dtype]) + ) + if b_transposed: + b_np = ( + np.random.uniform(low=-5, high=5, size=(N * K)) + .reshape((N, K)) + .astype(typemap[in_dtype]) + ) + c_np = np.dot(a_np.astype("float32"), b_np.T.astype("float32")).astype(out_dtype) + else: + b_np = ( + np.random.uniform(low=-5, high=5, size=(N * K)) + .reshape((K, N)) + .astype(typemap[in_dtype]) + ) + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) else: a_np = np.random.randint(-128, 128, (M, K)).astype("int8") @@ -144,7 +174,7 @@ def run_test( f(a, b, c) - if out_dtype != "float16": + if out_dtype != "float16" and in_dtype not in ["e4m3_float8", "e5m2_float8"]: # The numpy reference is computed with fp32 precision (otherwise too slow). # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2) @@ -337,5 +367,91 @@ def index_map_C(i, j): print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) +@tvm.testing.requires_cuda_compute_version(8, 9) +def test_e4m3e4m3f32_m16n16k32(): + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "e4m3_float8" + out_dtype = "float32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_e4m3_A_INTRIN, + LDMATRIX_e4m3_B_TRANS_INTRIN, + MMA_e4m3e4m3f32_TRANS_B_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("e4m3e4m3f32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_cuda_compute_version(8, 9) +def test_e5m2e5m2f32_m16n16k32(): + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "e5m2_float8" + out_dtype = "float32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_e5m2_A_INTRIN, + LDMATRIX_e5m2_B_TRANS_INTRIN, + MMA_e5m2e5m2f32_TRANS_B_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("e5m2e5m2f32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + + if __name__ == "__main__": tvm.testing.main()