From 1c53a42d1f3bcd1f1e3b341b40c690135f8a2992 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Mon, 26 Aug 2024 20:16:02 +0000 Subject: [PATCH 1/8] upd upgrade cutlass to 3.5 upd upd upd upd upd upd upd upd upd upd upd --- include/flashinfer/gemm/group_gemm.cuh | 2 +- .../flashinfer/gemm/group_gemm_cutlass.cuh | 57 ++- include/flashinfer/gemm/group_gemm_sm90.cuh | 348 ++++++++++++++++++ python/csrc/group_gemm.cu | 3 +- python/csrc/group_gemm_sm90.cu | 67 ++++ python/flashinfer/gemm.py | 31 +- tests/test_group_gemm.py | 3 +- 7 files changed, 483 insertions(+), 28 deletions(-) create mode 100644 include/flashinfer/gemm/group_gemm_sm90.cuh create mode 100644 python/csrc/group_gemm_sm90.cu diff --git a/include/flashinfer/gemm/group_gemm.cuh b/include/flashinfer/gemm/group_gemm.cuh index 968662f9..a4259a40 100644 --- a/include/flashinfer/gemm/group_gemm.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -116,4 +116,4 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe } // namespace flashinfer -#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_ +#endif // FLASHINFER_GEMM_GROUP_GEMM_CUH_ \ No newline at end of file diff --git a/include/flashinfer/gemm/group_gemm_cutlass.cuh b/include/flashinfer/gemm/group_gemm_cutlass.cuh index a3422bef..0f71fa3d 100644 --- a/include/flashinfer/gemm/group_gemm_cutlass.cuh +++ b/include/flashinfer/gemm/group_gemm_cutlass.cuh @@ -16,11 +16,16 @@ #ifndef FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ #define FLASHINFER_GROUP_GEMM_CUTLASS_CUH_ +#include +#include +#include + #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm_grouped.h" #include "cutlass/gemm/kernel/default_gemm_grouped.h" #include "cutlass/layout/matrix.h" #include "cutlass/numeric_types.h" +#include "cutlass/util/packed_stride.hpp" namespace flashinfer { @@ -41,21 +46,49 @@ struct cutlass_dtype { using type = cutlass::bfloat16_t; }; -template -__global__ void compute_cutlass_group_gemm_args(cutlass::gemm::GemmCoord* all_problems, T** ptr_x, - T** ptr_w, T** ptr_y, int64_t* ld_x, int64_t* ld_w, - int64_t* ld_y, T* x, T* w, T* y, int64_t* xy_indptr, - int64_t* w_indices, size_t d_in, size_t d_out, - bool w_column_major) { +template <> +struct cutlass_dtype<__nv_fp8_e4m3> { + using type = cutlass::float_e4m3_t; +}; + +template <> +struct cutlass_dtype<__nv_fp8_e5m2> { + using type = cutlass::float_e5m2_t; +}; + +template +__global__ void compute_sm80_cutlass_group_gemm_args( + cutlass::gemm::GemmCoord* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr, + int64_t* x_ld, int64_t* w_ld, int64_t* y_ld, DTypeIn* x, DTypeIn* w, DTypeOut* y, + int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { int i = blockIdx.x; int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; all_problems[i] = cutlass::gemm::GemmCoord(m, n, k); - ptr_w[i] = w + (w_indices == nullptr ? i : w_indices[i]) * d_in * d_out; - ptr_x[i] = x + xy_indptr[i] * d_in; - ptr_y[i] = y + xy_indptr[i] * d_out; - ld_x[i] = k; // m * k - ld_w[i] = w_column_major ? k : n; // k * n if column major, n * k if row major - ld_y[i] = n; // m * n + w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n; + x_ptr[i] = x + xy_indptr[i] * k; + y_ptr[i] = y + xy_indptr[i] * n; + x_ld[i] = k; // m * k + w_ld[i] = w_column_major ? k : n; // k * n if column major, n * k if row major + y_ld[i] = n; // m * n +} + +template +__global__ void compute_sm90_cutlass_group_gemm_args( + ProblemShape* all_problems, DTypeIn** x_ptr, DTypeIn** w_ptr, DTypeOut** y_ptr, + StrideA* x_stride, StrideB* w_stride, StrideCD* y_stride, DTypeIn* x, DTypeIn* w, DTypeOut* y, + int64_t* xy_indptr, int64_t* w_indices, size_t d_in, size_t d_out, bool w_column_major) { + int i = blockIdx.x; + int m = xy_indptr[i + 1] - xy_indptr[i], k = d_in, n = d_out; + all_problems[i] = ProblemShape(m, n, k); + w_ptr[i] = w + (w_indices == nullptr ? i : w_indices[i]) * k * n; + x_ptr[i] = x + xy_indptr[i] * k; + y_ptr[i] = y + xy_indptr[i] * n; + + x_stride[i] = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + w_stride[i] = w_column_major ? cutlass::make_cute_packed_stride(StrideB{}, {k, n, 1}) + : cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + y_stride[i] = cutlass::make_cute_packed_stride(StrideCD{}, {m, n, 1}); } } // namespace group_gemm diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh new file mode 100644 index 00000000..337e0098 --- /dev/null +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -0,0 +1,348 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ +#define FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ + +#include + +#include "../allocator.h" +#include "../utils.cuh" +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "handler.cuh" + +namespace flashinfer { + +namespace group_gemm { + +using namespace cute; + +#define DISPATCH_WEIGHT_LAYOUT(is_column_major, WEIGHT_LAYOUT, ...) \ + if (is_column_major) { \ + using WEIGHT_LAYOUT = cutlass::layout::ColumnMajor; \ + __VA_ARGS__ \ + } else { \ + using WEIGHT_LAYOUT = cutlass::layout::RowMajor; \ + __VA_ARGS__ \ + } + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template +cudaError_t CutlassSegmentGEMMWrapper_SM80(CutlassSegmentGEMMHandler* handler, DTypeIn* x, + DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, + int64_t* w_indices_d, unsigned int batch_size, + unsigned int d_in, unsigned int d_out, + bool weight_column_major, cudaStream_t stream) { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first < 8) { + std::cerr << "CutlassSegmentGEMMWrapper_SM80 requires compute capability 8.x" << std::endl; + return cudaErrorNotSupported; + } else { + if constexpr (sizeof(DTypeIn) != 2) { + std::cerr + << "CutlassSegmentGEMMWrapper requires fp16/bf16 data type for compute capability 8.x" + << std::endl; + return cudaErrorNotSupported; + } else { + // SM80 grouped gemm + AlignedAllocator allocator(handler->GetIntWorkspace(), handler->GetIntWorkspaceSizeInBytes()); + cutlass::gemm::GemmCoord* problem_sizes_device = + allocator.aligned_alloc( + batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device"); + DTypeIn** x_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "x_data"); + DTypeIn** w_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "w_data"); + DTypeOut** y_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeOut*), 16, "y_data"); + int64_t* ld_x = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_x"); + int64_t* ld_w = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_w"); + int64_t* ld_y = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_y"); + + // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, + // so I just use the kernel function directly, need to investigate more. + auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args; + compute_args_kernel<<>>( + problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DTypeIn*)x, (DTypeIn*)w, + (DTypeOut*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "Failed to launch compute_sm80_cutlass_group_gemm_args kernel: " + << cudaGetErrorString(err) << std::endl; + return err; + } + + using cutlass::epilogue::thread::LinearCombination; + using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; + DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + DTypeIn, // Element A + cutlass::layout::RowMajor, // Layout A + cutlass::ComplexTransform::kNone, // + 8, // Granularity A + DTypeIn, // Element B + WEIGHT_LAYOUT, // Layout B + cutlass::ComplexTransform::kNone, // + 8, // Granularity B + DTypeOut, // Element C&D + cutlass::layout::RowMajor, // Layout C&D + float, // Element Accumulator + cutlass::arch::OpClassTensorOp, // Operator Class Tag + cutlass::arch::Sm80, // Architecture + cutlass::gemm::GemmShape<128, 128, 32>, // Thread Block Shape + cutlass::gemm::GemmShape<64, 64, 32>, // Warp Shape + cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape + cutlass::epilogue::thread::LinearCombination, // Epilogue + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling + // Operator + 8 // Stages + >::GemmKernel; + + using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + typename GemmGrouped::Arguments args(problem_sizes_device, batch_size, 4, epilogue_op, + x_data, w_data, y_data, y_data, ld_x, ld_w, ld_y, + ld_y); + + GemmGrouped gemm; + auto status = gemm.initialize(args, nullptr, stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + status = gemm.run(stream); + if (status != cutlass::Status::kSuccess) { + std::ostringstream err_msg; + err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status); + throw std::runtime_error(err_msg.str()); + } + }); + } + } + return cudaSuccess; +} + +template +cudaError_t CutlassSegmentGEMMWrapper_SM90( + void* float_buffer, size_t float_buffer_size_in_bytes, + void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x, + DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, + int64_t* w_indices_d, unsigned int batch_size, + unsigned int d_in, unsigned int d_out, + bool weight_column_major, cudaStream_t stream) { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first < 9) { + std::cerr << "CutlassSegmentGEMMWrapper_SM90 requires compute capability of at least 9.0" + << std::endl; + return cudaErrorNotSupported; + } else { + // Compute capability >= 9.0 + // Reference implementation + // - + // https://github.com/NVIDIA/cutlass/blob/f7b19de32c5d1f3cedfc735c2849f12b537522ee/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu + using ProblemShape = + cutlass::gemm::GroupProblemShape>; // per group + using ElementA = DTypeIn; // Element type for A matrix operand + using ElementB = DTypeIn; // Element type for B matrix operand + using ElementC = DTypeOut; // Element type for C and D matrix operands + + DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { + if constexpr (std::is_same_v && + sizeof(DTypeIn) == 1) { + std::ostringstream err_msg; + err_msg << "Row-major layout is not supported for fp8 data type"; + throw std::runtime_error(err_msg.str()); + } else { + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of + // elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = WEIGHT_LAYOUT; // Layout type for B matrix operand + constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of + // elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands + constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of + // elements (up to 16 bytes) + + constexpr bool is_fp8 = sizeof(DTypeIn) == 1; + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the + // intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = + typename std::conditional, + Shape<_128, _64, _64>>::type; // Threadblock-level tile size + using ClusterShape = + typename std::conditional, Shape<_2, _1, _1>>:: + type; // Shape of the threadblocks in a cluster + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = typename std::conditional< + is_fp8, cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum, + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative>::type; // Kernel to launch + using EpilogueSchedule = + cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using DeviceGemmReference = + cutlass::reference::device::Gemm; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + AlignedAllocator allocator(int_buffer, + int_buffer_size_in_bytes); + ProblemShape::UnderlyingProblemShape* problem_sizes_device = + allocator.aligned_alloc( + batch_size * sizeof(ProblemShape::UnderlyingProblemShape), 16, + "problem_sizes_device"); + DTypeIn** x_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "x_data"); + DTypeIn** w_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "w_data"); + DTypeOut** y_data = + allocator.aligned_alloc(batch_size * sizeof(DTypeOut*), 16, "y_data"); + StrideA* x_stride = + allocator.aligned_alloc(batch_size * sizeof(StrideA), 16, "x_stride"); + StrideB* w_stride = + allocator.aligned_alloc(batch_size * sizeof(StrideB), 16, "w_stride"); + StrideC* y_stride = + allocator.aligned_alloc(batch_size * sizeof(StrideC), 16, "y_stride"); + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename Gemm::EpilogueOutputOp::Params params; + // TODO(Zihao): support block alpha and beta + params = typename Gemm::EpilogueOutputOp::Params(/*alpha=*/ElementAccumulator(1.f), + /*beta=*/ElementAccumulator(0.f)); + + typename Gemm::Arguments arguments; + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {int(batch_size), problem_sizes_device, nullptr}, + {const_cast(x_data), x_stride, const_cast(w_data), + w_stride}, + {params, const_cast(y_data), y_stride, y_data, y_stride}, + hw_info}; + + compute_sm90_cutlass_group_gemm_args<<>>( + problem_sizes_device, x_data, w_data, y_data, x_stride, w_stride, y_stride, (DTypeIn*)x, + (DTypeIn*)w, (DTypeOut*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "Failed to launch compute_sm90_cutlass_group_gemm_args kernel: " + << cudaGetErrorString(err) << std::endl; + return err; + } + + // Initialize the gemm kernel + Gemm gemm; + + // Using the arguments, query for extra workspace required for matrix multiplication + // computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + AlignedAllocator float_allocator(float_buffer, + float_buffer_size_in_bytes); + auto workspace_ptr = float_allocator.aligned_alloc(workspace_size, 64, + "sm90_group_gemm_float_workspace"); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace_ptr)); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); // Warmup + } + }); + } + + return cudaSuccess; +} + +} // namespace group_gemm + +} // namespace flashinfer + +#endif // FLASHINFER_GEMM_GROUP_GEMM_SM90_CUH_ diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index 7954f533..528b73a2 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -48,7 +48,8 @@ torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor s weight_indices = weight_indices.to(torch::kInt64); } - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { + // TODO(Zihao): add fp8 support + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; auto status = CutlassSegmentGEMMRun( workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), diff --git a/python/csrc/group_gemm_sm90.cu b/python/csrc/group_gemm_sm90.cu new file mode 100644 index 00000000..be82d1ce --- /dev/null +++ b/python/csrc/group_gemm_sm90.cu @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer::group_gemm; + +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major) { + // TODO(Zihao): Add more checks here + CHECK_INPUT(seg_indptr); + CHECK_INPUT(x); + CHECK_INPUT(weight); + auto device = x.device(); + CHECK_EQ(seg_indptr.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, x); // x: [sum(m_i), d_in] + CHECK_DIM(3, weight); // weight: [num_weights, d_out, d_in] if weight_column_major, [num_weights, + // d_in, d_out] otherwise + int64_t cumulative_batch_size = x.size(0); + int64_t d_out = weight_column_major ? weight.size(1) : weight.size(2); + int64_t d_in = weight_column_major ? weight.size(2) : weight.size(1); + CHECK_EQ(x.size(1), d_in); + auto y = torch::zeros({cumulative_batch_size, d_out}, x.options()); + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + seg_indptr = seg_indptr.to(torch::kInt64); + + bool weight_indices_defined = weight_indices.numel() > 0; + if (weight_indices_defined) { + CHECK_INPUT(weight_indices); + CHECK_EQ(weight_indices.device(), device); + weight_indices = weight_indices.to(torch::kInt64); + } + + // TODO(Zihao): add fp8 support + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + using cutlass_t = typename cutlass_dtype::type; + auto status = CutlassSegmentGEMMSM90Run( + float_workspace_buffer.data_ptr(), float_workspace_buffer.element_size() * float_workspace_buffer.size(0), + int_workspace_buffer.data_ptr(), int_workspace_buffer.element_size() * int_workspace_buffer.size(0), + static_cast(x.data_ptr()), static_cast(weight.data_ptr()), + static_cast(y.data_ptr()), static_cast(seg_indptr.data_ptr()), + weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, + batch_size, d_in, d_out, weight_column_major, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "Failed to run CutlassSegmentGEMM: ", cudaGetErrorString(status)); + return true; + }); + + return y; +} \ No newline at end of file diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index e7e7a051..7a2e73ee 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -53,7 +53,7 @@ class SegmentGEMMWrapper: >>> import torch >>> from flashinfer import SegmentGEMMWrapper >>> # create a 1MB workspace buffer - >>> workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") + >>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") >>> segment_gemm = SegmentGEMMWrapper(workspace_buffer) >>> seq_lens = torch.tensor([1, 2, 3, 4], dtype=torch.int64, device="cuda") >>> # create packed input tensor (10 = 1 + 2 + 3 + 4) @@ -96,27 +96,34 @@ class SegmentGEMMWrapper: True """ - def __init__(self, workspace_buffer: torch.Tensor) -> None: + def __init__(self, float_workspace_buffer: torch.Tensor) -> None: r"""Initialize the wrapper. Parameters ---------- - workspace_buffer : torch.Tensor - The workspace buffer for the kernels, we use it to store the metadata for the segment GEMM whose - size is proportional to the number of segments (batch size), 1MB workspace is enough for most cases. + float_workspace_buffer : torch.Tensor + The workspace buffer for the kernels, we use it for storing intermediate results in cutlass + segment GEMM kernels. Encouraged size is 128MB. """ - self._workspace_buffer = workspace_buffer + self._int_workspace_buffer = torch.empty( + (1024 * 1024,), dtype=torch.int8, device=float_workspace_buffer.device + ) + self._float_workspace_buffer = float_workspace_buffer - def reset_workspace_buffer(self, new_workspace_buffer: torch.Tensor) -> None: + def reset_workspace_buffer( + self, float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor + ) -> None: r"""Reset the workspace buffer. Parameters ---------- - new_workspace_buffer : torch.Tensor - The new workspace buffer, the device of the new workspace buffer should - be the same as the device of the input tensors. + float_workspace_buffer : torch.Tensor + The new float workspace buffer for the kernels. + int_workspace_buffer : torch.Tensor + The new int workspace buffer for the kernels. """ - self._workspace_buffer = new_workspace_buffer + self._float_workspace_buffer = float_workspace_buffer + self._int_workspace_buffer = int_workspace_buffer def run( self, @@ -194,7 +201,7 @@ def run( # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) return get_gemm_module().cutlass_segment_gemm( - self._workspace_buffer, + self._int_workspace_buffer, seg_indptr, weight_indices, x, diff --git a/tests/test_group_gemm.py b/tests/test_group_gemm.py index 96c48fb8..fb35c0a3 100644 --- a/tests/test_group_gemm.py +++ b/tests/test_group_gemm.py @@ -84,8 +84,7 @@ def test_segment_gemm( ), ), rtol=1e-3, - atol=1e-3, - msg="assertion failed at batch {}".format(i), + atol=1e-3 ) else: torch.testing.assert_close( From e725cec360b92def4585f870316aced641aff8ae Mon Sep 17 00:00:00 2001 From: xsling Date: Mon, 7 Oct 2024 22:39:32 +0000 Subject: [PATCH 2/8] onlyl fp16 for sm80 --- python/csrc/group_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index 528b73a2..1d941b43 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -49,7 +49,7 @@ torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor s } // TODO(Zihao): add fp8 support - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; auto status = CutlassSegmentGEMMRun( workspace_buffer.data_ptr(), workspace_buffer.element_size() * workspace_buffer.size(0), From 331b8c9ad680faa44aec8d1989d33f410e7abea7 Mon Sep 17 00:00:00 2001 From: xsling Date: Mon, 7 Oct 2024 23:03:35 +0000 Subject: [PATCH 3/8] upd --- flashinfer-aot/setup.py | 1 + include/flashinfer/gemm/group_gemm.cuh | 2 +- include/flashinfer/gemm/group_gemm_sm90.cuh | 98 --------------------- python/csrc/group_gemm.cu | 1 - python/flashinfer/jit/__init__.py | 5 +- python/flashinfer/jit/env.py | 5 +- 6 files changed, 8 insertions(+), 104 deletions(-) diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py index 396cf334..3cc15c27 100644 --- a/flashinfer-aot/setup.py +++ b/flashinfer-aot/setup.py @@ -355,6 +355,7 @@ def __init__(self, *args, **kwargs) -> None: include_dirs = [ str(root.resolve() / "include"), str(root.resolve() / "3rdparty" / "cutlass" / "include"), # for group gemm + str(root.resolve() / "3rdparty" / "cutlass" / "tools" / "util" / "include"), ] extra_compile_args = { "cxx": [ diff --git a/include/flashinfer/gemm/group_gemm.cuh b/include/flashinfer/gemm/group_gemm.cuh index a4259a40..20fca551 100644 --- a/include/flashinfer/gemm/group_gemm.cuh +++ b/include/flashinfer/gemm/group_gemm.cuh @@ -53,7 +53,7 @@ cudaError_t CutlassSegmentGEMMRun(void* workspace_buffer, size_t workspace_buffe // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, // so I just use the kernel function directly, need to investigate more. - auto compute_args_kernel = compute_cutlass_group_gemm_args; + auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args; compute_args_kernel<<>>( problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DType*)x, (DType*)w, (DType*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index 337e0098..b60c9e3b 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -41,7 +41,6 @@ #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" -#include "handler.cuh" namespace flashinfer { @@ -71,103 +70,6 @@ using namespace cute; } \ } -template -cudaError_t CutlassSegmentGEMMWrapper_SM80(CutlassSegmentGEMMHandler* handler, DTypeIn* x, - DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, - int64_t* w_indices_d, unsigned int batch_size, - unsigned int d_in, unsigned int d_out, - bool weight_column_major, cudaStream_t stream) { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first < 8) { - std::cerr << "CutlassSegmentGEMMWrapper_SM80 requires compute capability 8.x" << std::endl; - return cudaErrorNotSupported; - } else { - if constexpr (sizeof(DTypeIn) != 2) { - std::cerr - << "CutlassSegmentGEMMWrapper requires fp16/bf16 data type for compute capability 8.x" - << std::endl; - return cudaErrorNotSupported; - } else { - // SM80 grouped gemm - AlignedAllocator allocator(handler->GetIntWorkspace(), handler->GetIntWorkspaceSizeInBytes()); - cutlass::gemm::GemmCoord* problem_sizes_device = - allocator.aligned_alloc( - batch_size * sizeof(cutlass::gemm::GemmCoord), 16, "problem_sizes_device"); - DTypeIn** x_data = - allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "x_data"); - DTypeIn** w_data = - allocator.aligned_alloc(batch_size * sizeof(DTypeIn*), 16, "w_data"); - DTypeOut** y_data = - allocator.aligned_alloc(batch_size * sizeof(DTypeOut*), 16, "y_data"); - int64_t* ld_x = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_x"); - int64_t* ld_w = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_w"); - int64_t* ld_y = allocator.aligned_alloc(batch_size * sizeof(int64_t), 16, "ld_y"); - - // NOTE(Zihao): I didn't successfully launch the kernel with cudaLaunchKernel API, - // so I just use the kernel function directly, need to investigate more. - auto compute_args_kernel = compute_sm80_cutlass_group_gemm_args; - compute_args_kernel<<>>( - problem_sizes_device, x_data, w_data, y_data, ld_x, ld_w, ld_y, (DTypeIn*)x, (DTypeIn*)w, - (DTypeOut*)y, xy_indptr_d, w_indices_d, d_in, d_out, weight_column_major); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - std::cerr << "Failed to launch compute_sm80_cutlass_group_gemm_args kernel: " - << cudaGetErrorString(err) << std::endl; - return err; - } - - using cutlass::epilogue::thread::LinearCombination; - using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle; - DISPATCH_WEIGHT_LAYOUT(weight_column_major, WEIGHT_LAYOUT, { - using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - DTypeIn, // Element A - cutlass::layout::RowMajor, // Layout A - cutlass::ComplexTransform::kNone, // - 8, // Granularity A - DTypeIn, // Element B - WEIGHT_LAYOUT, // Layout B - cutlass::ComplexTransform::kNone, // - 8, // Granularity B - DTypeOut, // Element C&D - cutlass::layout::RowMajor, // Layout C&D - float, // Element Accumulator - cutlass::arch::OpClassTensorOp, // Operator Class Tag - cutlass::arch::Sm80, // Architecture - cutlass::gemm::GemmShape<128, 128, 32>, // Thread Block Shape - cutlass::gemm::GemmShape<64, 64, 32>, // Warp Shape - cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape - cutlass::epilogue::thread::LinearCombination, // Epilogue - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, // Swizzling - // Operator - 8 // Stages - >::GemmKernel; - - using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp; - typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0); - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - typename GemmGrouped::Arguments args(problem_sizes_device, batch_size, 4, epilogue_op, - x_data, w_data, y_data, y_data, ld_x, ld_w, ld_y, - ld_y); - - GemmGrouped gemm; - auto status = gemm.initialize(args, nullptr, stream); - if (status != cutlass::Status::kSuccess) { - std::ostringstream err_msg; - err_msg << "cutlass group_gemm.initialize failed: " << cutlassGetStatusString(status); - throw std::runtime_error(err_msg.str()); - } - status = gemm.run(stream); - if (status != cutlass::Status::kSuccess) { - std::ostringstream err_msg; - err_msg << "cutlass group_gemm.run failed: " << cutlassGetStatusString(status); - throw std::runtime_error(err_msg.str()); - } - }); - } - } - return cudaSuccess; -} - template cudaError_t CutlassSegmentGEMMWrapper_SM90( void* float_buffer, size_t float_buffer_size_in_bytes, diff --git a/python/csrc/group_gemm.cu b/python/csrc/group_gemm.cu index 1d941b43..7954f533 100644 --- a/python/csrc/group_gemm.cu +++ b/python/csrc/group_gemm.cu @@ -48,7 +48,6 @@ torch::Tensor CutlassSegmentGEMM(torch::Tensor workspace_buffer, torch::Tensor s weight_indices = weight_indices.to(torch::kInt64); } - // TODO(Zihao): add fp8 support DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; auto status = CutlassSegmentGEMMRun( diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index bbd85729..2f9acc0e 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -25,7 +25,7 @@ FLASHINFER_GEN_SRC_DIR, FLASHINFER_INCLUDE_DIR, FLASHINFER_CSRC_DIR, - CUTLASS_INCLUDE_DIR, + CUTLASS_INCLUDE_DIRS, ) from .activation import get_act_and_mul_cu_str, gen_act_and_mul_cu from .attention import ( @@ -135,9 +135,8 @@ def load_cuda_ops( if extra_include_paths is None: extra_include_paths = [ FLASHINFER_INCLUDE_DIR, - CUTLASS_INCLUDE_DIR, FLASHINFER_CSRC_DIR, - ] + ] + CUTLASS_INCLUDE_DIRS return torch_cpp_ext.load( name, list(map(lambda _: str(_), sources)), diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py index 65b47eed..c51bb8bb 100644 --- a/python/flashinfer/jit/env.py +++ b/python/flashinfer/jit/env.py @@ -23,4 +23,7 @@ _project_root = pathlib.Path(__file__).resolve().parent.parent.parent FLASHINFER_INCLUDE_DIR = _project_root / "include" FLASHINFER_CSRC_DIR = _project_root / "csrc" -CUTLASS_INCLUDE_DIR = _project_root / "3rdparty" / "cutlass" / "include" +CUTLASS_INCLUDE_DIRS = [ + _project_root / "3rdparty" / "cutlass" / "include", + _project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include" +] From 9558b9ec5262fa76fb8fcc7057699b678cf414fc Mon Sep 17 00:00:00 2001 From: xsling Date: Tue, 8 Oct 2024 06:34:43 +0000 Subject: [PATCH 4/8] upd --- .../csrc_aot/flashinfer_ops_sm90.cu | 155 ++++++++++++++++++ flashinfer-aot/setup.py | 15 +- include/flashinfer/gemm/group_gemm_sm90.cuh | 4 +- python/csrc/flashinfer_gemm_ops_sm90.cu | 31 ++++ python/flashinfer/gemm.py | 6 +- python/flashinfer/jit/__init__.py | 23 ++- 6 files changed, 225 insertions(+), 9 deletions(-) create mode 100644 flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu create mode 100644 python/csrc/flashinfer_gemm_ops_sm90.cu diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu new file mode 100644 index 00000000..0f8c1cbf --- /dev/null +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include + +void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, + torch::Tensor append_indptr, std::optional paged_kv_cache, + std::optional paged_k_cache, + std::optional paged_v_cache, torch::Tensor kv_indices, + torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, + unsigned int layout); + +std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, + torch::Tensor s_b); + +void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, + torch::Tensor s_other, std::optional mask = std::nullopt); + +std::vector merge_states(torch::Tensor v, torch::Tensor s); + +torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, + bool deterministic); + +std::vector top_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_p_arr, + double top_p_val, bool deterministic); + +std::vector top_k_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, + unsigned int top_k_val, bool deterministic); + +std::vector min_p_sampling_from_probs(torch::Tensor probs, + torch::Tensor uniform_samples, + std::optional maybe_min_p_arr, + double min_p_val, bool deterministic); + +std::vector top_k_top_p_sampling_from_probs( + torch::Tensor probs, torch::Tensor uniform_samples, + std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic); + +torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, + double top_p_val); + +torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, + unsigned int top_k_val); + +std::vector chain_speculative_sampling( + torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, + torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, + std::optional maybe_output_emitted_token_num, bool deterministic); + +void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); + +void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, + double eps); + +void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); + +void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, + double eps); + +void silu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); + +void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); + +void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); + +void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta, float low_freq_factor, float high_freq_factor, + float old_context_length); + +std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, + torch::Tensor offsets, bool interleave, float rope_scale, + float rope_theta); + +std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, + torch::Tensor indptr, torch::Tensor offsets, + bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, + float old_context_length); + +torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); + +torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, + torch::Tensor output_indptr, const std::string& bitorder); + +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, + torch::Tensor& A_scale, torch::Tensor& B_scale); + +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); + m.def("merge_state", &merge_state, "Merge two self-attention states"); + m.def("merge_state_in_place", &merge_state_in_place, + "Merge another self-attention state in-place."); + m.def("merge_states", &merge_states, "Merge multiple self-attention states"); + m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); + m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, + "Top-k sampling from probabilities"); + m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, + "Min-p sampling from probabilities"); + m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, + "Top-p sampling from probabilities"); + m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, + "Top-k and top-p sampling from probabilities"); + m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); + m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); + m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); + m.def("chain_speculative_sampling", &chain_speculative_sampling, + "Speculative sampling from sequence of probabilities"); + m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); + m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); + m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); + m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); + m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); + m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); + m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, + "Apply Llama 3.1 style RoPE in-place"); + m.def("apply_rope", &apply_rope, "Apply RoPE"); + m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); + m.def("packbits", &packbits, "GPU packbits operator"); + m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); + m.def("cutlass_segment_gemms_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); +} \ No newline at end of file diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py index 3cc15c27..0bad49e7 100644 --- a/flashinfer-aot/setup.py +++ b/flashinfer-aot/setup.py @@ -336,6 +336,18 @@ def remove_unwanted_pytorch_nvcc_flags(): except ValueError: pass +def get_gemm_src_files(): + cuda_major, _ = get_cuda_version() + if cuda_major < 9: + return [ + "csrc/group_gemm.cu", + "csrc_aot/flashinfer_ops.cu", + ] + else: + return [ + "csrc/group_gemm_sm90.cu", + "csrc_aot/flashinfer_ops_sm90.cu", + ] class NinjaBuildExtension(torch_cpp_ext.BuildExtension): def __init__(self, *args, **kwargs) -> None: @@ -386,8 +398,7 @@ def __init__(self, *args, **kwargs) -> None: "csrc/quantization.cu", "csrc/group_gemm.cu", "csrc/bmm_fp8.cu", - "csrc_aot/flashinfer_ops.cu", - ], + ] + get_gemm_src_files(), include_dirs=include_dirs, extra_compile_args=extra_compile_args, ) diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index b60c9e3b..d4907cac 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -71,7 +71,7 @@ using namespace cute; } template -cudaError_t CutlassSegmentGEMMWrapper_SM90( +cudaError_t CutlassSegmentGEMMSM90Run( void* float_buffer, size_t float_buffer_size_in_bytes, void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x, DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, @@ -80,7 +80,7 @@ cudaError_t CutlassSegmentGEMMWrapper_SM90( bool weight_column_major, cudaStream_t stream) { auto compute_capacity = GetCudaComputeCapability(); if (compute_capacity.first < 9) { - std::cerr << "CutlassSegmentGEMMWrapper_SM90 requires compute capability of at least 9.0" + std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0" << std::endl; return cudaErrorNotSupported; } else { diff --git a/python/csrc/flashinfer_gemm_ops_sm90.cu b/python/csrc/flashinfer_gemm_ops_sm90.cu new file mode 100644 index 00000000..4436228c --- /dev/null +++ b/python/csrc/flashinfer_gemm_ops_sm90.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, + torch::Tensor& A_scale, torch::Tensor& B_scale); + + +// (... Tensor x_arr, Tensor w_arr, Tensor y_arr, Tensor x_stride, Tensor weight_stride, Tensor y_stride, Tensor problem_shape ...) +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cutlass_segment_gemms_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); + m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); +} \ No newline at end of file diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 7a2e73ee..1d68ee89 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -19,7 +19,7 @@ import torch from .utils import get_indptr -from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from .jit import get_gemm_src_files, load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops from typing import Optional @@ -37,10 +37,8 @@ def get_gemm_module(): _gemm_module = load_cuda_ops( "gemm", [ - FLASHINFER_CSRC_DIR / "group_gemm.cu", FLASHINFER_CSRC_DIR / "bmm_fp8.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", - ], + ] + get_gemm_src_files(), ) return _gemm_module diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index 2f9acc0e..278a8dad 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -17,8 +17,9 @@ import os import re import logging +import subprocess import torch.utils.cpp_extension as torch_cpp_ext -from typing import List +from typing import List, Tuple from .env import ( FLASHINFER_WORKSPACE_DIR, FLASHINFER_JIT_DIR, @@ -84,6 +85,14 @@ def check_cuda_arch(): if arch < 75: raise RuntimeError("FlashInfer requires sm75+") +def get_cuda_version() -> Tuple[int, int]: + if torch_cpp_ext.CUDA_HOME is None: + nvcc = "nvcc" + else: + nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc") + txt = subprocess.check_output([nvcc, "--version"], text=True) + major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0]) + return major, minor def clear_cache_dir(): if os.path.exists(FLASHINFER_JIT_DIR): @@ -104,6 +113,18 @@ def remove_unwanted_pytorch_nvcc_flags(): except ValueError: pass +def get_gemm_src_files(): + cuda_major, _ = get_cuda_version() + if cuda_major < 9: + return [ + FLASHINFER_CSRC_DIR / "group_gemm.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", + ] + else: + return [ + FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu", + ] remove_unwanted_pytorch_nvcc_flags() From 79a6e4ea2500714f023f6dc61bbff4fa474e5a01 Mon Sep 17 00:00:00 2001 From: xsling Date: Tue, 8 Oct 2024 19:33:03 +0000 Subject: [PATCH 5/8] upd --- .../csrc_aot/flashinfer_ops_sm90.cu | 2 +- include/flashinfer/gemm/group_gemm_sm90.cuh | 1 + python/csrc/flashinfer_gemm_ops_sm90.cu | 2 +- python/flashinfer/gemm.py | 32 +++++++++++++------ python/flashinfer/jit/__init__.py | 13 ++++---- 5 files changed, 32 insertions(+), 18 deletions(-) diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu index 0f8c1cbf..bb3ba229 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu @@ -150,6 +150,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); m.def("packbits", &packbits, "GPU packbits operator"); m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); - m.def("cutlass_segment_gemms_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); } \ No newline at end of file diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index d4907cac..6596d543 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -21,6 +21,7 @@ #include "../allocator.h" #include "../utils.cuh" #include "cutlass/cutlass.h" +#include "group_gemm_cutlass.cuh" #include "cute/tensor.hpp" #include "cutlass/tensor_ref.h" diff --git a/python/csrc/flashinfer_gemm_ops_sm90.cu b/python/csrc/flashinfer_gemm_ops_sm90.cu index 4436228c..3f0b8556 100644 --- a/python/csrc/flashinfer_gemm_ops_sm90.cu +++ b/python/csrc/flashinfer_gemm_ops_sm90.cu @@ -26,6 +26,6 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch bool weight_column_major); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemms_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); } \ No newline at end of file diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 1d68ee89..3bdf02ed 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -19,7 +19,7 @@ import torch from .utils import get_indptr -from .jit import get_gemm_src_files, load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops +from .jit import get_gemm_src_files, load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops, is_sm90_capable from typing import Optional @@ -198,15 +198,27 @@ def run( if weight_indices is None: # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) - return get_gemm_module().cutlass_segment_gemm( - self._int_workspace_buffer, - seg_indptr, - weight_indices, - x, - weights, - batch_size, - weight_column_major, - ) + if is_sm90_capable: + return get_gemm_module().cutlass_segment_gemm_sm90( + self._float_workspace_buffer, + self._int_workspace_buffer, + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) + else: + return get_gemm_module().cutlass_segment_gemm( + self._int_workspace_buffer, + seg_indptr, + weight_indices, + x, + weights, + batch_size, + weight_column_major, + ) forward = run diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index 278a8dad..0d7b0ab0 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -94,6 +94,8 @@ def get_cuda_version() -> Tuple[int, int]: major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0]) return major, minor +is_sm90_capable = get_cuda_version() >= (9, 0) + def clear_cache_dir(): if os.path.exists(FLASHINFER_JIT_DIR): for file in os.listdir(FLASHINFER_JIT_DIR): @@ -114,16 +116,15 @@ def remove_unwanted_pytorch_nvcc_flags(): pass def get_gemm_src_files(): - cuda_major, _ = get_cuda_version() - if cuda_major < 9: + if is_sm90_capable: return [ - FLASHINFER_CSRC_DIR / "group_gemm.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", + FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu", ] else: return [ - FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu", + FLASHINFER_CSRC_DIR / "group_gemm.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", ] remove_unwanted_pytorch_nvcc_flags() From 9533d2912dc0df2251abbe13e5e9e039eddd0ce2 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 9 Oct 2024 00:01:02 +0000 Subject: [PATCH 6/8] upd --- flashinfer-aot/csrc_aot/flashinfer_ops.cu | 1 - .../csrc_aot/flashinfer_ops_sm90.cu | 129 ------------------ flashinfer-aot/setup.py | 30 ++-- python/csrc/flashinfer_gemm_ops_sm90.cu | 4 - python/flashinfer/gemm.py | 33 ++++- python/flashinfer/jit/__init__.py | 45 ++---- python/flashinfer/utils.py | 6 + python/setup.py | 1 + 8 files changed, 65 insertions(+), 184 deletions(-) diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops.cu b/flashinfer-aot/csrc_aot/flashinfer_ops.cu index ee2091bf..c9f0313f 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops.cu @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once #include void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu b/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu index bb3ba229..5140982f 100644 --- a/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu +++ b/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu @@ -13,102 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once #include -void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, - torch::Tensor append_indptr, std::optional paged_kv_cache, - std::optional paged_k_cache, - std::optional paged_v_cache, torch::Tensor kv_indices, - torch::Tensor kv_indptr, torch::Tensor kv_last_page_len, - unsigned int layout); - -std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, - torch::Tensor s_b); - -void merge_state_in_place(torch::Tensor v, torch::Tensor s, torch::Tensor v_other, - torch::Tensor s_other, std::optional mask = std::nullopt); - -std::vector merge_states(torch::Tensor v, torch::Tensor s); - -torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples, - bool deterministic); - -std::vector top_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_p_arr, - double top_p_val, bool deterministic); - -std::vector top_k_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, - unsigned int top_k_val, bool deterministic); - -std::vector min_p_sampling_from_probs(torch::Tensor probs, - torch::Tensor uniform_samples, - std::optional maybe_min_p_arr, - double min_p_val, bool deterministic); - -std::vector top_k_top_p_sampling_from_probs( - torch::Tensor probs, torch::Tensor uniform_samples, - std::optional maybe_top_k_arr, double top_k_val, - std::optional maybe_top_p_arr, double top_p_val, bool deterministic); - -torch::Tensor top_p_renorm_probs(torch::Tensor probs, std::optional maybe_top_p_arr, - double top_p_val); - -torch::Tensor top_k_renorm_probs(torch::Tensor probs, std::optional maybe_top_k_arr, - unsigned int top_k_val); - -torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional maybe_top_k_arr, - unsigned int top_k_val); - -std::vector chain_speculative_sampling( - torch::Tensor draft_probs, torch::Tensor draft_token_ids, torch::Tensor uniform_samples, - torch::Tensor target_probs, std::optional maybe_output_accepted_token_num, - std::optional maybe_output_emitted_token_num, bool deterministic); - -void rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); - -void fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps); - -void gemma_rmsnorm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double eps); - -void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, - double eps); - -void silu_and_mul(torch::Tensor& out, torch::Tensor& input); - -void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); - -void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); - -void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta); - -void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta, float low_freq_factor, float high_freq_factor, - float old_context_length); - -std::vector apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor indptr, - torch::Tensor offsets, bool interleave, float rope_scale, - float rope_theta); - -std::vector apply_llama31_rope(torch::Tensor q, torch::Tensor k, - torch::Tensor indptr, torch::Tensor offsets, - bool interleave, float rope_scale, float rope_theta, - float low_freq_factor, float high_freq_factor, - float old_context_length); - -torch::Tensor packbits(torch::Tensor x, const std::string& bitorder); - -torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr, - torch::Tensor output_indptr, const std::string& bitorder); - -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale); torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, torch::Tensor weight_indices, torch::Tensor x, @@ -116,40 +22,5 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch bool weight_column_major); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("append_paged_kv_cache", &append_paged_kv_cache, "Append paged KV-Cache operator"); - m.def("merge_state", &merge_state, "Merge two self-attention states"); - m.def("merge_state_in_place", &merge_state_in_place, - "Merge another self-attention state in-place."); - m.def("merge_states", &merge_states, "Merge multiple self-attention states"); - m.def("sampling_from_probs", &sampling_from_probs, "Sample from probabilities"); - m.def("top_k_sampling_from_probs", &top_k_sampling_from_probs, - "Top-k sampling from probabilities"); - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, - "Min-p sampling from probabilities"); - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, - "Top-p sampling from probabilities"); - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, - "Top-k and top-p sampling from probabilities"); - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Renormalize probabilities by top-k mask"); - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Renormalize probabilities by top-p mask"); - m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask"); - m.def("chain_speculative_sampling", &chain_speculative_sampling, - "Speculative sampling from sequence of probabilities"); - m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, - "Gemma Fused add root mean square normalization"); - m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); - m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); - m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place"); - m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace, - "Apply Llama 3.1 style RoPE in-place"); - m.def("apply_rope", &apply_rope, "Apply RoPE"); - m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); - m.def("packbits", &packbits, "GPU packbits operator"); - m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator"); m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); } \ No newline at end of file diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py index 0bad49e7..80768822 100644 --- a/flashinfer-aot/setup.py +++ b/flashinfer-aot/setup.py @@ -336,18 +336,6 @@ def remove_unwanted_pytorch_nvcc_flags(): except ValueError: pass -def get_gemm_src_files(): - cuda_major, _ = get_cuda_version() - if cuda_major < 9: - return [ - "csrc/group_gemm.cu", - "csrc_aot/flashinfer_ops.cu", - ] - else: - return [ - "csrc/group_gemm_sm90.cu", - "csrc_aot/flashinfer_ops_sm90.cu", - ] class NinjaBuildExtension(torch_cpp_ext.BuildExtension): def __init__(self, *args, **kwargs) -> None: @@ -384,6 +372,10 @@ def __init__(self, *args, **kwargs) -> None: "-use_fast_math", ], } + extra_compile_args_sm90 = extra_compile_args.copy() + extra_compile_args_sm90["nvcc"].extend( + "-gencode arch=compute_90a,code=sm_90a".split() + ) ext_modules = [] ext_modules.append( torch_cpp_ext.CUDAExtension( @@ -398,11 +390,23 @@ def __init__(self, *args, **kwargs) -> None: "csrc/quantization.cu", "csrc/group_gemm.cu", "csrc/bmm_fp8.cu", - ] + get_gemm_src_files(), + "csrc_aot/flashinfer_ops.cu" + ], include_dirs=include_dirs, extra_compile_args=extra_compile_args, ) ) + ext_modules.append( + torch_cpp_ext.CUDAExtension( + name="flashinfer._kernels_sm90", + sources=[ + "csrc/group_gemm_sm90.cu", + "csrc_aot/flashinfer_ops_sm90.cu", + ], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args_sm90, + ) + ) ext_modules.append( torch_cpp_ext.CUDAExtension( name="flashinfer._decode_kernels", diff --git a/python/csrc/flashinfer_gemm_ops_sm90.cu b/python/csrc/flashinfer_gemm_ops_sm90.cu index 3f0b8556..8d557bac 100644 --- a/python/csrc/flashinfer_gemm_ops_sm90.cu +++ b/python/csrc/flashinfer_gemm_ops_sm90.cu @@ -15,9 +15,6 @@ */ #include -void bmm_fp8(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& D, - torch::Tensor& A_scale, torch::Tensor& B_scale); - // (... Tensor x_arr, Tensor w_arr, Tensor y_arr, Tensor x_stride, Tensor weight_stride, Tensor y_stride, Tensor problem_shape ...) torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, @@ -27,5 +24,4 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); - m.def("bmm_fp8", &bmm_fp8, "BMM FP8"); } \ No newline at end of file diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index 3bdf02ed..b00d844d 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -18,12 +18,13 @@ import torch -from .utils import get_indptr -from .jit import get_gemm_src_files, load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops, is_sm90_capable +from .utils import get_indptr, get_compute_capability +from .jit import load_cuda_ops, FLASHINFER_CSRC_DIR, has_prebuilt_ops from typing import Optional _gemm_module = None +_gemm_module_sm90 = None def get_gemm_module(): @@ -38,10 +39,31 @@ def get_gemm_module(): "gemm", [ FLASHINFER_CSRC_DIR / "bmm_fp8.cu", - ] + get_gemm_src_files(), + FLASHINFER_CSRC_DIR / "group_gemm.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", + ], ) return _gemm_module +def get_gemm_sm90_module(): + print("get_gemm_sm90_module") + global _gemm_module_sm90 + if _gemm_module_sm90 is None: + if has_prebuilt_ops: + from . import _kernels_sm90 + + _gemm_module_sm90 = _kernels_sm90 + else: + _gemm_module_sm90 = load_cuda_ops( + "gemm_sm90", + [ + FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu", + ], + extra_cuda_cflags=["-gencode", "arch=compute_90a,code=sm_90a"], + ) + return _gemm_module_sm90 + class SegmentGEMMWrapper: r"""Wrapper for segment GEMM kernels. @@ -198,8 +220,9 @@ def run( if weight_indices is None: # create an empty CPU tensor as placeholder weight_indices = torch.empty(0, dtype=torch.int64) - if is_sm90_capable: - return get_gemm_module().cutlass_segment_gemm_sm90( + major, _ = get_compute_capability(x.device) + if major >= 9: + return get_gemm_sm90_module().cutlass_segment_gemm_sm90( self._float_workspace_buffer, self._int_workspace_buffer, seg_indptr, diff --git a/python/flashinfer/jit/__init__.py b/python/flashinfer/jit/__init__.py index 0d7b0ab0..e0e272b8 100644 --- a/python/flashinfer/jit/__init__.py +++ b/python/flashinfer/jit/__init__.py @@ -85,16 +85,6 @@ def check_cuda_arch(): if arch < 75: raise RuntimeError("FlashInfer requires sm75+") -def get_cuda_version() -> Tuple[int, int]: - if torch_cpp_ext.CUDA_HOME is None: - nvcc = "nvcc" - else: - nvcc = os.path.join(torch_cpp_ext.CUDA_HOME, "bin/nvcc") - txt = subprocess.check_output([nvcc, "--version"], text=True) - major, minor = map(int, re.findall(r"release (\d+)\.(\d+),", txt)[0]) - return major, minor - -is_sm90_capable = get_cuda_version() >= (9, 0) def clear_cache_dir(): if os.path.exists(FLASHINFER_JIT_DIR): @@ -115,17 +105,6 @@ def remove_unwanted_pytorch_nvcc_flags(): except ValueError: pass -def get_gemm_src_files(): - if is_sm90_capable: - return [ - FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu", - ] - else: - return [ - FLASHINFER_CSRC_DIR / "group_gemm.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops.cu", - ] remove_unwanted_pytorch_nvcc_flags() @@ -133,22 +112,24 @@ def get_gemm_src_files(): def load_cuda_ops( name: str, sources: List[str], - extra_cflags: List[str] = ["-O3", "-Wno-switch-bool"], - extra_cuda_cflags: List[str] = [ + extra_cflags: List[str] = [], + extra_cuda_cflags: List[str] = [], + extra_ldflags=None, + extra_include_paths=None, + verbose=False, +): + cflags = ["-O3", "-Wno-switch-bool"] + cuda_cflags = [ "-O3", "-std=c++17", "--threads", "4", - # "-Xfatbin", - # "-compress-all", "-use_fast_math", "-DFLASHINFER_ENABLE_BF16", "-DFLASHINFER_ENABLE_FP8", - ], - extra_ldflags=None, - extra_include_paths=None, - verbose=False, -): + ] + cflags += extra_cflags + cuda_cflags += extra_cuda_cflags logger.info(f"Loading JIT ops: {name}") check_cuda_arch() build_directory = FLASHINFER_JIT_DIR / name @@ -162,8 +143,8 @@ def load_cuda_ops( return torch_cpp_ext.load( name, list(map(lambda _: str(_), sources)), - extra_cflags=extra_cflags, - extra_cuda_cflags=extra_cuda_cflags, + extra_cflags=cflags, + extra_cuda_cflags=cuda_cflags, extra_ldflags=extra_ldflags, extra_include_paths=list(map(lambda _: str(_), extra_include_paths)), build_directory=build_directory, diff --git a/python/flashinfer/utils.py b/python/flashinfer/utils.py index e8e40ee5..acaed4a0 100644 --- a/python/flashinfer/utils.py +++ b/python/flashinfer/utils.py @@ -176,3 +176,9 @@ def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: raise TypeError( "dtype must be a string or torch.dtype, got {}".format(type(dtype)) ) + + +def get_compute_capability(device: torch.device) -> Tuple[int, int]: + if device.type != "cuda": + raise ValueError("device must be a cuda device") + return torch.cuda.get_device_capability(device.index) diff --git a/python/setup.py b/python/setup.py index ffb3debb..52166d51 100644 --- a/python/setup.py +++ b/python/setup.py @@ -46,6 +46,7 @@ def clear_aot_config(): if __name__ == "__main__": generate_build_meta() + clear_aot_config() setuptools.setup( name="flashinfer", version=get_version(), From b73d8e157a10cc9a82387bdd41ed32632b3e5d90 Mon Sep 17 00:00:00 2001 From: xsling Date: Wed, 9 Oct 2024 00:14:06 +0000 Subject: [PATCH 7/8] upd --- include/flashinfer/gemm/group_gemm_sm90.cuh | 35 ++++++++++----------- python/csrc/flashinfer_gemm_ops_sm90.cu | 13 ++++---- python/csrc/group_gemm_sm90.cu | 15 +++++---- python/flashinfer/gemm.py | 1 + python/flashinfer/jit/env.py | 2 +- 5 files changed, 34 insertions(+), 32 deletions(-) diff --git a/include/flashinfer/gemm/group_gemm_sm90.cuh b/include/flashinfer/gemm/group_gemm_sm90.cuh index 6596d543..5d660a07 100644 --- a/include/flashinfer/gemm/group_gemm_sm90.cuh +++ b/include/flashinfer/gemm/group_gemm_sm90.cuh @@ -18,30 +18,30 @@ #include -#include "../allocator.h" -#include "../utils.cuh" -#include "cutlass/cutlass.h" #include "group_gemm_cutlass.cuh" +#include "../allocator.h" +#include "../utils.cuh" #include "cute/tensor.hpp" +#include "cutlass/cutlass.h" #include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" #include "cutlass/gemm/kernel/gemm_universal.hpp" - #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/tensor_view_io.h" #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + namespace flashinfer { @@ -72,13 +72,12 @@ using namespace cute; } template -cudaError_t CutlassSegmentGEMMSM90Run( - void* float_buffer, size_t float_buffer_size_in_bytes, - void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x, - DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, - int64_t* w_indices_d, unsigned int batch_size, - unsigned int d_in, unsigned int d_out, - bool weight_column_major, cudaStream_t stream) { +cudaError_t CutlassSegmentGEMMSM90Run(void* float_buffer, size_t float_buffer_size_in_bytes, + void* int_buffer, size_t int_buffer_size_in_bytes, DTypeIn* x, + DTypeIn* w, DTypeOut* y, int64_t* xy_indptr_d, + int64_t* w_indices_d, unsigned int batch_size, + unsigned int d_in, unsigned int d_out, + bool weight_column_major, cudaStream_t stream) { auto compute_capacity = GetCudaComputeCapability(); if (compute_capacity.first < 9) { std::cerr << "CutlassSegmentGEMMSM90Run requires compute capability of at least 9.0" @@ -167,8 +166,7 @@ cudaError_t CutlassSegmentGEMMSM90Run( using StrideC = typename Gemm::GemmKernel::InternalStrideC; using StrideD = typename Gemm::GemmKernel::InternalStrideD; - AlignedAllocator allocator(int_buffer, - int_buffer_size_in_bytes); + AlignedAllocator allocator(int_buffer, int_buffer_size_in_bytes); ProblemShape::UnderlyingProblemShape* problem_sizes_device = allocator.aligned_alloc( batch_size * sizeof(ProblemShape::UnderlyingProblemShape), 16, @@ -224,8 +222,7 @@ cudaError_t CutlassSegmentGEMMSM90Run( size_t workspace_size = Gemm::get_workspace_size(arguments); // Allocate workspace memory - AlignedAllocator float_allocator(float_buffer, - float_buffer_size_in_bytes); + AlignedAllocator float_allocator(float_buffer, float_buffer_size_in_bytes); auto workspace_ptr = float_allocator.aligned_alloc(workspace_size, 64, "sm90_group_gemm_float_workspace"); diff --git a/python/csrc/flashinfer_gemm_ops_sm90.cu b/python/csrc/flashinfer_gemm_ops_sm90.cu index 8d557bac..143f9c46 100644 --- a/python/csrc/flashinfer_gemm_ops_sm90.cu +++ b/python/csrc/flashinfer_gemm_ops_sm90.cu @@ -16,12 +16,13 @@ #include -// (... Tensor x_arr, Tensor w_arr, Tensor y_arr, Tensor x_stride, Tensor weight_stride, Tensor y_stride, Tensor problem_shape ...) -torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major); +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); + m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, + "Cutlass Segment GEMM operator for SM90"); } \ No newline at end of file diff --git a/python/csrc/group_gemm_sm90.cu b/python/csrc/group_gemm_sm90.cu index be82d1ce..a218f347 100644 --- a/python/csrc/group_gemm_sm90.cu +++ b/python/csrc/group_gemm_sm90.cu @@ -19,10 +19,11 @@ using namespace flashinfer::group_gemm; -torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, - torch::Tensor weight_indices, torch::Tensor x, - torch::Tensor weight, unsigned int batch_size, - bool weight_column_major) { +torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, + torch::Tensor int_workspace_buffer, torch::Tensor seg_indptr, + torch::Tensor weight_indices, torch::Tensor x, + torch::Tensor weight, unsigned int batch_size, + bool weight_column_major) { // TODO(Zihao): Add more checks here CHECK_INPUT(seg_indptr); CHECK_INPUT(x); @@ -52,8 +53,10 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, torch DISPATCH_PYTORCH_DTYPE_TO_CTYPE(x.scalar_type(), c_type, [&] { using cutlass_t = typename cutlass_dtype::type; auto status = CutlassSegmentGEMMSM90Run( - float_workspace_buffer.data_ptr(), float_workspace_buffer.element_size() * float_workspace_buffer.size(0), - int_workspace_buffer.data_ptr(), int_workspace_buffer.element_size() * int_workspace_buffer.size(0), + float_workspace_buffer.data_ptr(), + float_workspace_buffer.element_size() * float_workspace_buffer.size(0), + int_workspace_buffer.data_ptr(), + int_workspace_buffer.element_size() * int_workspace_buffer.size(0), static_cast(x.data_ptr()), static_cast(weight.data_ptr()), static_cast(y.data_ptr()), static_cast(seg_indptr.data_ptr()), weight_indices_defined ? static_cast(weight_indices.data_ptr()) : nullptr, diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index b00d844d..d6f54497 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -45,6 +45,7 @@ def get_gemm_module(): ) return _gemm_module + def get_gemm_sm90_module(): print("get_gemm_sm90_module") global _gemm_module_sm90 diff --git a/python/flashinfer/jit/env.py b/python/flashinfer/jit/env.py index c51bb8bb..e3fbec81 100644 --- a/python/flashinfer/jit/env.py +++ b/python/flashinfer/jit/env.py @@ -25,5 +25,5 @@ FLASHINFER_CSRC_DIR = _project_root / "csrc" CUTLASS_INCLUDE_DIRS = [ _project_root / "3rdparty" / "cutlass" / "include", - _project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include" + _project_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", ] From f48e123bc7b7e818781165f194a6f02dbba036bc Mon Sep 17 00:00:00 2001 From: xsling Date: Wed, 9 Oct 2024 00:22:16 +0000 Subject: [PATCH 8/8] upd --- .../csrc_aot/{flashinfer_ops_sm90.cu => flashinfer_sm90_ops.cu} | 0 flashinfer-aot/setup.py | 2 +- ...{flashinfer_gemm_ops_sm90.cu => flashinfer_gemm_sm90_ops.cu} | 2 +- python/flashinfer/gemm.py | 2 +- 4 files changed, 3 insertions(+), 3 deletions(-) rename flashinfer-aot/csrc_aot/{flashinfer_ops_sm90.cu => flashinfer_sm90_ops.cu} (100%) rename python/csrc/{flashinfer_gemm_ops_sm90.cu => flashinfer_gemm_sm90_ops.cu} (99%) diff --git a/flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu b/flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu similarity index 100% rename from flashinfer-aot/csrc_aot/flashinfer_ops_sm90.cu rename to flashinfer-aot/csrc_aot/flashinfer_sm90_ops.cu diff --git a/flashinfer-aot/setup.py b/flashinfer-aot/setup.py index 80768822..5c541026 100644 --- a/flashinfer-aot/setup.py +++ b/flashinfer-aot/setup.py @@ -401,7 +401,7 @@ def __init__(self, *args, **kwargs) -> None: name="flashinfer._kernels_sm90", sources=[ "csrc/group_gemm_sm90.cu", - "csrc_aot/flashinfer_ops_sm90.cu", + "csrc_aot/flashinfer_sm90_ops.cu", ], include_dirs=include_dirs, extra_compile_args=extra_compile_args_sm90, diff --git a/python/csrc/flashinfer_gemm_ops_sm90.cu b/python/csrc/flashinfer_gemm_sm90_ops.cu similarity index 99% rename from python/csrc/flashinfer_gemm_ops_sm90.cu rename to python/csrc/flashinfer_gemm_sm90_ops.cu index 143f9c46..0332eb8d 100644 --- a/python/csrc/flashinfer_gemm_ops_sm90.cu +++ b/python/csrc/flashinfer_gemm_sm90_ops.cu @@ -25,4 +25,4 @@ torch::Tensor CutlassSegmentGEMMSM90(torch::Tensor float_workspace_buffer, PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cutlass_segment_gemm_sm90", &CutlassSegmentGEMMSM90, "Cutlass Segment GEMM operator for SM90"); -} \ No newline at end of file +} diff --git a/python/flashinfer/gemm.py b/python/flashinfer/gemm.py index d6f54497..5b938765 100644 --- a/python/flashinfer/gemm.py +++ b/python/flashinfer/gemm.py @@ -59,7 +59,7 @@ def get_gemm_sm90_module(): "gemm_sm90", [ FLASHINFER_CSRC_DIR / "group_gemm_sm90.cu", - FLASHINFER_CSRC_DIR / "flashinfer_gemm_ops_sm90.cu", + FLASHINFER_CSRC_DIR / "flashinfer_gemm_sm90_ops.cu", ], extra_cuda_cflags=["-gencode", "arch=compute_90a,code=sm_90a"], )