diff --git a/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py b/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py index 635a31989..82019a769 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py @@ -13,35 +13,6 @@ from .quantize_ops import FP8RowwiseGroupedGemm -grouped_kernel_registry: list[str] = [ - "fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2", - "fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2", - "fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2", - "fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2", - "fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2", - "fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2", - "fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2", - "fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2", - "fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2", - "fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2", - "fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1", - "fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3", - "fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4", - "fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3", - "fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3", - "fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3", - "fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3", - "fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3", - "fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4", - "fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3", - "fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2", - "fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1", - "fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2", - "fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2", - "fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2", -] - - def main(args: Any): # Extract and format shape arguments. M = [int(m) for m in args.M.strip().split(",")] @@ -62,12 +33,12 @@ def main(args: Any): quantized_vals = group_gemm_op.quantize(A, B) # Iterate over kernels to find the most performant one. benchmark_results = [] - for kernel_name in grouped_kernel_registry: + for kernel_name in torch.ops.fbgemm.get_f8f8bf16_rowwise_grouped_kernels(): # Do a warmup run of the kernel. output = group_gemm_op.compute(*quantized_vals, kernel_name=kernel_name) # Benchmark this kernel implementation. ms_runtime = group_gemm_op.benchmark( - *quantized_vals, use_cuda_graph=False, kernel_name=kernel_name + *quantized_vals, use_cuda_graph=True, kernel_name=kernel_name ) # Compute statistics for this kernel. tflops = 0 @@ -84,6 +55,7 @@ def main(args: Any): / 1e9 ) # Record results. + print(f"Kernel: {kernel_name}, ms: {ms_runtime:.4f}, TFLOPS: {tflops:.2f}") benchmark_results.append( { "kernel_name": kernel_name, @@ -92,13 +64,11 @@ def main(args: Any): "gbps": gbps, } ) - # Print all results. - print("Benchmark results:") - for result in benchmark_results: - print(f"Kernel: {result['kernel_name']}, TFLOPS: {result['tflops']}") # Report best kernel. best_kernel = min(benchmark_results, key=lambda x: x["ms_runtime"]) - print(f"Best kernel for this shape: {best_kernel['kernel_name']}") + print( + f"Best kernel for this shape: {best_kernel['kernel_name']}: {best_kernel['tflops']:.2f} TFLOPS" + ) # If specified, save all results. if args.export_csv: diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index 3c4c6e3e3..850a8c257 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -107,14 +107,14 @@ def benchmark_grouped( B, bench_quantize=True, use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=False, + use_cuda_graph=True, ) else: ms_runtime = quantize_op.benchmark( *quantized_vals, bench_quantize=False, use_rotating_buffer_bench=use_rotating_buffer_bench, - use_cuda_graph=False, + use_cuda_graph=True, ) # Print out results for this op. diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index efc15ea89..7474e759d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -19,16 +20,131 @@ #include #include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" #include "kernels/fp8_rowwise_grouped_kernel_manifest.h" namespace fbgemm_gpu { +// Define useful types that are needed for various kernels. +using KernelArguments = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<2>; +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using D0DataType = float; +using D1DataType = float; +using DsDataType = ck::Tuple; +using EDataType = ck::bhalf_t; + +RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { + // We use shape heuristics to find the best kernel. + // To do this, we divide by the size of M and find the best + // option within that grouping. + if (M <= 16) { + if (N < 8192 && K <= 8192) { + return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; + } + if (K <= 8192) { + return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + } + return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; + } + if (M <= 32) { + if (N < 8192 && K <= 8192) { + return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + } + if (K <= 8192) { + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + } + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; + } + if (M <= 64) { + return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + if (M <= 128) { + if (N < 8192 && K <= 8192) { + return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + if (M <= 256) { + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + if (M <= 512) { + if (K <= 8192) { + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + } + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + // Default kernel for all other shapes. + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; +} + +__global__ void set_kernel_args_kernel( + KernelArguments* kernel_args, + ADataType* XQ, + BDataType* WQ, + D0DataType* w_scale, + D1DataType* x_scale, + EDataType* output, + int M, + int N, + int K) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + // Each kernel annoyingly can only set the kernel args for one group. + // This could only be avoided with complicated memory management. + if (idx == 0) { + // Write kernel arguments directly to memory. + KernelArguments kernel_group_args = { + XQ, WQ, {w_scale, x_scale}, output, M, N, K, K, K, {0, 0}, N}; + kernel_args[0] = kernel_group_args; + } +} + +void set_grouped_kernel_args( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector output) { + TORCH_CHECK( + XQ.size() == WQ.size() && XQ.size() == x_scale.size() && + XQ.size() == w_scale.size(), + "All inputs must have the same number of groups."); + int group_count = XQ.size(); + // We use the smallest reasonable block size since we effectively need only 1 thread. + int blockSize = 32; + int numBlocks = 1; + auto stream = at::cuda::getCurrentHIPStream().stream(); + + // Launch a kernel for each group to set kernel memory on device. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int K = XQ[i].size(1); + int N = WQ[i].size(0); + // Launch kernel to set kernel arguments. + set_kernel_args_kernel<<>>( + reinterpret_cast( + reinterpret_cast(kernel_args.data_ptr()) + + (i * sizeof(KernelArguments))), + reinterpret_cast(XQ[i].data_ptr()), + reinterpret_cast(WQ[i].data_ptr()), + reinterpret_cast(w_scale[i].data_ptr()), + reinterpret_cast(x_scale[i].data_ptr()), + reinterpret_cast(output[i].data_ptr()), + M, + N, + K); + } +} + std::vector f8f8bf16_rowwise_grouped( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional output = std::nullopt, + std::optional> output = std::nullopt, std::optional kernel_name = std::nullopt) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. @@ -36,6 +152,7 @@ std::vector f8f8bf16_rowwise_grouped( XQ.size() == WQ.size() && XQ.size() == x_scale.size() && XQ.size() == w_scale.size(), "All inputs must have the same number of groups."); + int group_count = XQ.size(); // Iterate over inputs and check they are valid. for (at::Tensor x : XQ) { TORCH_CHECK(x.is_cuda() && x.is_contiguous()); @@ -58,39 +175,68 @@ std::vector f8f8bf16_rowwise_grouped( TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32."); } - // Allocate output if needed. std::vector Y; - Y.reserve(XQ.size()); if (output.has_value()) { - TORCH_CHECK(output.value().size() == XQ.size(), "Output and input must have same number of groups."); + Y = output.value(); + TORCH_CHECK( + Y.size() == group_count, + "Output and input must have same number of groups."); // Check that output shapes are correct. - for (int i = 0; i < output.value().size(); i++) { + for (int i = 0; i < group_count; i++) { int M = XQ[i].size(0); int N = WQ[i].size(0); - int out_M = output.value()[i].size(0); - int out_N = output.value()[i].size(1); - TORCH_CHECK(M == out_M && N == out_N, "Output tensors do not have the expected shape."); - TORCH_CHECK(output.value()[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); - Y.push_back(output.value()[i]); + int out_M = Y[i].size(0); + int out_N = Y[i].size(1); + TORCH_CHECK( + M == out_M && N == out_N, + "Output tensors do not have the expected shape."); + TORCH_CHECK( + Y[i].dtype() == at::kBFloat16, "Output dtype must be bfloat16."); } } else { - for (int i = 0; i < XQ.size(); i++) { + for (int i = 0; i < group_count; i++) { int M = XQ[i].size(0); int N = WQ[i].size(0); Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16))); } } + // Prepare kernel arguments by copying them to the proper device location. + at::Tensor kernel_args = at::empty({1000}, XQ[0].options().dtype(at::kByte)); + set_grouped_kernel_args(XQ, WQ, x_scale, w_scale, kernel_args, Y); + // If provided a specific kernel implementation, dispatch to it. if (kernel_name.has_value()) { auto it = kernel_name_map.find(kernel_name.value()); // If not found, raise an error. - TORCH_CHECK(it != kernel_name_map.end(), "Could not find kernel " + kernel_name.value()); + TORCH_CHECK( + it != kernel_name_map.end(), + "Could not find kernel " + kernel_name.value()); // If found, always use requested kernel. - return it->second(XQ, WQ, x_scale, w_scale, Y); + return it->second(XQ, WQ, x_scale, w_scale, kernel_args, Y); + } + // Otherwise, use heuristics to find the best kernel options. + // We use the largest of each shape for heuristics. + int MaxM = 0; + int MaxN = 0; + int MaxK = 0; + for (int i = 0; i < group_count; i++) { + MaxM = max(MaxM, XQ[i].size(0)); + MaxN = max(MaxN, WQ[i].size(0)); + MaxK = max(MaxK, XQ[i].size(1)); + } + RowwiseGroupedKernel selected_kernel = + rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); +} + +std::vector get_f8f8bf16_rowwise_grouped_kernels() { + /* Helper function to get the names of avaialable grouped gemm kernels.*/ + std::vector kernel_names; + for (const auto& pair : kernel_name_map) { + kernel_names.push_back(pair.first); } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - XQ, WQ, x_scale, w_scale, Y); + return kernel_names; } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..4c99757b8 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..21e92b399 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index f9feea0a0..e24f8b6fd 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -14,7 +14,8 @@ fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_in at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::vector Y) { + at::Tensor kernel_args, + std::vector Y) { // A kernel that works well on small but not super tiny shapes. using DeviceGemmInstance = DeviceGemmHelper< 128, @@ -34,5 +35,5 @@ fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_in ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..3e5d0f249 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 28fd933bc..9c871f64c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,74 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 128 != 0 || N % 32 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - - // This kernel seems optimal in the most purely compute bound tasks. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..f2a90ac2b --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..64f631164 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..b9ed2888b --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 686b4af0f..34a7367f5 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,74 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // The smallest kernel we have available. Works well for memory bound shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 16 != 0 || N % 32 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); - + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..0637fde60 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index e30395781..868cd8275 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,73 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // The smallest kernel we have available. Works well for memory bound shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 16 != 0 || N % 32 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); - } else{ + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..8c324b798 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..eeb3ecfa9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..2d4faaaa7 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..f5c9aa779 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..6f8da7b7e --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 32f855eda..123638a33 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,22 +1,31 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // The smallest kernel we have available. Works well for memory bound shapes. - using DeviceGemmInstance = DeviceGemmHelper< + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< 128, 16, 32, @@ -32,7 +41,32 @@ fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_int 1, 1, ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; - // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..28a14967c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..0a20e29cd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..fe794f853 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..6ef2ed503 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index a5ebfec50..3cd1a219c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -1,22 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -25,51 +24,49 @@ fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_in pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..a3b3b37ec --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..fac1cd90b --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 5ee482262..5059ba737 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,73 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A small kernel for small but not tiny shapes. - - // Check if this input needs to be padded. + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 32 != 0 || N % 16 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..1750fe915 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..cd5cc29f6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index 8bae17737..67d7d7d77 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -1,73 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that works well on small but not super tiny shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 32 != 0 || N % 64 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip index fd7b0918e..63bb549f4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -1,73 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that works well on small but not super tiny shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 32 != 0 || N % 64 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..329e3a0e0 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..9d7908f44 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..7330ad30c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 59f33ef81..5c07c9b99 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,73 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A small kernel for small but not tiny shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 64 != 0 || N % 32 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index b68a26299..270c4b6b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,22 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -25,51 +24,49 @@ fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 84914d05e..c90498392 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,22 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // V5 kernel that works well on some medium shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -25,50 +24,49 @@ fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_i pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip new file mode 100644 index 000000000..9282b9d89 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index f5c172a0c..1a09ef2a4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,21 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" -std::vector fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - +std::vector +fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -24,51 +24,49 @@ std::vector fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64 pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 64, - 32, - 32, - 2, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 64, - 32, - 32, - 2, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..e05304e31 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 826510e8c..005b98e6f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,74 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 128 != 0 || N % 64 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 64, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 64, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..43440d9f0 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..59892b182 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index 4aad7d97f..47be5bd05 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,22 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // This kernel works well for many medium to large shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -25,50 +24,49 @@ fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_i pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..b4c7f9344 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..c27587b65 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..5410a9083 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip index f62256540..f728fd9cf 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip @@ -1,74 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 256 != 0 || N % 224 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - - // This kernel seems optimal in the most purely compute bound tasks. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index ce62e2b34..165d3cb90 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,21 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" -std::vector fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - +std::vector +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -24,51 +24,49 @@ std::vector fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x3 pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index 0b0cea66a..fd63d9bef 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,74 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 256 != 0 || N % 256 != 0 || K % 64 != 0) { + if (K % 64 != 0) { pad = true; } } - - // This kernel seems optimal in the most purely compute bound tasks. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 16, - 16, - 8, - 8, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 16, - 16, - 8, - 8, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 26e822d4e..872c8d675 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,22 +1,21 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { @@ -25,51 +24,49 @@ fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_in pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 32, - 32, - 4, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 32, - 32, - 4, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..6a391e1da --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..3d051451a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..7815fb933 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..fc8bdf60d --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 04bf725e3..35700d809 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,74 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // A kernel that seems to work well on mid sized tensors. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 64 != 0 || N % 64 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - - // Dispatch based on whether padding is needed or not. if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 64, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3>; + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 64, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..8714c5d8e --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index a73187051..749d0a3c9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,73 +1,72 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // The smallest kernel we have available. Works well for memory bound shapes. - + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { // Check if this input needs to be padded. bool pad = false; for (int i = 0; i < XQ.size(); i++) { - int M = XQ[i].size(0); - int N = WQ[i].size(0); int K = XQ[i].size(1); - if (M % 16 != 0 || N % 16 != 0 || K % 128 != 0) { + if (K % 128 != 0) { pad = true; } } - if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..76388b8bb --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..5cc809bef --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 128 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip index 1d3bd374f..bf8b430f0 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip @@ -14,7 +14,8 @@ fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intr at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::vector Y) { + at::Tensor kernel_args, + std::vector Y) { // Secret kernel that seems good with small M but large N and K. using DeviceGemmInstance = DeviceGemmHelper< 64, @@ -35,5 +36,5 @@ fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intr ck::BlockGemmPipelineVersion::v1, ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..3468432a9 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip new file mode 100644 index 000000000..1f2339131 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..2f28b951e --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..673d19a0f --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 256 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index c14cbf3f8..a832dcd1e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -14,7 +14,8 @@ fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_inte at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::vector Y) { + at::Tensor kernel_args, + std::vector Y) { // The smallest kernel we have available. Works well for memory bound shapes. using DeviceGemmInstance = DeviceGemmHelper< 64, @@ -35,5 +36,5 @@ fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_inte ck::BlockGemmPipelineVersion::v2, ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip new file mode 100644 index 000000000..b9d76db77 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 0cb601d89..18236d9db 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,22 +1,31 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // The smallest kernel we have available. Works well for memory bound shapes. - using DeviceGemmInstance = DeviceGemmHelper< + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< 64, 16, 16, @@ -25,14 +34,39 @@ fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interw 16, 1, 1, - S<8, 8, 1>, - S<8, 8, 1>, + S<8, 8, 1>, + S<8, 8, 1>, S<1, 16, 1, 4>, S<4, 4, 1>, 1, 1, ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; - // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip new file mode 100644 index 000000000..17f17cee6 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..e9efd64a5 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 512 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 086ab787a..6b7d7553d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,22 +1,31 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include "fp8_rowwise_grouped_common.h" std::vector fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector Y) { - // The smallest kernel we have available. Works well for memory bound shapes. - using DeviceGemmInstance = DeviceGemmHelper< + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< 64, 16, 16, @@ -32,7 +41,32 @@ fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_inter 1, 1, ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2>; - // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, Y); + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip new file mode 100644 index 000000000..3de229cd4 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -0,0 +1,72 @@ +/* +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ + +#include "fp8_rowwise_grouped_common.h" + +std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y) { + // Check if this input needs to be padded. + bool pad = false; + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + if (K % 64 != 0) { + pad = true; + } + } + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h index d1a984757..fa66a1628 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h @@ -6,11 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include - #include #include #include @@ -22,13 +17,6 @@ #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" - #include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp" // Define commonly used types. @@ -136,6 +124,7 @@ std::vector f8f8bf16_rowwise_grouped_impl( at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y) { // Get input information. int group_count = XQ.size(); @@ -144,7 +133,6 @@ std::vector f8f8bf16_rowwise_grouped_impl( using GemmDesc = ck::tensor_operation::device::GemmDesc; // Create gemm shape containers. std::vector gemm_descs; - std::vector ggemm_kargs; // Create container for input arguments. std::vector A_args; std::vector B_args; @@ -152,7 +140,6 @@ std::vector f8f8bf16_rowwise_grouped_impl( std::vector> D_args = {}; // Reserve space in argument arrays. gemm_descs.reserve(group_count); - ggemm_kargs.reserve(group_count); A_args.reserve(group_count); B_args.reserve(group_count); C_args.reserve(group_count); @@ -162,25 +149,9 @@ std::vector f8f8bf16_rowwise_grouped_impl( // Set the shape arguments for this gemm. int M = XQ[i].size(0); int K = XQ[i].size(1); - int N = WQ[i].size(1); + int N = WQ[i].size(0); GemmDesc gemm_desc = {M, N, K, K, K, N, {0, 0}}; gemm_descs.push_back(gemm_desc); - // For some reason, we also need to specify kernel args (which are quite - // redundant to other arguments). - KernelArguments kernel_args = { - reinterpret_cast(XQ[i].data_ptr()), - reinterpret_cast(WQ[i].data_ptr()), - {reinterpret_cast(w_scale[i].data_ptr()), - reinterpret_cast(x_scale[i].data_ptr())}, - reinterpret_cast(Y[i].data_ptr()), - M, - N, - K, - K, - K, - {0, 0}, - N}; - ggemm_kargs.push_back(kernel_args); // Set pointers to inputs and outputs. A_args.push_back(reinterpret_cast(XQ[i].data_ptr())); B_args.push_back(reinterpret_cast(WQ[i].data_ptr())); @@ -207,21 +178,12 @@ std::vector f8f8bf16_rowwise_grouped_impl( a_element_op, b_element_op, cde_element_op); + + // Set gemm kernel arguments. + gemm.SetDeviceKernelArgs(argument, kernel_args.data_ptr()); + // Get hip graph stream if it exists. auto stream = at::cuda::getCurrentHIPStream().stream(); - // Set up kernel arguments. - // Allocate device memory with pytorch for simplicity. - at::Tensor gemm_arg_dev_mem = at::empty( - gemm.GetDeviceKernelArgSize(&argument), XQ[0].options().dtype(at::kByte)); - // Copy arguments to device memory. - hipMemcpyAsync( - gemm_arg_dev_mem.data_ptr(), - ggemm_kargs.data(), - gemm.GetDeviceKernelArgSize(&argument), - hipMemcpyHostToDevice, - stream); - gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.data_ptr()); - invoker.Run(argument, StreamConfig{stream, false}); return Y; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h index 0471ff113..ff7de71e7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h @@ -6,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include @@ -20,81 +19,322 @@ using RowwiseGroupedKernel = std::function( at::TensorList, at::TensorList, at::TensorList, + at::Tensor, std::vector)>; -// Default tile size. +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + std::vector fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); -// Large shape performance. std::vector fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); -// Jumbo tile size. std::vector -fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( +fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( +fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( +fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector @@ -103,30 +343,34 @@ fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_int at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector @@ -135,147 +379,431 @@ fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_int at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); std::vector -fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, std::vector Y); // Map function for string name to kernel implementation for manual // specification. static const std::unordered_map kernel_name_map = { + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), KERNEL_NAME_MAP_ENTRY( fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), + fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1), + fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), + fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3), + fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3), + fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), + fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), + fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1), + fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2), + fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), KERNEL_NAME_MAP_ENTRY( fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2), + KERNEL_NAME_MAP_ENTRY( + fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2), }; diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index ed139cc9d..3bb5f4580 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -81,8 +81,9 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional output = std::nullopt, + std::optional> output = std::nullopt, std::optional kernel_name = std::nullopt); +std::vector get_f8f8bf16_rowwise_grouped_kernels(); at::Tensor f8f8bf16_blockwise( at::Tensor XQ, at::Tensor WQ, @@ -174,6 +175,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { #ifdef USE_ROCM m.def( "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); + m.def("get_f8f8bf16_rowwise_grouped_kernels() -> str[]"); + m.impl( + "get_f8f8bf16_rowwise_grouped_kernels", + get_f8f8bf16_rowwise_grouped_kernels); #endif m.def( "f8f8bf16_blockwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, int block_m=128, int block_n=128, int block_k=128) -> Tensor"); @@ -256,6 +261,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); #endif +#ifdef USE_ROCM + m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); +#endif } // Shape registration functions.