diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu index 6dd6f269f3dc..820bf81dd1a0 100644 --- a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -64,11 +64,11 @@ struct IsPersistent { static const bool value = v; }; -template > +template > struct MlaSm100 { using Element = T; using ElementAcc = float; - using ElementOut = T; + using ElementOut = TOut; using TileShape = Shape<_128, _128, Shape<_512, _64>>; using TileShapeH = cute::tuple_element_t<0, TileShape>; @@ -178,7 +178,7 @@ typename T::Fmha::Arguments args_from_options( return arguments; } -template +template void runMla( at::Tensor const& out, at::Tensor const& q_nope, @@ -190,7 +190,7 @@ void runMla( double sm_scale, int64_t num_kv_splits, cudaStream_t stream) { - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; typename MlaSm100Type::Fmha fmha; auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); @@ -233,13 +233,13 @@ void sm100_cutlass_mla_decode( DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { if (in_dtype == at::ScalarType::Half) { - runMla>( + runMla>( out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::BFloat16) { - runMla>( + runMla>( out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { - runMla>( + runMla>( out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); } else { TORCH_CHECK(false, "Unsupported input data type of MLA"); @@ -253,7 +253,7 @@ void sm100_cutlass_mla_decode( int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) // which are float, so Element type here doesn't matter. - using MlaSm100Type = MlaSm100; + using MlaSm100Type = MlaSm100; // Get split kv. Requires problem shape and sm_count only. typename MlaSm100Type::Fmha::Arguments arguments; diff --git a/tests/kernels/test_cutlass_mla_decode.py b/tests/kernels/test_cutlass_mla_decode.py index 2b745b84dae6..85984324b196 100644 --- a/tests/kernels/test_cutlass_mla_decode.py +++ b/tests/kernels/test_cutlass_mla_decode.py @@ -1,96 +1,180 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +import random + import pytest import torch -import torch.nn.functional as F -from torch import Tensor import vllm._custom_ops as ops from vllm.platforms import current_platform - -if not current_platform.has_device_capability(100): - pytest.skip( - reason="Cutlass MLA Requires compute capability of 10 or above.", - allow_module_level=True) - - -def ref_mla( - out: Tensor, # (bs, num_heads, v_head_dim) - query: Tensor, # (bs, num_heads, head_dim) - kv_cache: Tensor, # (num_blocks, block_size, head_dim) - scale: float, - block_tables: Tensor, # (bs, max_num_blocks) - seq_lens: Tensor, # (bs,) -): - bs, num_heads, v_head_dim = out.shape - head_dim = query.shape[2] - - for i in range(bs): - # gather and flatten KV-cache - kv = kv_cache[ - block_tables[i]] # (max_num_blocks, block_size, head_dim) - kv = kv.view(1, -1, - head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) - v = kv[:, :, :v_head_dim] - - q = query[i].view(num_heads, 1, head_dim) - o = F.scaled_dot_product_attention(q, - kv, - v, - scale=scale, - enable_gqa=True) - out[i] = o.view(num_heads, v_head_dim) - - return out - - -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) -@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) -@pytest.mark.parametrize("bs", [1, 2, 4]) +from vllm.triton_utils import triton + + +def cal_diff(x: torch.Tensor, + y: torch.Tensor, + name: str, + use_fp8: bool = False) -> None: + x, y = x.double(), y.double() + cos_diff = 1 - 2 * (x * y).sum().item() / max( + (x * x + y * y).sum().item(), 1e-12) + if (use_fp8): + assert cos_diff < 1e-4 + else: + assert cos_diff < 1e-5 + + +CUTLASS_MLA_UNSUPPORTED_REASON = \ + "Cutlass MLA Requires compute capability of 10 or above." \ + if not current_platform.is_device_capability(100) \ + else "Cutlass MLA is supported" + + +@pytest.mark.skipif(not current_platform.has_device_capability(100), + reason=CUTLASS_MLA_UNSUPPORTED_REASON) +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1]) +@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) -@pytest.mark.parametrize("block_size", [16, 64, 128]) -def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, - varlen: bool, block_size: int): - torch.set_default_dtype(dtype) - torch.set_default_device('cuda') +@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn]) +@torch.inference_mode() +def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, + causal, varlen, torch_dtype): + device = torch.device("cuda:0") + if torch_dtype == torch.float8_e4m3fn: + init_dtype = torch.bfloat16 + else: + init_dtype = torch_dtype + torch.set_default_dtype(init_dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) torch.manual_seed(42) + random.seed(42) - d = 576 - h_q = 128 - dv = 512 + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") - q_nope_dim = 128 - q_pe_dim = 64 - scale = (q_nope_dim + q_pe_dim)**(-0.5) + use_fp8 = torch_dtype == torch.float8_e4m3fn + scale = math.sqrt(d)**(-1) + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: - seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) - seq_lens = seq_lens.clip(2).to(torch.int32) + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), + s_q) + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange(b * max_seqlen_pad // block_size, + dtype=torch.int32).view( + b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + blocked_v = blocked_k[..., :dv] + + init_dtype = q.dtype + if use_fp8: + fp8_dtype = torch.float8_e4m3fn + descale_q = torch.ones((1), dtype=torch.float32) + descale_k = torch.ones((1), dtype=torch.float32) + + q = q.to(fp8_dtype) + blocked_k = blocked_k.to(fp8_dtype) + blocked_v = blocked_v.to(fp8_dtype) else: - seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) - max_seq_len = seq_lens.max().item() - block_num = (max_seq_len + block_size - 1) // block_size - - # Pad block_num so that small blocks can be packed into full 128-sized - # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small - # blocks. - pack_factor = 128 // block_size - block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor - - # Amplify input values to ensure test coverage of edge cases where CUTLASS - # kernel errors occur with split_k settings. - q = torch.randn(bs, h_q, d) * 100 - block_table = torch.randint(0, - bs * block_num, (bs, block_num), - dtype=torch.int32) - - kv_cache = torch.randn(block_table.numel(), block_size, d) - - out_ref = q.new_zeros(bs, h_q, dv) - ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) - out_ans = torch.zeros_like(out_ref) - q_nope = q[:, :, :dv].clone() - q_pe = q[:, :, dv:].clone() - ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, - block_table, scale) - - torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) + descale_q = None + descale_k = None + + def cutlass_mla(): + MAX_HEADS = 128 + + q_reshaped = q.squeeze(1) + q_nope = q_reshaped[:, :, :dv].clone() + q_pe = q_reshaped[:, :, dv:].clone() + + if h_q < MAX_HEADS: + q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv)) + q_nope_padded[:, :h_q] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv)) + q_pe_padded[:, :h_q] = q_pe + q_pe = q_pe_padded + + kv_cache_flat = blocked_k.squeeze(2) + device_properties = torch.cuda.get_device_properties( + torch.device("cuda:0")) + sm_count = device_properties.multi_processor_count + workspace_size = ops.sm100_cutlass_mla_get_workspace_size( + max_seqlen * block_size, b, sm_count, num_kv_splits=1) + workspace = torch.empty(workspace_size, + device="cuda", + dtype=torch.uint8) + + out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) + + ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat, + cache_seqlens, block_table, workspace, + scale, 1) + return out_ans[:, :h_q].contiguous() + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, + dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q + blocked_k_ = (blocked_k.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_k + blocked_v_ = (blocked_v.to(torch.float) * + descale_k).to(init_dtype) if use_fp8 else blocked_v + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + out_i, lse_i = scaled_dot_product_attention( + q_[i].transpose(0, 1), + blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = out_i.transpose(0, 1) + lse[i] = lse_i + return out, lse + + out_cutlass = cutlass_mla() + out_torch, lse_torch = ref_mla() + # Extract the single token (s_q=1) slice to match cutlass output shape + out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] + cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) + + t = triton.testing.do_bench(cutlass_mla) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + + b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( + b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", + f"{bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 5cbb7346436e..c65c987c0e48 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -500,8 +500,8 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str, else: attention_backend = "FLASHMLA" - # Only FlashMLA supports fp8 - if attention_backend == "FLASHMLA": + # Only FlashMLA and CUTLASS_MLA support fp8 + if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]: supported = True else: supported = (not fp8_attention) diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 8a17d3a49278..705307d4dea3 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -108,10 +108,6 @@ def __init__( "are not implemented for " "CutlassMLAImpl") - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "CutlassMLA V1 with FP8 KV cache not yet supported") - self._use_old_cutlass_mla = False force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) if force_old_cutlass: @@ -182,11 +178,10 @@ def _sm100_cutlass_mla_decode( > 0), f"block num must be greater than 0, got {block_num}" assert block_num % (128 / PAGE_SIZE) == 0 - # TODO(kaixih@nvidia): support fp8 assert q_nope.dtype in ( - torch.float16, - torch.bfloat16, - ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." + torch.float16, torch.bfloat16, torch.float8_e4m3fn), ( + f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got " + f"{q_nope.dtype}.") assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype assert ( seq_lens.dtype == torch.int32 @@ -195,7 +190,9 @@ def _sm100_cutlass_mla_decode( page_table.dtype == torch.int32 ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." - out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) + dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype) + else q_nope.dtype) + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype) ops.sm100_cutlass_mla_decode( out, @@ -220,9 +217,6 @@ def _sm100_forward_decode( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") - # Adjust workspace size (if necessary) self._workspace.ensure_size(attn_metadata, self._num_kv_splits) @@ -252,8 +246,9 @@ def _old_forward_decode( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Cutlass MLA not yet supported") + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA") B = q_nope.shape[0]