From 6be6496be16e22d2571e8af9b88d5224dc6c610f Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 29 Aug 2025 13:18:43 -0700 Subject: [PATCH 1/3] Remove old cutlass MLA kernel Signed-off-by: Matthew Bonanni --- CMakeLists.txt | 2 - csrc/attention/mla/cutlass_mla_entry.cu | 38 --- csrc/attention/mla/cutlass_mla_kernels.cu | 225 ------------------ csrc/torch_bindings.cpp | 7 - vllm/_custom_ops.py | 13 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 57 +---- 6 files changed, 4 insertions(+), 338 deletions(-) delete mode 100644 csrc/attention/mla/cutlass_mla_entry.cu delete mode 100644 csrc/attention/mla/cutlass_mla_kernels.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 009c224dc773..c11083233233 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -298,7 +298,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/attention/mla/cutlass_mla_entry.cu" "csrc/quantization/fp8/per_token_group_quant.cu") set_gencode_flags_for_srcs( @@ -585,7 +584,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) set(SRCS - "csrc/attention/mla/cutlass_mla_kernels.cu" "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/attention/mla/cutlass_mla_entry.cu b/csrc/attention/mla/cutlass_mla_entry.cu deleted file mode 100644 index 0319d1daf302..000000000000 --- a/csrc/attention/mla/cutlass_mla_entry.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA -void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale); -#endif - -void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale) { -#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA - return cutlass_mla_decode_sm100a(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale); -#endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled cutlass MLA"); -} diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu deleted file mode 100644 index 9d05d910dd81..000000000000 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" -#include "cutlass/kernel_hardware_info.h" - -#include "cutlass_extensions/common.hpp" - -#include "device/sm100_mla.hpp" -#include "kernel/sm100_mla_tile_scheduler.hpp" - -using namespace cute; -using namespace cutlass::fmha::kernel; - -template -struct MlaSm100 { - using Element = T; - using ElementAcc = float; - using ElementOut = T; - - using TileShape = Shape<_128, _128, Shape<_512, _64>>; - using TileShapeH = cute::tuple_element_t<0, TileShape>; - using TileShapeD = cute::tuple_element_t<2, TileShape>; - - // H K (D_latent D_rope) B - using ProblemShape = cute::tuple; - - using StrideQ = cute::tuple; // H D B - using StrideK = cute::tuple; // K D B - using StrideO = StrideK; // H D B - using StrideLSE = cute::tuple<_1, int>; // H B - - using TileScheduler = - std::conditional_t; - - using FmhaKernel = - cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< - TileShape, Element, ElementAcc, ElementOut, ElementAcc, TileScheduler, - /*kIsCpAsync=*/true>; - using Fmha = cutlass::fmha::device::MLA; -}; - -template -typename T::Fmha::Arguments args_from_options( - at::Tensor const& out, at::Tensor const& q_nope, at::Tensor const& q_pe, - at::Tensor const& kv_c_and_k_pe_cache, at::Tensor const& seq_lens, - at::Tensor const& page_table, double scale) { - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = q_nope.device().index(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); - - int batches = q_nope.sizes()[0]; - int page_count_per_seq = page_table.sizes()[1]; - int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; - int page_size = kv_c_and_k_pe_cache.sizes()[1]; - int max_seq_len = page_size * page_count_per_seq; - using TileShapeH = typename T::TileShapeH; - using TileShapeD = typename T::TileShapeD; - auto problem_shape = - cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); - - auto [H, K, D, B] = problem_shape; - auto [D_latent, D_rope] = D; - - using StrideQ = typename T::StrideQ; - using StrideK = typename T::StrideK; - using StrideO = typename T::StrideO; - using StrideLSE = typename T::StrideLSE; - - StrideQ stride_Q_latent = cute::make_tuple( - static_cast(D_latent), _1{}, static_cast(H * D_latent)); - StrideQ stride_Q_rope = cute::make_tuple(static_cast(D_rope), _1{}, - static_cast(H * D_rope)); - StrideK stride_C = - cute::make_tuple(static_cast(D_latent + D_rope), _1{}, - static_cast(page_size * (D_latent + D_rope))); - StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); - StrideLSE stride_LSE = cute::make_tuple(_1{}, static_cast(H)); - StrideO stride_O = cute::make_tuple(static_cast(D_latent), _1{}, - static_cast(H * D_latent)); - - using Element = typename T::Element; - using ElementOut = typename T::ElementOut; - using ElementAcc = typename T::ElementAcc; - auto Q_latent_ptr = static_cast(q_nope.data_ptr()); - auto Q_rope_ptr = static_cast(q_pe.data_ptr()); - auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); - auto scale_f = static_cast(scale); - typename T::Fmha::Arguments arguments{ - problem_shape, - {scale_f, Q_latent_ptr, stride_Q_latent, Q_rope_ptr, stride_Q_rope, C_ptr, - stride_C, C_ptr + D_latent, stride_C, - static_cast(seq_lens.data_ptr()), - static_cast(page_table.data_ptr()), stride_PT, page_count_total, - page_size}, - {static_cast(out.data_ptr()), stride_O, - static_cast(nullptr), stride_LSE}, - hw_info, - 1, // split_kv - nullptr, // is_var_split_kv - }; - // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute - // split_kv automatically based on batch size and sequence length to balance - // workload across available SMs. Consider using var_split_kv for manual - // control if needed. - T::Fmha::set_split_kv(arguments); - return arguments; -} - -template -void runMla(at::Tensor const& out, at::Tensor const& q_nope, - at::Tensor const& q_pe, at::Tensor const& kv_c_and_k_pe_cache, - at::Tensor const& seq_lens, at::Tensor const& page_table, - float scale, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; - typename MlaSm100Type::Fmha fmha; - auto arguments = args_from_options( - out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, scale); - size_t workspace_size = MlaSm100Type::Fmha::get_workspace_size(arguments); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(q_nope.device()); - auto workspace = torch::empty(workspace_size, workspace_options); - - CUTLASS_CHECK(fmha.can_implement(arguments)); - - CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); - - CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); -} - -void cutlass_mla_decode_sm100a(torch::Tensor const& out, - torch::Tensor const& q_nope, - torch::Tensor const& q_pe, - torch::Tensor const& kv_c_and_k_pe_cache, - torch::Tensor const& seq_lens, - torch::Tensor const& page_table, double scale) { - TORCH_CHECK(q_nope.device().is_cuda(), "q_nope must be on CUDA"); - TORCH_CHECK(q_nope.dim() == 3, "q_nope must be a 3D tensor"); - TORCH_CHECK(q_pe.dim() == 3, "q_pe must be a 3D tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dim() == 3, - "kv_c_and_k_pe_cache must be a 3D tensor"); - TORCH_CHECK(seq_lens.dim() == 1, "seq_lens must be a 1D tensor"); - TORCH_CHECK(page_table.dim() == 2, "page_table must be a 2D tensor"); - TORCH_CHECK(out.dim() == 3, "out must be a 3D tensor"); - - auto B_q_nope = q_nope.size(0); - auto H_q_nope = q_nope.size(1); - auto D_q_nope = q_nope.size(2); - auto B_q_pe = q_pe.size(0); - auto H_q_pe = q_pe.size(1); - auto D_q_pe = q_pe.size(2); - auto B_pt = page_table.size(0); - auto PAGE_NUM = page_table.size(1); - auto PAGE_SIZE = kv_c_and_k_pe_cache.size(1); - auto D_ckv = kv_c_and_k_pe_cache.size(2); - auto B_o = out.size(0); - auto H_o = out.size(1); - auto D_o = out.size(2); - - TORCH_CHECK(D_q_nope == 512, "D_q_nope must be equal to 512"); - TORCH_CHECK(D_q_pe == 64, "D_q_pe must be equal to 64"); - TORCH_CHECK(D_ckv == 576, "D_ckv must be equal to 576"); - TORCH_CHECK(H_q_nope == H_q_pe && H_q_nope == H_o && H_o == 128, - "H_q_nope, H_q_pe, and H_o must be equal to 128"); - TORCH_CHECK(PAGE_SIZE > 0 && (PAGE_SIZE & (PAGE_SIZE - 1)) == 0, - "PAGE_SIZE must be a power of 2"); - TORCH_CHECK( - B_q_nope == B_q_pe && B_q_nope == B_pt && B_q_nope == B_o, - "Batch dims must be same for page_table, q_nope and q_pe, and out"); - TORCH_CHECK(PAGE_NUM % (128 / PAGE_SIZE) == 0, - "PAGE_NUM must be divisible by 128 / PAGE_SIZE"); - TORCH_CHECK(D_o == 512, "D_o must be equal to 512"); - - TORCH_CHECK(q_nope.dtype() == at::ScalarType::Half || - q_nope.dtype() == at::ScalarType::BFloat16 || - q_nope.dtype() == at::ScalarType::Float8_e4m3fn, - "q_nope must be a half, bfloat16, or float8_e4m3fn tensor"); - TORCH_CHECK(kv_c_and_k_pe_cache.dtype() == q_nope.dtype() && - q_nope.dtype() == q_pe.dtype(), - "kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"); - TORCH_CHECK(seq_lens.dtype() == torch::kInt32, - "seq_lens must be a 32-bit integer tensor"); - TORCH_CHECK(page_table.dtype() == torch::kInt32, - "page_table must be a 32-bit integer tensor"); - - auto in_dtype = q_nope.dtype(); - const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); - const cudaStream_t stream = - at::cuda::getCurrentCUDAStream(q_nope.get_device()); - if (in_dtype == at::ScalarType::Half) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, - page_table, scale, stream); - } else if (in_dtype == at::ScalarType::BFloat16) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); - } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale, stream); - } else { - TORCH_CHECK(false, "Unsupported input data type of MLA"); - } -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f22e23519831..bc096406c51a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -510,13 +510,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); - // CUTLASS MLA decode - ops.def( - "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," - " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," - " Tensor page_table, float scale) -> ()"); - ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // SM100 CUTLASS MLA decode ops.def( "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 456c6b3ba923..7f60f52a9c4d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1823,17 +1823,8 @@ def flash_mla_with_kvcache( return out, softmax_lse -def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, page_table: torch.Tensor, - scale: float) -> torch.Tensor: - torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, - seq_lens, page_table, scale) - return out - - -def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, - q_nope: torch.Tensor, q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 78af8d28f889..68bb5972259b 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -219,12 +219,13 @@ def _sm100_cutlass_mla_decode( return out, returned_lse - def _sm100_forward_decode( + def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, + layer: AttentionLayer, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None @@ -245,57 +246,3 @@ def _sm100_forward_decode( ) return o, (lse if self.need_to_return_lse_for_decode else None) - - # TODO: Currently we leave it here only for backup in case something is - # wrong with the new SM100 CUTLASS MLA kernel - def _old_forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - assert attn_metadata.decode is not None - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA") - - B = q_nope.shape[0] - - o = torch.empty((B, self.num_heads, self.kv_lora_rank), - dtype=q_nope.dtype, - device=q_nope.device) - - # Run MLA - # Clone q_nope and q_pe to make sure strides computation is correct. - q_nope = q_nope.clone() - q_pe = q_pe.clone() - - ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata.decode.seq_lens, - attn_metadata.decode.block_table, self.scale) - - return o - - def _forward_decode( - self, - q: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: MLACommonMetadata, - layer: AttentionLayer, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - if type(q) is tuple: - q_nope, q_pe = q - else: - q_nope, q_pe = torch.split( - q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - if self._use_old_cutlass_mla: - # TODO: Remove the old cutlass MLA kernel after more extensive - # testing - return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata), None - - return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, - attn_metadata) From fb844649376122fa70cd4296708d8b3997026675 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Fri, 29 Aug 2025 13:21:31 -0700 Subject: [PATCH 2/3] Remove FORCE_OLD_CUTLASS_MLA option Signed-off-by: Matthew Bonanni --- vllm/v1/attention/backends/mla/cutlass_mla.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 68bb5972259b..fdcdb236e622 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -109,12 +109,6 @@ def __init__( "are not implemented for " "CutlassMLAImpl") - self._use_old_cutlass_mla = False - force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) - if force_old_cutlass: - logger.warning_once("Forcing old cutlass mla kernel") - self._use_old_cutlass_mla = True - # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging # issues. In case the code hangs, use: # FORCE_NUM_KV_SPLITS=1 From d091fd2da6f491d09512d51b30746419c5e6f03a Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 16 Sep 2025 12:50:59 -0400 Subject: [PATCH 3/3] Address pre-commit Signed-off-by: Matthew Bonanni --- vllm/_custom_ops.py | 4 ++-- vllm/v1/attention/backends/mla/cutlass_mla.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7f60f52a9c4d..712295aa9288 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1823,8 +1823,8 @@ def flash_mla_with_kvcache( return out, softmax_lse -def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, - q_pe: torch.Tensor, +def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor, + q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, seq_lens: torch.Tensor, page_table: torch.Tensor, workspace: torch.Tensor, scale: float, diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index fdcdb236e622..21be17a750df 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import ClassVar, Optional +from typing import ClassVar, Optional, Union import torch @@ -215,8 +215,7 @@ def _sm100_cutlass_mla_decode( def _forward_decode( self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, + q: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, layer: AttentionLayer, @@ -224,6 +223,12 @@ def _forward_decode( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None + if type(q) is tuple: + q_nope, q_pe = q + else: + q_nope, q_pe = torch.split( + q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits)