Skip to content

Commit e087890

Browse files
MatthewBonanniepwalsh
authored andcommitted
[Kernel] Add FP8 support with FlashMLA backend (vllm-project#22668)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
1 parent 46d9ec2 commit e087890

File tree

19 files changed

+235
-109
lines changed

19 files changed

+235
-109
lines changed

cmake/external_projects/flashmla.cmake

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ else()
1919
FetchContent_Declare(
2020
flashmla
2121
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22-
GIT_TAG 0e43e774597682284358ff2c54530757b654b8d1
22+
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
2323
GIT_PROGRESS TRUE
2424
CONFIGURE_COMMAND ""
2525
BUILD_COMMAND ""
@@ -37,13 +37,14 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
3737
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
3838
set(FlashMLA_SOURCES
3939
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
40-
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
40+
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
4141
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
42-
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu)
42+
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
43+
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
4344

4445
set(FlashMLA_INCLUDES
4546
${flashmla_SOURCE_DIR}/csrc/cutlass/include
46-
${flashmla_SOURCE_DIR}/csrc/include)
47+
${flashmla_SOURCE_DIR}/csrc)
4748

4849
set_gencode_flags_for_srcs(
4950
SRCS "${FlashMLA_SOURCES}"

csrc/cache.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,11 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
4040
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
4141
const double scale, const std::string& kv_cache_dtype);
4242

43-
void gather_cache(
43+
void gather_and_maybe_dequant_cache(
4444
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
4545
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
4646
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
4747
torch::Tensor const& cu_seq_lens, // [BATCH+1]
48-
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);
48+
int64_t batch_size, const std::string& kv_cache_dtype,
49+
torch::Tensor const& scale,
50+
std::optional<torch::Tensor> seq_starts = std::nullopt);

csrc/cache_kernels.cu

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -624,16 +624,17 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
624624
namespace vllm {
625625

626626
// grid is launched with dimensions (batch, num_splits)
627-
template <typename scalar_t>
628-
__global__ void gather_cache(
629-
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
627+
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
628+
__global__ void gather_and_maybe_dequant_cache(
629+
const cache_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
630630
// ENTRIES...]
631631
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
632632
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
633633
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
634634
const int32_t block_size, const int32_t entry_size,
635635
const int64_t block_table_stride, const int64_t cache_block_stride,
636636
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
637+
const float* __restrict__ scale,
637638
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
638639
// batch
639640

@@ -675,10 +676,16 @@ __global__ void gather_cache(
675676
if (partial_block_size) full_blocks_end -= 1;
676677
}
677678

678-
auto copy_entry = [&](const scalar_t* __restrict__ _src,
679+
auto copy_entry = [&](const cache_t* __restrict__ _src,
679680
scalar_t* __restrict__ _dst) {
680-
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
681-
_dst[i] = _src[i];
681+
for (int i = threadIdx.x; i < entry_size; i += blockDim.x) {
682+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
683+
_dst[i] = static_cast<scalar_t>(_src[i]);
684+
} else {
685+
_dst[i] =
686+
fp8::scaled_convert<scalar_t, cache_t, kv_dt>(_src[i], *scale);
687+
}
688+
}
682689
};
683690

684691
for (int pid = split_start; pid < full_blocks_end; ++pid) {
@@ -705,25 +712,31 @@ __global__ void gather_cache(
705712
} // namespace vllm
706713

707714
// Macro to dispatch the kernel based on the data type.
708-
#define CALL_GATHER_CACHE(CPY_DTYPE) \
709-
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
710-
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
711-
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
712-
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
713-
block_size, entry_size, block_table_stride, cache_block_stride, \
714-
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
715+
// SCALAR_T is the data type of the destination tensor.
716+
// CACHE_T is the stored data type of kv-cache.
717+
// KV_DTYPE is the real data type of kv-cache.
718+
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
719+
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE> \
720+
<<<grid, block, 0, stream>>>( \
721+
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
722+
reinterpret_cast<SCALAR_T*>(dst.data_ptr()), \
723+
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
724+
block_size, entry_size, block_table_stride, cache_block_stride, \
725+
cache_entry_stride, dst_entry_stride, \
726+
reinterpret_cast<const float*>(scale.data_ptr()), seq_starts_ptr);
715727

716728
// Gather sequences from the cache into the destination tensor.
717729
// - cu_seq_lens contains the cumulative sequence lengths for each batch
718730
// - block_table contains the cache block indices for each sequence
719731
// - Optionally, seq_starts (if provided) offsets the starting block index by
720732
// (seq_starts[bid] / page_size)
721-
void gather_cache(
733+
void gather_and_maybe_dequant_cache(
722734
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
723735
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
724736
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
725737
torch::Tensor const& cu_seq_lens, // [BATCH+1]
726-
int64_t batch_size,
738+
int64_t batch_size, const std::string& kv_cache_dtype,
739+
torch::Tensor const& scale,
727740
std::optional<torch::Tensor> seq_starts = std::nullopt) {
728741
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
729742
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -761,20 +774,8 @@ void gather_cache(
761774
dim3 grid(batch_size, num_splits);
762775
dim3 block(1024);
763776

764-
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
765-
"src_cache and dst must have the same dtype");
766-
767-
const int dtype_bits = src_cache.element_size() * 8;
768777
const int32_t* seq_starts_ptr =
769778
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
770779

771-
if (dtype_bits == 32) {
772-
CALL_GATHER_CACHE(uint32_t);
773-
} else if (dtype_bits == 16) {
774-
CALL_GATHER_CACHE(uint16_t);
775-
} else if (dtype_bits == 8) {
776-
CALL_GATHER_CACHE(uint8_t);
777-
} else {
778-
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
779-
}
780+
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
780781
}

csrc/torch_bindings.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
672672
"str kv_cache_dtype) -> ()");
673673
cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
674674

675-
// Gather cache blocks from src_cache to dst.
675+
// Gather cache blocks from src_cache to dst, dequantizing from
676+
// src_cache's dtype to dst's dtype if necessary.
676677
cache_ops.def(
677-
"gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, "
678-
"Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()");
679-
cache_ops.impl("gather_cache", torch::kCUDA, &gather_cache);
678+
"gather_and_maybe_dequant_cache(Tensor src_cache, Tensor! dst, "
679+
" Tensor block_table, Tensor cu_seq_lens, "
680+
" int batch_size, "
681+
" str kv_cache_dtype, "
682+
" Tensor scale, Tensor? seq_starts) -> ()");
683+
cache_ops.impl("gather_and_maybe_dequant_cache", torch::kCUDA,
684+
&gather_and_maybe_dequant_cache);
680685
}
681686

682687
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {

tests/kernels/attention/test_cache.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -709,14 +709,15 @@ def test_swap_blocks_mla(
709709
@pytest.mark.parametrize("max_seq_len", [512])
710710
@pytest.mark.parametrize("batch_size", [8])
711711
@pytest.mark.parametrize("dtype", [torch.float32])
712-
@pytest.mark.parametrize("kv_cache_dtype",
713-
["auto"]) # You can also test "fp8" if needed.
712+
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
714713
@pytest.mark.parametrize("device", CUDA_DEVICES)
715714
@torch.inference_mode()
716-
def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
717-
num_blocks, max_seq_len, batch_size, dtype,
718-
kv_cache_dtype, device):
715+
def test_gather_and_maybe_dequant_cache_mla(kv_lora_rank, qk_rope_head_dim,
716+
block_size, num_blocks,
717+
max_seq_len, batch_size, dtype,
718+
kv_cache_dtype, device):
719719
entry_size = kv_lora_rank + qk_rope_head_dim
720+
scale = torch.tensor(0.1, dtype=torch.float32, device=device)
720721
src_cache = _create_mla_cache(num_blocks, block_size, entry_size, dtype,
721722
kv_cache_dtype, device)
722723
_fill_mla_cache(src_cache, kv_cache_dtype=kv_cache_dtype)
@@ -742,9 +743,7 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
742743
perm = torch.randperm(num_blocks, device=device)
743744
block_table[b, :] = perm
744745

745-
dst = torch.zeros((total_tokens, entry_size),
746-
dtype=src_cache.dtype,
747-
device=device)
746+
dst = torch.zeros((total_tokens, entry_size), dtype=dtype, device=device)
748747

749748
expected_batches = []
750749
for b in range(batch_size):
@@ -756,21 +755,38 @@ def test_gather_cache_mla(kv_lora_rank, qk_rope_head_dim, block_size,
756755

757756
gathered_rows = []
758757
for i in range(tot - 1):
759-
gathered_rows.append(src_cache[blocks[i]])
758+
block_data = src_cache[blocks[i]]
759+
if kv_cache_dtype == "fp8":
760+
dequantized_block = torch.empty_like(block_data, dtype=dtype)
761+
ops.convert_fp8(dequantized_block, block_data, scale.item())
762+
gathered_rows.append(dequantized_block)
763+
else:
764+
gathered_rows.append(block_data)
760765
remaining = s - (tot - 1) * block_size
761-
gathered_rows.append(src_cache[blocks[-1], :remaining, :])
766+
last_block_data = src_cache[blocks[-1], :remaining, :]
767+
if kv_cache_dtype == "fp8":
768+
dequantized_last_block = torch.empty_like(last_block_data,
769+
dtype=dtype)
770+
ops.convert_fp8(dequantized_last_block, last_block_data,
771+
scale.item())
772+
gathered_rows.append(dequantized_last_block)
773+
else:
774+
gathered_rows.append(last_block_data)
762775

763776
batch_expected = torch.cat(gathered_rows, dim=0)
764777
expected_batches.append(batch_expected)
765778
expected = torch.cat(expected_batches, dim=0)
766779

767780
opcheck(
768-
torch.ops._C_cache_ops.gather_cache,
769-
(src_cache, dst, block_table, cu_seq_lens, batch_size, None),
781+
torch.ops._C_cache_ops.gather_and_maybe_dequant_cache,
782+
(src_cache, dst, block_table, cu_seq_lens, batch_size, kv_cache_dtype,
783+
scale, None),
770784
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
771785
)
772786

773-
ops.gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size)
787+
ops.gather_and_maybe_dequant_cache(src_cache, dst, block_table,
788+
cu_seq_lens, batch_size, kv_cache_dtype,
789+
scale, None)
774790
torch.testing.assert_close(dst, expected)
775791

776792

tests/kernels/attention/test_flashmla.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@
1313
from vllm.triton_utils import triton
1414

1515

16-
def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
16+
def cal_diff(x: torch.Tensor,
17+
y: torch.Tensor,
18+
name: str,
19+
use_fp8: bool = False) -> None:
1720
x, y = x.double(), y.double()
1821
cos_diff = 1 - 2 * (x * y).sum().item() / max(
1922
(x * x + y * y).sum().item(), 1e-12)
20-
assert cos_diff < 1e-5
23+
if (use_fp8):
24+
assert cos_diff < 1e-4
25+
else:
26+
assert cos_diff < 1e-5
2127

2228
FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
2329
if not is_flashmla_supported()[0] else "FlashMLA is supported"
@@ -27,28 +33,34 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
2733
reason=FLASH_MLA_UNSUPPORTED_REASON)
2834
@pytest.mark.parametrize("b", [128])
2935
@pytest.mark.parametrize("s_q", [1, 2])
30-
@pytest.mark.parametrize("mean_sk", [4096, 8192])
36+
@pytest.mark.parametrize("mean_sk", [4096, 8192, 16384])
3137
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
3238
@pytest.mark.parametrize("h_kv", [1])
3339
@pytest.mark.parametrize("d", [576])
3440
@pytest.mark.parametrize("dv", [512])
3541
@pytest.mark.parametrize("block_size", [64])
3642
@pytest.mark.parametrize("causal", [True])
3743
@pytest.mark.parametrize("varlen", [False, True])
38-
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
44+
@pytest.mark.parametrize("torch_dtype",
45+
[torch.bfloat16, torch.float16, torch.float8_e4m3fn])
3946
@torch.inference_mode()
4047
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
41-
varlen, dtype):
48+
varlen, torch_dtype):
4249
device = torch.device("cuda:0")
43-
torch.set_default_dtype(dtype)
50+
if torch_dtype == torch.float8_e4m3fn:
51+
init_dtype = torch.bfloat16
52+
else:
53+
init_dtype = torch_dtype
54+
torch.set_default_dtype(init_dtype)
4455
torch.set_default_device(device)
4556
torch.cuda.set_device(device)
4657
torch.manual_seed(0)
4758
random.seed(0)
4859

4960
print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
50-
f"{d=}, {dv=}, {causal=}, {varlen=}, {dtype=}")
61+
f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}")
5162

63+
use_fp8 = torch_dtype == torch.float8_e4m3fn
5264
cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
5365
if varlen:
5466
for i in range(b):
@@ -71,6 +83,19 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
7183
tile_scheduler_metadata, num_splits = get_mla_metadata(
7284
cache_seqlens, s_q * h_q // h_kv, h_kv)
7385

86+
init_dtype = q.dtype
87+
if use_fp8:
88+
fp8_dtype = torch.float8_e4m3fn
89+
descale_q = torch.ones((1), dtype=torch.float32)
90+
descale_k = torch.ones((1), dtype=torch.float32)
91+
92+
q = q.to(fp8_dtype)
93+
blocked_k = blocked_k.to(fp8_dtype)
94+
blocked_v = blocked_v.to(fp8_dtype)
95+
else:
96+
descale_q = None
97+
descale_k = None
98+
7499
def flash_mla():
75100
return flash_mla_with_kvcache(
76101
q,
@@ -81,6 +106,8 @@ def flash_mla():
81106
tile_scheduler_metadata,
82107
num_splits,
83108
causal=causal,
109+
descale_q=descale_q,
110+
descale_k=descale_k,
84111
)
85112

86113
def scaled_dot_product_attention(query, key, value, is_causal=False):
@@ -104,29 +131,35 @@ def scaled_dot_product_attention(query, key, value, is_causal=False):
104131
return attn_weight @ value, lse
105132

106133
def ref_mla():
134+
q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q
135+
blocked_k_ = (blocked_k.to(torch.float) *
136+
descale_k).to(init_dtype) if use_fp8 else blocked_k
137+
blocked_v_ = (blocked_v.to(torch.float) *
138+
descale_k).to(init_dtype) if use_fp8 else blocked_v
107139
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
108140
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
109141
for i in range(b):
110142
begin = i * max_seqlen_pad
111143
end = begin + cache_seqlens[i]
112-
ref_O, LSE = scaled_dot_product_attention(
113-
q[i].transpose(0, 1),
114-
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
115-
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
144+
out_i, lse_i = scaled_dot_product_attention(
145+
q_[i].transpose(0, 1),
146+
blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1),
147+
blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
116148
is_causal=causal,
117149
)
118-
out[i] = ref_O.transpose(0, 1)
119-
lse[i] = LSE
150+
out[i] = out_i.transpose(0, 1)
151+
lse[i] = lse_i
120152
return out, lse
121153

122154
out_flash, lse_flash = flash_mla()
123155
out_torch, lse_torch = ref_mla()
124-
cal_diff(out_flash, out_torch, "out")
156+
cal_diff(out_flash, out_torch, "out", use_fp8)
125157
cal_diff(lse_flash, lse_torch, "lse")
126158

127159
t = triton.testing.do_bench(flash_mla)
128160
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
129-
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
130-
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
131-
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
132-
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
161+
bytes = (total_seqlens * h_kv * d +
162+
b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + (
163+
b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8)
164+
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,",
165+
f"{bytes / 10 ** 6 / t:.0f} GB/s")

0 commit comments

Comments
 (0)