Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ struct IsPersistent {
static const bool value = v;
};

template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
template <typename T, typename TOut, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
using ElementOut = T;
using ElementOut = TOut;

using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
Expand Down Expand Up @@ -178,7 +178,7 @@ typename T::Fmha::Arguments args_from_options(
return arguments;
}

template <typename Element, bool IsPaged128, typename PersistenceOption>
template <typename Element, typename ElementOut, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope,
Expand All @@ -190,7 +190,7 @@ void runMla(
double sm_scale,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
using MlaSm100Type = MlaSm100<Element, ElementOut, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);

Expand Down Expand Up @@ -233,13 +233,13 @@ void sm100_cutlass_mla_decode(
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
runMla<cutlass::half_t, cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
runMla<cutlass::bfloat16_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
runMla<cutlass::float_e4m3_t, cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
Expand All @@ -253,7 +253,7 @@ void sm100_cutlass_mla_decode(
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
using MlaSm100Type = MlaSm100<cutlass::half_t, cutlass::half_t, true>;

// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
Expand Down
250 changes: 167 additions & 83 deletions tests/kernels/test_cutlass_mla_decode.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,180 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import random

import pytest
import torch
import torch.nn.functional as F
from torch import Tensor

import vllm._custom_ops as ops
from vllm.platforms import current_platform

if not current_platform.has_device_capability(100):
pytest.skip(
reason="Cutlass MLA Requires compute capability of 10 or above.",
allow_module_level=True)


def ref_mla(
out: Tensor, # (bs, num_heads, v_head_dim)
query: Tensor, # (bs, num_heads, head_dim)
kv_cache: Tensor, # (num_blocks, block_size, head_dim)
scale: float,
block_tables: Tensor, # (bs, max_num_blocks)
seq_lens: Tensor, # (bs,)
):
bs, num_heads, v_head_dim = out.shape
head_dim = query.shape[2]

for i in range(bs):
# gather and flatten KV-cache
kv = kv_cache[
block_tables[i]] # (max_num_blocks, block_size, head_dim)
kv = kv.view(1, -1,
head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim)
v = kv[:, :, :v_head_dim]

q = query[i].view(num_heads, 1, head_dim)
o = F.scaled_dot_product_attention(q,
kv,
v,
scale=scale,
enable_gqa=True)
out[i] = o.view(num_heads, v_head_dim)

return out


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096])
@pytest.mark.parametrize("bs", [1, 2, 4])
from vllm.triton_utils import triton


def cal_diff(x: torch.Tensor,
y: torch.Tensor,
name: str,
use_fp8: bool = False) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
if (use_fp8):
assert cos_diff < 1e-4
else:
assert cos_diff < 1e-5


CUTLASS_MLA_UNSUPPORTED_REASON = \
"Cutlass MLA Requires compute capability of 10 or above." \
if not current_platform.is_device_capability(100) \
else "Cutlass MLA is supported"


@pytest.mark.skipif(not current_platform.has_device_capability(100),
reason=CUTLASS_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1])
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576])
@pytest.mark.parametrize("dv", [512])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("block_size", [16, 64, 128])
def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
varlen: bool, block_size: int):
torch.set_default_dtype(dtype)
torch.set_default_device('cuda')
@pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn])
@torch.inference_mode()
def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size,
causal, varlen, torch_dtype):
device = torch.device("cuda:0")
if torch_dtype == torch.float8_e4m3fn:
init_dtype = torch.bfloat16
else:
init_dtype = torch_dtype
torch.set_default_dtype(init_dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(42)
random.seed(42)

d = 576
h_q = 128
dv = 512
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")

q_nope_dim = 128
q_pe_dim = 64
scale = (q_nope_dim + q_pe_dim)**(-0.5)
use_fp8 = torch_dtype == torch.float8_e4m3fn
scale = math.sqrt(d)**(-1)
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen:
seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2)
seq_lens = seq_lens.clip(2).to(torch.int32)
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
s_q)
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256

q = torch.randn(b, s_q, h_q, d)
block_table = torch.arange(b * max_seqlen_pad // block_size,
dtype=torch.int32).view(
b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
blocked_v = blocked_k[..., :dv]

init_dtype = q.dtype
if use_fp8:
fp8_dtype = torch.float8_e4m3fn
descale_q = torch.ones((1), dtype=torch.float32)
descale_k = torch.ones((1), dtype=torch.float32)

q = q.to(fp8_dtype)
blocked_k = blocked_k.to(fp8_dtype)
blocked_v = blocked_v.to(fp8_dtype)
else:
seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32)
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size

# Pad block_num so that small blocks can be packed into full 128-sized
# CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small
# blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

# Amplify input values to ensure test coverage of edge cases where CUTLASS
# kernel errors occur with split_k settings.
q = torch.randn(bs, h_q, d) * 100
block_table = torch.randint(0,
bs * block_num, (bs, block_num),
dtype=torch.int32)

kv_cache = torch.randn(block_table.numel(), block_size, d)

out_ref = q.new_zeros(bs, h_q, dv)
ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens)
out_ans = torch.zeros_like(out_ref)
q_nope = q[:, :, :dv].clone()
q_pe = q[:, :, dv:].clone()
ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens,
block_table, scale)

torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2)
descale_q = None
descale_k = None

def cutlass_mla():
MAX_HEADS = 128

q_reshaped = q.squeeze(1)
q_nope = q_reshaped[:, :, :dv].clone()
q_pe = q_reshaped[:, :, dv:].clone()

if h_q < MAX_HEADS:
q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv))
q_nope_padded[:, :h_q] = q_nope
q_nope = q_nope_padded

q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv))
q_pe_padded[:, :h_q] = q_pe
q_pe = q_pe_padded

kv_cache_flat = blocked_k.squeeze(2)
device_properties = torch.cuda.get_device_properties(
torch.device("cuda:0"))
sm_count = device_properties.multi_processor_count
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
max_seqlen * block_size, b, sm_count, num_kv_splits=1)
workspace = torch.empty(workspace_size,
device="cuda",
dtype=torch.uint8)

out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype)

ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat,
cache_seqlens, block_table, workspace,
scale, 1)
return out_ans[:, :h_q].contiguous()

def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k,
dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse

def ref_mla():
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
blocked_k_ = (blocked_k.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_k
blocked_v_ = (blocked_v.to(torch.float) *
descale_k).to(init_dtype) if use_fp8 else blocked_v
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
out_i, lse_i = scaled_dot_product_attention(
q_[i].transpose(0, 1),
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal,
)
out[i] = out_i.transpose(0, 1)
lse[i] = lse_i
return out, lse

out_cutlass = cutlass_mla()
out_torch, lse_torch = ref_mla()
# Extract the single token (s_q=1) slice to match cutlass output shape
out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv]
cal_diff(out_cutlass, out_torch_slice, "out", use_fp8)

t = triton.testing.do_bench(cutlass_mla)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d +
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
f"{bytes / 10 ** 6 / t:.0f} GB/s")
4 changes: 2 additions & 2 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,8 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str,
else:
attention_backend = "FLASHMLA"

# Only FlashMLA supports fp8
if attention_backend == "FLASHMLA":
# Only FlashMLA and CUTLASS_MLA support fp8
if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
supported = True
else:
supported = (not fp8_attention)
Expand Down
23 changes: 9 additions & 14 deletions vllm/v1/attention/backends/mla/cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ def __init__(
"are not implemented for "
"CutlassMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")

self._use_old_cutlass_mla = False
force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None)
if force_old_cutlass:
Expand Down Expand Up @@ -182,11 +178,10 @@ def _sm100_cutlass_mla_decode(
> 0), f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0

# TODO(kaixih@nvidia): support fp8
assert q_nope.dtype in (
torch.float16,
torch.bfloat16,
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got "
f"{q_nope.dtype}.")
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
assert (
seq_lens.dtype == torch.int32
Expand All @@ -195,7 +190,9 @@ def _sm100_cutlass_mla_decode(
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."

out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
dtype = (torch.bfloat16 if is_quantized_kv_cache(self.kv_cache_dtype)
else q_nope.dtype)
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)

ops.sm100_cutlass_mla_decode(
out,
Expand All @@ -220,9 +217,6 @@ def _sm100_forward_decode(
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")

# Adjust workspace size (if necessary)
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)

Expand Down Expand Up @@ -252,8 +246,9 @@ def _old_forward_decode(
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FP8 Cutlass MLA not supported with FORCE_OLD_CUTLASS_MLA")

B = q_nope.shape[0]

Expand Down