diff --git a/3rdparty/cutlass b/3rdparty/cutlass index afa177220367..ad7b2f5e84fc 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index d11777e8514a..b74ce4c8dfe0 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -58,19 +58,27 @@ if(USE_CUDA AND USE_CUTLASS) set(TVM_CUTLASS_RUNTIME_SRCS "") if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a") - list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) - list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu) list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu) - list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu) + endif() + if (CMAKE_CUDA_ARCHITECTURES MATCHES "100a") + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu) endif() if(TVM_CUTLASS_RUNTIME_SRCS) add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS}) - target_compile_options(tvm_cutlass_objs PRIVATE $<$:--expt-relaxed-constexpr>) + target_compile_options(tvm_cutlass_objs PRIVATE $<$:-lineinfo --expt-relaxed-constexpr>) target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include ) + target_link_libraries(tvm_cutlass_objs PRIVATE tvm_ffi_header) target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=) + # Note: enable this to get more detailed logs for cutlass kernels + # target_compile_definitions(tvm_cutlass_objs PRIVATE CUTLASS_DEBUG_TRACE_LEVEL=2) list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$>") endif() diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh new file mode 100644 index 000000000000..ebb8f58a6b18 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cuh @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +namespace tvm { +namespace runtime { + +template +struct CutlassGroupGemm; + +template +void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + CHECK_EQ(x->ndim, 2); + CHECK_EQ(weight->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + int num_groups = weight->shape[0]; + int n = weight->shape[1]; + int k = weight->shape[2]; + float alpha = 1.0f; + float beta = 0.0f; + cudaStream_t stream = static_cast(func().cast()); + + if (DataType(x->dtype) == DataType::Float(16)) { + CHECK(DataType(weight->dtype) == DataType::Float(16)); + CHECK(DataType(out->dtype) == DataType::Float(16)); + using Dtype = cutlass::half_t; + CutlassGroupGemm::run( + static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, alpha, beta, static_cast(out->data), stream); + } else if (DataType(x->dtype) == DataType::BFloat(16)) { + CHECK(DataType(weight->dtype) == DataType::BFloat(16)); + CHECK(DataType(out->dtype) == DataType::BFloat(16)); + using Dtype = cutlass::bfloat16_t; + CutlassGroupGemm::run( + static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, alpha, beta, static_cast(out->data), stream); + } +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh new file mode 100644 index 000000000000..f38664915d35 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +inline size_t aligned(size_t value, size_t alignment = 16) { + return (value + alignment - 1) / alignment * alignment; +} + +template +struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch +}; + +template +struct MMA2SMConfig { + using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>; + using ClusterShape = Shape<_2, _2, _1>; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch +}; + +template +struct CutlassGroupGemmRunner { + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements + // (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ScaleType = std::variant; + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + // Different configs for 1SM and 2SM MMA kernel + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, OperatorClass, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, + AlignmentC, typename ScheduleConfig::EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, + AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape, + typename ScheduleConfig::ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename ScheduleConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, + ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, + StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D, + uint8_t* workspace, int64_t workspace_size, int num_groups, ScaleType alpha, + ScaleType beta, cudaStream_t stream) { + typename Gemm::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + [&]() { + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; + if (std::holds_alternative(alpha)) { + fusion_args.alpha = std::get(alpha); + fusion_args.beta = std::get(beta); + } else if (std::holds_alternative(alpha)) { + fusion_args.alpha_ptr = std::get(alpha); + fusion_args.beta_ptr = std::get(beta); + } else { + LOG(FATAL) << "Unsupported alpha and beta type"; + throw; + } + }(); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + arguments = typename Gemm::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_groups, problem_sizes, problem_sizes_host}, + {ptr_A, stride_A, ptr_B, stride_B}, + {fusion_args, ptr_C, stride_C, ptr_D, stride_D}, + hw_info}; + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +__global__ void prepare_group_gemm_arguments( + const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, + StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out, + int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) { + int group_id = threadIdx.x; + if (group_id >= num_groups) return; + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; + ptr_A[group_id] = x + prev_rows * k; + ptr_B[group_id] = weight + group_id * k * n; + ptr_D[group_id] = out + prev_rows * n; + problem_sizes[group_id] = {static_cast(indptr[group_id] - prev_rows), static_cast(n), + static_cast(k)}; + stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{}); +} + +template +void cutlass_group_gemm_sm100(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, + std::variant alpha, + std::variant beta, ElementC* out, + cudaStream_t stream) { + // Note: We use MMA2SMConfig for now. It can be changed to MMA1SMConfig if needed. + using Runner = CutlassGroupGemmRunner, ElementA, ElementB, ElementC>; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + + Runner runner; + std::ptrdiff_t offset = 0; + const ElementA** ptr_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementA*) * num_groups); + const ElementB** ptr_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementB*) * num_groups); + ElementC** ptr_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementC*) * num_groups); + typename ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); + StrideA* stride_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideA) * num_groups); + StrideB* stride_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideB) * num_groups); + StrideC* stride_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideC) * num_groups); + prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes, + stride_A, stride_B, stride_D, x, + weight, out, indptr, n, k, num_groups); + offset = aligned(offset, 256); + runner.run_group_gemm(ptr_A, ptr_B, const_cast(ptr_D), ptr_D, problem_sizes, + nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset, + workspace_size - offset, num_groups, alpha, beta, stream); +} diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh similarity index 96% rename from src/runtime/contrib/cutlass/group_gemm_runner.cuh rename to src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh index a3c52e27a9d5..38e1beb2b8f4 100644 --- a/src/runtime/contrib/cutlass/group_gemm_runner.cuh +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh @@ -169,11 +169,11 @@ __global__ void prepare_group_gemm_arguments( } template -void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, - int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, - std::variant alpha, - std::variant beta, ElementC* out, - cudaStream_t stream) { +void cutlass_group_gemm_sm90(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, + std::variant alpha, + std::variant beta, ElementC* out, + cudaStream_t stream) { using Runner = CutlassGroupGemmRunner; using StrideA = typename Runner::StrideA; using StrideB = typename Runner::StrideB; diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu new file mode 100644 index 000000000000..29efcbe088ae --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "fp16_group_gemm.cuh" +#include "fp16_group_gemm_runner_sm100.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace tvm { +namespace runtime { + +template +struct CutlassGroupGemm<100, ElementA, ElementB, ElementC> { + static void run(ElementA* A, ElementB* B, int64_t* indptr, uint8_t* workspace, int workspace_size, + int N, int K, int num_groups, float alpha, float beta, ElementC* C, + cudaStream_t stream) { + cutlass_group_gemm_sm100( + A, B, indptr, workspace, workspace_size, N, K, num_groups, alpha, beta, C, stream); + } +}; + +void tvm_cutlass_group_gemm_sm100(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray out) { + tvm_cutlass_group_gemm_impl<100>(x, weight, indptr, workspace, out); +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm100); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu similarity index 60% rename from src/runtime/contrib/cutlass/fp16_group_gemm.cu rename to src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu index dffe7dc4ffed..93a03a0675b2 100644 --- a/src/runtime/contrib/cutlass/fp16_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu @@ -19,14 +19,28 @@ #include #include -#include #include +#include #include -#include "group_gemm_runner.cuh" +#include "fp16_group_gemm.cuh" +#include "fp16_group_gemm_runner_sm90.cuh" + +namespace tvm { +namespace runtime { #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) +template +struct CutlassGroupGemm<90, ElementA, ElementB, ElementC> { + static void run(ElementA* A, ElementB* B, int64_t* indptr, uint8_t* workspace, int workspace_size, + int N, int K, int num_groups, float alpha, float beta, ElementC* C, + cudaStream_t stream) { + cutlass_group_gemm_sm90(A, B, indptr, workspace, workspace_size, + N, K, num_groups, alpha, beta, C, stream); + } +}; + template <> struct KernelTraits { using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; @@ -34,36 +48,21 @@ struct KernelTraits { using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster }; -namespace tvm { -namespace runtime { +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; -template void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, NDArray out) { - // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. - // Recommened size is 4MB. - static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); - cudaStream_t stream = static_cast(func().cast()); - CHECK_EQ(x->ndim, 2); - CHECK_EQ(weight->ndim, 3); - CHECK_EQ(indptr->ndim, 1); - CHECK_EQ(workspace->ndim, 1); - CHECK_EQ(out->ndim, 2); - int num_groups = weight->shape[0]; - int n = weight->shape[1]; - int k = weight->shape[2]; - float alpha = 1.0f; - float beta = 0.0f; - cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), - static_cast(indptr->data), static_cast(workspace->data), - workspace->shape[0], n, k, num_groups, alpha, beta, - static_cast(out->data), stream); + tvm_cutlass_group_gemm_impl<90>(x, weight, indptr, workspace, out); } -TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") - .set_body_typed(tvm_cutlass_group_gemm_sm90); +TVM_FFI_REGISTER_GLOBAL("cutlass.group_gemm").set_body_typed(tvm_cutlass_group_gemm_sm90); + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED } // namespace runtime } // namespace tvm - -#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu b/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu deleted file mode 100644 index 5164958afeb5..000000000000 --- a/src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include - -#include "../cublas/cublas_utils.h" -#include "blockwise_scaled_gemm_runner.cuh" - -#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) - -namespace tvm { -namespace runtime { - -void tvm_cutlass_fp8_blockwise_scaled_gemm(NDArray a, NDArray b, NDArray scales_a, NDArray scales_b, - NDArray workspace, int64_t block_size_0, - int64_t block_size_1, NDArray out) { - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _1, _1>; - - // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. - // Recommened size is 4MB. - const auto get_stream_func = tvm::ffi::Function::GetGlobal("runtime.get_cuda_stream"); - ICHECK(get_stream_func.has_value()); - cudaStream_t stream = static_cast((*get_stream_func)().cast()); - - CHECK_GE(a->ndim, 2); - CHECK_EQ(scales_a->ndim, a->ndim); - CHECK_EQ(b->ndim, 2); - CHECK_EQ(scales_b->ndim, 2); - CHECK_EQ(workspace->ndim, 1); - CHECK_EQ(out->ndim, a->ndim); - int64_t m = 1; - for (int64_t i = 0; i < a->ndim - 1; ++i) { - m *= a->shape[i]; - } - int64_t n = b->shape[0]; - CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is supported now."; - int64_t k = a->shape[a->ndim - 1]; - - // scales_a is col-major of (*a_shape[:-1], k / block_size) - CHECK_EQ(scales_a->shape[0] * block_size_1, k); - for (int64_t i = 1; i < scales_a->ndim; ++i) { - CHECK_EQ(scales_a->shape[i], a->shape[i - 1]); - } - // scales_b is col-major of (k / block_size, n / block_size) - CHECK_EQ(scales_b->shape[0] * block_size_0, n); - CHECK_EQ(scales_b->shape[1] * block_size_1, k); - - using tvm::runtime::DataType; - CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); - - if (DataType(out->dtype) == DataType::Float(16)) { - cutlass_fp8_blockwise_scaled_gemm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { - cutlass_fp8_blockwise_scaled_gemm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, stream); - } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); - } -} - -void tvm_cutlass_fp8_blockwise_scaled_bmm(NDArray a, NDArray b, NDArray scales_a, NDArray scales_b, - NDArray workspace, int64_t block_size_0, - int64_t block_size_1, NDArray out) { - using TileShape = Shape<_128, _128, _128>; - using ClusterShape = Shape<_1, _1, _1>; - - // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. - // Recommened size is 4MB. - const auto get_stream_func = tvm::ffi::Function::GetGlobal("runtime.get_cuda_stream"); - ICHECK(get_stream_func.has_value()); - cudaStream_t stream = static_cast((*get_stream_func)().cast()); - - CHECK_EQ(a->ndim, 3); - CHECK_EQ(scales_a->ndim, 3); - CHECK_EQ(b->ndim, 3); - CHECK_EQ(scales_b->ndim, 3); - CHECK_EQ(workspace->ndim, 1); - CHECK_EQ(out->ndim, 3); - int64_t batch_size = a->shape[0]; - int64_t m = a->shape[1]; - int64_t n = b->shape[1]; - CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now."; - int64_t k = a->shape[2]; - CHECK_EQ(b->shape[0], batch_size); - CHECK_EQ(scales_a->shape[0], batch_size); - CHECK_EQ(scales_b->shape[0], batch_size); - CHECK_EQ(out->shape[0], batch_size); - - // scales_a is col-major of (batch_size, m, k / block_size) - CHECK_EQ(scales_a->shape[1] * block_size_1, k); - CHECK_EQ(scales_a->shape[2], m); - // scales_b is col-major of (k / block_size, n / block_size) - CHECK_EQ(scales_b->shape[1] * block_size_0, n); - CHECK_EQ(scales_b->shape[2] * block_size_1, k); - - using tvm::runtime::DataType; - CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); - CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); - CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); - CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); - - if (DataType(out->dtype) == DataType::Float(16)) { - cutlass_fp8_blockwise_scaled_bmm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, batch_size, stream); - } else if (DataType(out->dtype) == DataType::BFloat(16)) { - cutlass_fp8_blockwise_scaled_bmm( - static_cast(a->data), static_cast(b->data), - static_cast(scales_a->data), static_cast(scales_b->data), - static_cast(out->data), static_cast(workspace->data), - workspace->shape[0] * DataType(workspace->dtype).bytes(), m, n, k, batch_size, stream); - } else { - LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); - } -} - -TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_gemm); -TVM_FFI_REGISTER_GLOBAL("cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn") - .set_body_typed(tvm_cutlass_fp8_blockwise_scaled_bmm); - -} // namespace runtime -} // namespace tvm - -#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu similarity index 98% rename from src/runtime/contrib/cutlass/fp8_group_gemm.cu rename to src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu index 62a91dec1809..686a6ebcffeb 100644 --- a/src/runtime/contrib/cutlass/fp8_group_gemm.cu +++ b/src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu @@ -23,7 +23,7 @@ #include #include -#include "group_gemm_runner.cuh" +#include "fp16_group_gemm_runner_sm90.cuh" #if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh new file mode 100644 index 000000000000..4ecca5f1d8a9 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm.cuh @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +namespace tvm { +namespace runtime { + +template +struct CutlassFP8GroupwiseGemm; + +template +void tvm_cutlass_fp8_groupwise_scaled_gemm_impl(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + static tvm::ffi::Function get_stream_func = + tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(get_stream_func().cast()); + + CHECK_GE(a->ndim, 2); + CHECK_EQ(scales_a->ndim, a->ndim); + CHECK_EQ(b->ndim, 2); + CHECK_EQ(scales_b->ndim, 2); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, a->ndim); + int64_t m = 1; + for (int64_t i = 0; i < a->ndim - 1; ++i) { + m *= a->shape[i]; + } + int64_t n = b->shape[0]; + CHECK_EQ(a->shape[a->ndim - 1], b->shape[1]) << "Only col-major B is supported now."; + int64_t k = a->shape[a->ndim - 1]; + + // scales_a is col-major of (*a_shape[:-1], k / block_size) + CHECK_EQ(scales_a->shape[0] * block_size_1, k); + for (int64_t i = 1; i < scales_a->ndim; ++i) { + CHECK_EQ(scales_a->shape[i], a->shape[i - 1]); + } + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[0]); + CHECK_EQ(scales_b->shape[1] * block_size_1, k); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, 1, stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, 1, stream); + } else { + LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + } +} + +template +void tvm_cutlass_fp8_groupwise_scaled_bmm_impl(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + // Workspace is used for storing device-side gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + static tvm::ffi::Function get_stream_func = + tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(get_stream_func().cast()); + + CHECK_EQ(a->ndim, 3); + CHECK_EQ(scales_a->ndim, 3); + CHECK_EQ(b->ndim, 3); + CHECK_EQ(scales_b->ndim, 3); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 3); + int64_t batch_size = a->shape[0]; + int64_t m = a->shape[1]; + int64_t n = b->shape[1]; + CHECK_EQ(a->shape[2], b->shape[2]) << "Only col-major B is supported now."; + int64_t k = a->shape[2]; + CHECK_EQ(b->shape[0], batch_size); + CHECK_EQ(scales_a->shape[0], batch_size); + CHECK_EQ(scales_b->shape[0], batch_size); + CHECK_EQ(out->shape[0], batch_size); + + // scales_a is col-major of (batch_size, m, k / block_size) + CHECK_EQ(scales_a->shape[1] * block_size_1, k); + CHECK_EQ(scales_a->shape[2], m); + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ(scales_b->shape[1] * block_size_0, n); + CHECK_EQ(scales_b->shape[2] * block_size_1, k); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, batch_size, stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + CutlassFP8GroupwiseGemm::run(static_cast(a->data), + static_cast(b->data), + static_cast(scales_a->data), + static_cast(scales_b->data), + static_cast(out->data), + static_cast(workspace->data), + workspace->shape[0] * DataType(workspace->dtype).bytes(), m, + n, k, batch_size, stream); + } else { + LOG(FATAL) << "Unsupported output dtype: " << DataType(out->dtype); + } +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh new file mode 100644 index 000000000000..95fc578fd43f --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using tvm::runtime::NDArray; + +template +struct CutlassFP8ScaledGroupwiseGemmRunnerSM100 { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using LayoutD = LayoutC; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // MMA type + using ElementAccumulator = float; // Element Accumulator will also be our scale factor type + using ElementCompute = float; + using ElementBlockScale = float; + + static constexpr int ScaleGranularityM = 1; + static constexpr int ScaleGranularityN = 128; + static constexpr int ScaleGranularityK = 128; + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, UMMA::Major::MN, UMMA::Major::K>; + + using LayoutSFA = + decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix operand + using LayoutSFB = + decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix operand + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, + LayoutC, AlignmentC, ElementD, LayoutC, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA, + cute::tuple, AlignmentA, ElementB, cute::tuple, + AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelScheduleSm100Blockwise>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, + void>; // Default to ClusterLaunchControl (CLC) based tile scheduler + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + + void run_gemm(const ElementA* a_ptr, const ElementB* b_ptr, const ElementBlockScale* scales_a_ptr, + const ElementBlockScale* scales_b_ptr, ElementD* o_ptr, int m, int n, int k, int l, + uint8_t* workspace, int64_t workspace_size, cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + StrideA stride_a = + cute::make_stride(static_cast(k), Int<1>{}, static_cast(m * k)); + StrideB stride_b = + cute::make_stride(static_cast(k), Int<1>{}, static_cast(n * k)); + StrideD stride_d = + cute::make_stride(static_cast(n), Int<1>{}, static_cast(m * n)); + auto layout_scales_a = ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, l)); + auto layout_scales_b = ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, l)); + + typename Gemm::Arguments arguments = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, l}, + {a_ptr, stride_a, b_ptr, stride_b, scales_a_ptr, + layout_scales_a, scales_b_ptr, layout_scales_b}, + {{}, o_ptr, stride_d, o_ptr, stride_d}, + hw_info}; + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +void cutlass_fp8_groupwise_scaled_mm_sm100(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementD* out, + uint8_t* workspace, int64_t workspace_size, int64_t m, + int64_t n, int64_t k, int64_t l, cudaStream_t stream) { + using Runner = CutlassFP8ScaledGroupwiseGemmRunnerSM100; + Runner runner; + runner.run_gemm(a, b, scales_a, scales_b, out, m, n, k, l, workspace, workspace_size, stream); +} diff --git a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh similarity index 75% rename from src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh rename to src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh index f520bf815a94..5ec9ed083916 100644 --- a/src/runtime/contrib/cutlass/blockwise_scaled_gemm_runner.cuh +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh @@ -58,7 +58,7 @@ using tvm::runtime::NDArray; template -struct CutlassFP8ScaledBlockwiseGemmRunner { +struct CutlassFP8GroupwiseScaledGemmRunner { using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; @@ -149,53 +149,14 @@ struct CutlassFP8ScaledBlockwiseGemmRunner { template -void cutlass_fp8_blockwise_scaled_gemm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, - ElementBlockScale* scales_b, ElementD* out, - uint8_t* workspace, int64_t workspace_size, int64_t m, - int64_t n, int64_t k, cudaStream_t stream) { +void cutlass_fp8_groupwise_scaled_mm_sm90(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementD* out, + uint8_t* workspace, int64_t workspace_size, int64_t m, + int64_t n, int64_t k, int64_t l, cudaStream_t stream) { if (k > 3 * n) { using SchedulerType = cutlass::gemm::StreamKScheduler; using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; - using StrideA = typename Runner::StrideA; - using StrideB = typename Runner::StrideB; - using StrideD = typename Runner::StrideD; - - Runner runner; - StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); - StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); - StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); - ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), 1}; - runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, - workspace, workspace_size, stream); - } else { - using SchedulerType = cutlass::gemm::PersistentScheduler; - using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; - using StrideA = typename Runner::StrideA; - using StrideB = typename Runner::StrideB; - using StrideD = typename Runner::StrideD; - - Runner runner; - StrideA stride_a = cute::make_stride(k, Int<1>{}, m * k); - StrideB stride_b = cute::make_stride(k, Int<1>{}, n * k); - StrideD stride_d = cute::make_stride(n, Int<1>{}, m * n); - ProblemShape problem_size{static_cast(m), static_cast(n), static_cast(k), 1}; - runner.run_gemm(a, b, scales_a, scales_b, out, &problem_size, &stride_a, &stride_b, &stride_d, - workspace, workspace_size, stream); - } -} - -template -void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScale* scales_a, - ElementBlockScale* scales_b, ElementD* out, - uint8_t* workspace, int64_t workspace_size, int64_t m, - int64_t n, int64_t k, int64_t l, cudaStream_t stream) { - if (k > 3 * n) { - using SchedulerType = cutlass::gemm::StreamKScheduler; - using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; + CutlassFP8GroupwiseScaledGemmRunner; using StrideA = typename Runner::StrideA; using StrideB = typename Runner::StrideB; using StrideD = typename Runner::StrideD; @@ -211,7 +172,7 @@ void cutlass_fp8_blockwise_scaled_bmm(ElementA* a, ElementB* b, ElementBlockScal } else { using SchedulerType = cutlass::gemm::PersistentScheduler; using Runner = - CutlassFP8ScaledBlockwiseGemmRunner; + CutlassFP8GroupwiseScaledGemmRunner; using StrideA = typename Runner::StrideA; using StrideB = typename Runner::StrideB; using StrideD = typename Runner::StrideD; diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu new file mode 100644 index 000000000000..ffa3ae6653e6 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "fp8_groupwise_scaled_gemm.cuh" +#include "fp8_groupwise_scaled_gemm_runner_sm100.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace tvm { +namespace runtime { + +template +struct CutlassFP8GroupwiseGemm<100, TileShape, ClusterShape, ElementA, ElementB, ElementC, + ElementBlockScale> { + static void run(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementC* out, uint8_t* workspace, + int64_t workspace_size, int64_t m, int64_t n, int64_t k, int64_t l, + cudaStream_t stream) { + cutlass_fp8_groupwise_scaled_mm_sm100( + a, b, scales_a, scales_b, out, workspace, workspace_size, m, n, k, l, stream); + } +}; + +void tvm_cutlass_fp8_groupwise_scaled_gemm_sm100(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_gemm_impl<100, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +void tvm_cutlass_fp8_groupwise_scaled_bmm_sm100(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_bmm_impl<100, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm100); +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm100); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu new file mode 100644 index 000000000000..e445e97da364 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../cublas/cublas_utils.h" +#include "fp8_groupwise_scaled_gemm.cuh" +#include "fp8_groupwise_scaled_gemm_runner_sm90.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +namespace tvm { +namespace runtime { + +template +struct CutlassFP8GroupwiseGemm<90, TileShape, ClusterShape, ElementA, ElementB, ElementC, + ElementBlockScale> { + static void run(ElementA* a, ElementB* b, ElementBlockScale* scales_a, + ElementBlockScale* scales_b, ElementC* out, uint8_t* workspace, + int64_t workspace_size, int64_t m, int64_t n, int64_t k, int64_t l, + cudaStream_t stream) { + cutlass_fp8_groupwise_scaled_mm_sm90( + a, b, scales_a, scales_b, out, workspace, workspace_size, m, n, k, l, stream); + } +}; + +void tvm_cutlass_fp8_groupwise_scaled_gemm_sm90(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_gemm_impl<90, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +void tvm_cutlass_fp8_groupwise_scaled_bmm_sm90(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; + tvm_cutlass_fp8_groupwise_scaled_bmm_impl<90, TileShape, ClusterShape>( + a, b, scales_a, scales_b, workspace, block_size_0, block_size_1, out); +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_gemm_sm90); +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn") + .set_body_typed(tvm_cutlass_fp8_groupwise_scaled_bmm_sm90); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh new file mode 100644 index 000000000000..19c6b699aa95 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + CHECK(error == cutlass::Status::kSuccess) \ + << "Got cutlass error: " << cutlassGetStatusString(error); \ + } + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +inline size_t aligned(size_t value, size_t alignment = 16) { + return (value + alignment - 1) / alignment * alignment; +} + +template +struct CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100 { + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = float; + using ElementCompute = float; + + static constexpr int ScaleGranularityM = 1; + static constexpr int ScaleGranularityN = 128; + static constexpr int ScaleGranularityK = 128; + using ScaleConfig = + cutlass::detail::Sm100BlockwiseScaleConfig; + + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise2SmSm100; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, + LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule>::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementA, + cute::tuple, AlignmentA, ElementB, cute::tuple, + AlignmentB, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, + const ElementBlockScale** ptr_scales_a, + const ElementBlockScale** ptr_scales_b, const ElementC** ptr_C, + ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, + StrideA* stride_A, StrideB* stride_B, LayoutSFA* layout_scales_a, + LayoutSFB* layout_scales_b, StrideC* stride_C, StrideD* stride_D, + uint8_t* workspace, int64_t workspace_size, int num_groups, + cudaStream_t stream) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_groups, problem_sizes, problem_sizes_host}, + {ptr_A, stride_A, ptr_B, stride_B, ptr_scales_a, + layout_scales_a, ptr_scales_b, layout_scales_b}, + {{}, ptr_C, stride_C, ptr_D, stride_D}, + hw_info}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha = 1.0f; + fusion_args.beta = 0.0f; + + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run(stream)); + } +}; + +template +__global__ void prepare_group_gemm_arguments( + const ElementA** ptr_A, const ElementB** ptr_B, const ElementBlockScale** ptr_scales_a, + const ElementBlockScale** ptr_scales_b, ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, + StrideB* stride_B, LayoutSFA* layout_scales_a, LayoutSFB* layout_scales_b, StrideC* stride_D, + const ElementA* a, const ElementB* b, const ElementBlockScale* scales_a, + const ElementBlockScale* scales_b, ElementC* out, int64_t* indptr, int64_t n, int64_t k, + int num_groups) { + int group_id = threadIdx.x; + if (group_id >= num_groups) return; + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; + ptr_A[group_id] = a + prev_rows * k; + ptr_B[group_id] = b + group_id * k * n; + ptr_D[group_id] = out + prev_rows * n; + ptr_scales_a[group_id] = scales_a + prev_rows * ((k + 127) / 128); + ptr_scales_b[group_id] = scales_b + group_id * ((k + 127) / 128) * ((n + 127) / 128); + int64_t m = indptr[group_id] - prev_rows; + problem_sizes[group_id] = {static_cast(m), static_cast(n), static_cast(k)}; + stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{}); + layout_scales_a[group_id] = ScaleConfig::tile_atom_to_shape_SFA( + make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); + layout_scales_b[group_id] = ScaleConfig::tile_atom_to_shape_SFB( + make_shape(static_cast(m), static_cast(n), static_cast(k), 1)); +} + +template +void cutlass_fp8_groupwise_scaled_group_gemm_sm100( + ElementA* a, ElementB* b, const ElementBlockScale* scales_a, const ElementBlockScale* scales_b, + int64_t* indptr, uint8_t* workspace, int64_t workspace_size, int64_t n, int64_t k, + int64_t num_groups, ElementC* out, cudaStream_t stream) { + using TileShape = Shape<_256, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; + using Runner = + CutlassFP8ScaledGroupwiseGroupGemmRunnerSM100; + using ScaleConfig = typename Runner::ScaleConfig; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + using LayoutSFA = typename Runner::LayoutSFA; + using LayoutSFB = typename Runner::LayoutSFB; + + Runner runner; + std::ptrdiff_t offset = 0; + const ElementA** ptr_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementA*) * num_groups); + const ElementB** ptr_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementB*) * num_groups); + const ElementBlockScale** ptr_scales_a = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementBlockScale*) * num_groups); + const ElementBlockScale** ptr_scales_b = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementBlockScale*) * num_groups); + ElementC** ptr_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementC*) * num_groups); + typename ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); + StrideA* stride_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideA) * num_groups); + StrideB* stride_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideB) * num_groups); + StrideC* stride_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideC) * num_groups); + LayoutSFA* layout_scales_a = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(LayoutSFA) * num_groups); + LayoutSFB* layout_scales_b = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(LayoutSFB) * num_groups); + prepare_group_gemm_arguments + <<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_scales_a, ptr_scales_b, ptr_D, problem_sizes, + stride_A, stride_B, layout_scales_a, layout_scales_b, stride_D, + a, b, scales_a, scales_b, out, indptr, n, k, num_groups); + offset = aligned(offset, 256); + runner.run_group_gemm(ptr_A, ptr_B, ptr_scales_a, ptr_scales_b, + const_cast(ptr_D), ptr_D, problem_sizes, nullptr, + stride_A, stride_B, layout_scales_a, layout_scales_b, stride_D, stride_D, + workspace + offset, workspace_size - offset, num_groups, stream); +} diff --git a/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu new file mode 100644 index 000000000000..d13481e9dd3f --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "fp8_groupwise_scaled_group_gemm_runner_sm100.cuh" + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +namespace tvm { +namespace runtime { + +void tvm_fp8_groupwise_scaled_group_gemm_sm100(NDArray a, NDArray b, NDArray scales_a, + NDArray scales_b, NDArray indptr, NDArray workspace, + int64_t block_size_0, int64_t block_size_1, + NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommended size is 4MB. + static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream"); + cudaStream_t stream = static_cast(func().cast()); + CHECK_EQ(a->ndim, 2); + CHECK_EQ(b->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + int num_groups = b->shape[0]; + int n = b->shape[1]; + int k = b->shape[2]; + + CHECK_EQ(scales_a->ndim, a->ndim); + CHECK_EQ(scales_b->ndim, b->ndim); + // scales_a is row-major of (m, k / block_size) + CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_a->shape[1]); + CHECK_EQ(scales_a->shape[0], a->shape[0]); + // scales_b is col-major of (k / block_size, n / block_size) + CHECK_EQ(scales_b->shape[0], num_groups); + CHECK_EQ((n + block_size_0 - 1) / block_size_0, scales_b->shape[1]); + CHECK_EQ((k + block_size_1 - 1) / block_size_1, scales_b->shape[2]); + + using tvm::runtime::DataType; + CHECK_EQ(DataType(a->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(b->dtype), DataType::Float8E4M3FN()); + CHECK_EQ(DataType(scales_a->dtype), DataType::Float(32)); + CHECK_EQ(DataType(scales_b->dtype), DataType::Float(32)); + CHECK_EQ(DataType(indptr->dtype), DataType::Int(64)); + CHECK_EQ(DataType(workspace->dtype), DataType::UInt(8)); + + if (DataType(out->dtype) == DataType::Float(16)) { + using Dtype = cutlass::half_t; + cutlass_fp8_groupwise_scaled_group_gemm_sm100( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(out->data), stream); + } else if (DataType(out->dtype) == DataType::BFloat(16)) { + using Dtype = cutlass::bfloat16_t; + cutlass_fp8_groupwise_scaled_group_gemm_sm100( + static_cast(a->data), static_cast(b->data), + static_cast(scales_a->data), static_cast(scales_b->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(out->data), stream); + } +} + +TVM_FFI_REGISTER_GLOBAL("cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn") + .set_body_typed(tvm_fp8_groupwise_scaled_group_gemm_sm100); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_SM100_SUPPORTED diff --git a/src/target/tag.cc b/src/target/tag.cc index f6e2307b75e1..0df0d8d2c7af 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -161,6 +161,8 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) .with_config("l2_cache_size_bytes", 41943040); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536) .with_config("l2_cache_size_bytes", 52428800); +TVM_REGISTER_CUDA_TAG("nvidia/nvidia-b100", "sm_100a", 49152, 65536) + .with_config("l2_cache_size_bytes", 52428800); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); diff --git a/tests/python/contrib/test_cutlass_gemm.py b/tests/python/contrib/test_cutlass_gemm.py index 7c259e6f7d6d..33f7ef1160a1 100644 --- a/tests/python/contrib/test_cutlass_gemm.py +++ b/tests/python/contrib/test_cutlass_gemm.py @@ -44,8 +44,8 @@ def verify_group_gemm( def get_ref_data(): assert M % num_groups == 0 M_per_group = M // num_groups - a_np = get_random_ndarray((M, K), "float16") - b_np = get_random_ndarray((num_groups, N, K), "float16") + a_np = get_random_ndarray((M, K), x_dtype) + b_np = get_random_ndarray((num_groups, N, K), weight_dtype) indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group c_np = np.concatenate( [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)], @@ -76,7 +76,7 @@ def to_numpy_dtype(dtype): @tvm.testing.requires_cuda_compute_version(9) def test_group_gemm_sm90(): verify_group_gemm( - "cutlass.group_gemm_fp16_sm90", + "cutlass.group_gemm", 8, 128, 128, @@ -116,6 +116,24 @@ def test_group_gemm_sm90(): ) +@tvm.testing.requires_cutlass +@tvm.testing.requires_cuda_compute_version(10) +def test_group_gemm_sm100(): + verify_group_gemm( + "cutlass.group_gemm", + 8, + 128, + 128, + 4, + "bfloat16", + "bfloat16", + "bfloat16", + False, + rtol=1e-2, + atol=1e-3, + ) + + def rowwise_quant_fp8_e4m3(shape: Tuple[int, int], block_size: Tuple[int, int], dtype: str): x_full_np = (np.random.rand(*shape) * 2 - 1).astype(dtype) x_scale_shape = ( @@ -283,14 +301,14 @@ def blockwise_bmm( @tvm.testing.requires_cutlass @tvm.testing.requires_cuda_compute_version(9) -def test_fp8_e4m3_blockwise_scaled_gemm(): +def test_fp8_e4m3_groupwise_scaled_gemm(): M = 16 N = 4608 K = 896 block_size = (128, 128) assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128 - func_name = "cutlass.blockwise_scaled_gemm_e4m3fn_e4m3fn" + func_name = "cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn" gemm_func = tvm.get_global_func(func_name, allow_missing=True) if gemm_func is None: print(f"Skipped as {func_name} is not available") @@ -316,7 +334,7 @@ def test_fp8_e4m3_blockwise_scaled_gemm(): @tvm.testing.requires_cutlass @tvm.testing.requires_cuda_compute_version(9) -def test_fp8_e4m3_blockwise_scaled_bmm(): +def test_fp8_e4m3_groupwise_scaled_bmm(): B = 16 M = 40 N = 512 @@ -324,7 +342,7 @@ def test_fp8_e4m3_blockwise_scaled_bmm(): block_size = (128, 128) assert N % 128 == 0 and K % 128 == 0 # Only support N/K are multiple of 128 - func_name = "cutlass.blockwise_scaled_bmm_e4m3fn_e4m3fn" + func_name = "cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn" gemm_func = tvm.get_global_func(func_name, allow_missing=True) if gemm_func is None: print(f"Skipped as {func_name} is not available")