From c5faa922fddc88644a9111bc88609aabe6655461 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 05:30:46 +0000 Subject: [PATCH 01/26] torch library bindings, unit tests running Signed-off-by: Lucas Wilkinson --- CMakeLists.txt | 2 + cmake/flashmla.cmake | 66 +++++++++++ cmake/flashmla.patch | 32 ++++++ csrc/pytorch_shim.h | 110 ++++++++++++++++++ setup.py | 1 + test_copy.py | 25 +++++ test_merge.py | 196 +++++++++++++++++++++++++++++++++ tests/kernels/test_flashmla.py | 136 +++++++++++++++++++++++ vllm/_custom_ops.py | 61 ++++++++++ vllm/attention/ops/flashmla.py | 87 +++++++++++++++ 10 files changed, 716 insertions(+) create mode 100644 cmake/flashmla.cmake create mode 100644 cmake/flashmla.patch create mode 100644 csrc/pytorch_shim.h create mode 100644 test_copy.py create mode 100644 test_merge.py create mode 100644 tests/kernels/test_flashmla.py create mode 100644 vllm/attention/ops/flashmla.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 02a60c0e3520..c14e1708a6ec 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -580,6 +580,8 @@ if (NOT VLLM_GPU_LANG STREQUAL "CUDA") return() endif () +include(cmake/flashmla.cmake) + # vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target # arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the # arches in the CUDA case (and instead set the gencodes on a per file basis) diff --git a/cmake/flashmla.cmake b/cmake/flashmla.cmake new file mode 100644 index 000000000000..a4cbcc8957de --- /dev/null +++ b/cmake/flashmla.cmake @@ -0,0 +1,66 @@ +include(FetchContent) + +# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory +# instead of downloading. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{FLASH_MLA_SRC_DIR}) + set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR}) +endif() + +if(FLASH_MLA_SRC_DIR) + FetchContent_Declare( + flashmla + SOURCE_DIR ${FLASH_MLA_SRC_DIR} + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + ) +else() + FetchContent_Declare( + flashmla + GIT_REPOSITORY https://github.com/deepseek-ai/FlashMLA + GIT_TAG 414a2f3eedeb5ad3c4a6e89d8641e059519cacc9 + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + PATCH_COMMAND git apply --ignore-whitespace + "${CMAKE_CURRENT_LIST_DIR}/flashmla.patch" + # For incremental builds to prevent the patch from being reapplied, + # https://stackoverflow.com/a/73725257 + UPDATE_DISCONNECTED TRUE + ) +endif() + + +FetchContent_MakeAvailable(flashmla) +message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") + +# The machete kernels only work on hopper and require CUDA 12.0 or later. +# Only build Machete kernels if we are building for something compatible with sm90a +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FLASH_MLA_ARCHS) + set(FlashMLA_SOURCES + ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu) + + set(FlashMLA_INCLUDES + ${PROJECT_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/include) + + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + + define_gpu_extension_target( + _C_flashmla + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_SOURCES} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} + USE_SABI 3 + WITH_SOABI) +endif() + diff --git a/cmake/flashmla.patch b/cmake/flashmla.patch new file mode 100644 index 000000000000..0e97600e2906 --- /dev/null +++ b/cmake/flashmla.patch @@ -0,0 +1,32 @@ +diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp +index 5a1cb8e..65fbfb0 100644 +--- a/csrc/flash_api.cpp ++++ b/csrc/flash_api.cpp +@@ -1,6 +1,6 @@ + // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp + +-#include ++// #include + #include + #include + #include +@@ -196,8 +196,14 @@ mha_fwd_kvcache_mla( + return {out, softmax_lse}; + } + +-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +- m.doc() = "FlashMLA"; +- m.def("get_mla_metadata", &get_mla_metadata); +- m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); ++#include "core/registration.h" ++#include "pytorch_shim.h" ++ ++TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ++ m.def("get_mla_metadata", make_pytorch_shim(&get_mla_metadata)); ++ m.impl("get_mla_metadata", torch::kCUDA, make_pytorch_shim(&get_mla_metadata)); ++ ++ m.def("fwd_kvcache_mla", make_pytorch_shim(&mha_fwd_kvcache_mla)); ++ m.impl("fwd_kvcache_mla", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache_mla)); + } ++REGISTER_EXTENSION(TORCH_EXTENSION_NAME) +\ No newline at end of file diff --git a/csrc/pytorch_shim.h b/csrc/pytorch_shim.h new file mode 100644 index 000000000000..4bedfcd47656 --- /dev/null +++ b/csrc/pytorch_shim.h @@ -0,0 +1,110 @@ +#pragma once + +#include + +/** + * Unforunately, the type signatures of the flash_attn ops are not compatible + * with the PyTorch library bindings. To get around that we use + * `make_pytorch_shim` which creates a lambda that exponses the API using + * PyTorch compatible types to the types, then converts them to the types + * expected by the flash_attn ops. This shims allows us to make minimal changes + * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * + * The `pytorch_library_compatible_type` struct is used to map from the + * flash_attn ops types to a PyTorch library compatible one. The main issues is + * that the following types are not support by PyTorch libary bindings: + * - `int` + * - `float` + * - `c10::optional &` + * - `c10::optional &` + * So we convert them to (respectively): + * - `int64_t` + * - `double` + * - `const c10::optional&` + * - `const c10::optional&` + */ + +template +struct pytorch_library_compatible_type { + using type = T; + static T convert_from_type(T arg) { return arg; } +}; + +template +using pytorch_library_compatible_type_t = + typename pytorch_library_compatible_type::type; + +template +T convert_from_pytorch_compatible_type( + pytorch_library_compatible_type_t arg) { + return pytorch_library_compatible_type::convert_from_type(arg); +} + +// Map `c10::optional &` -> `const c10::optional&` +// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate +// the optional container) +template +struct pytorch_library_compatible_type&> { + using type = const c10::optional&; + static c10::optional& convert_from_type(const c10::optional& arg) { + return const_cast&>(arg); + } +}; + +// Map `c10::optional` -> +// `c10::optional>` +// (NOTE: tested for `c10::optional` -> `c10::optional`) +template +struct pytorch_library_compatible_type> { + using type = c10::optional>; + static c10::optional> convert_from_type( + c10::optional arg) { + return arg; + } +}; + +// Map `c10::optional&` -> `const c10::optional&` +template <> +struct pytorch_library_compatible_type&> { + using type = const c10::optional&; + static c10::optional& convert_from_type( + const c10::optional& arg) { + return const_cast&>( + reinterpret_cast&>(arg)); + } +}; + +// Map `int` -> `int64_t` +template <> +struct pytorch_library_compatible_type { + using type = int64_t; + static int convert_from_type(int64_t arg) { + TORCH_CHECK(arg <= std::numeric_limits::max(), + "int64_t value is too large to be converted to int"); + TORCH_CHECK(arg >= std::numeric_limits::min(), + "int64_t value is too small to be converted to int"); + return arg; + } +}; + +// Map `float` -> `double` +template <> +struct pytorch_library_compatible_type { + using type = double; + static float convert_from_type(double arg) { + TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), + "double value is too large to be converted to float"); + return arg; + } +}; + +// +// Shim Utils +// + +template +auto make_pytorch_shim(Ret (*fun)(Args... args)) { + return [fun](pytorch_library_compatible_type_t... args) { + return fun(convert_from_pytorch_compatible_type(args)...); + }; +} diff --git a/setup.py b/setup.py index d8a336c2d426..8163d29749b8 100755 --- a/setup.py +++ b/setup.py @@ -612,6 +612,7 @@ def _read_requirements(filename: str) -> List[str]: # FA3 requires CUDA 12.0 or later ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + ext_modules.append(CMakeExtension(name="vllm._C_flashmla")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/test_copy.py b/test_copy.py new file mode 100644 index 000000000000..33ed5d81c61f --- /dev/null +++ b/test_copy.py @@ -0,0 +1,25 @@ +import torch +from torch.profiler import profile, record_function, ProfilerActivity + +x = torch.randn(512, 512).cuda() +y = torch.randn(512, 512).cuda() + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + x[...] = x + +print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + x.copy_(x) + +print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + y[...] = x + +print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) + +with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + y.copy_(x) + +print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) \ No newline at end of file diff --git a/test_merge.py b/test_merge.py new file mode 100644 index 000000000000..9e6b88f4bbf4 --- /dev/null +++ b/test_merge.py @@ -0,0 +1,196 @@ +import torch +import triton +import triton.language as tl +import unittest +from typing import Optional +from vllm.vllm_flash_attn import flash_attn_varlen_func + +# === Use your provided merge_attn_states implementation exactly === + +def merge_attn_states( + output: torch.Tensor, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + + merge_attn_states_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + + +@triton.jit +def merge_attn_states_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse, # [NUM_HEADS, NUM_TOKENS] + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse, # [NUM_HEADS, NUM_TOKENS] + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse, # [NUM_HEADS, NUM_TOKENS] + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + # Load lse values for this token & head. + p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) + s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) + + # Determine validity (for causal masking, masked positions will be -∞, which is not finite). + p_valid = tl.isfinite(p_lse) + s_valid = tl.isfinite(s_lse) + both_valid = p_valid & s_valid + only_p = p_valid & (~s_valid) + only_s = s_valid & (~p_valid) + + # Compute merged candidate only if both sides are valid. + max_lse = tl.maximum(p_lse, s_lse) + p_shift = p_lse - max_lse + s_shift = s_lse - max_lse + out_se = tl.exp(p_shift) + tl.exp(s_shift) + + merged_lse_candidate = tl.log(out_se) + max_lse + # If both are valid, merge; otherwise, choose the valid side. + merged_lse = tl.where(both_valid, merged_lse_candidate, tl.where(only_p, p_lse, s_lse)) + + # Optionally store merged lse. + if OUTPUT_LSE: + tl.store(output_lse + head_idx * num_tokens + token_idx, merged_lse) + + # Load the attention outputs. + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + base_offset = token_idx * num_heads * HEAD_SIZE + p_out = tl.load(prefix_output + base_offset + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + s_out = tl.load(suffix_output + base_offset + head_idx * HEAD_SIZE + head_arange, + mask=head_mask) + + # Compute candidate merged output if both sides valid. + p_scale = tl.exp(p_shift) / out_se + s_scale = tl.exp(s_shift) / out_se + merged_output_candidate = p_out * p_scale + s_out * s_scale + merged_output = tl.where(both_valid, merged_output_candidate, + tl.where(only_p, p_out, s_out)) + + tl.store(output + base_offset + head_idx * HEAD_SIZE + head_arange, + merged_output, + mask=head_mask) + + +# === Single test: iterative merge (via multiple flash_attn_varlen_func calls) +# vs. a single unchunked call. We transpose the softmax lse outputs +# because FlashAttention returns them as [NUM_TOKENS, NUM_HEADS], +# but our merge kernel expects [NUM_HEADS, NUM_TOKENS]. === + +class TestFlashAttnMerge(unittest.TestCase): + def test_flash_attn_merge(self): + torch.manual_seed(0) + device = "cuda" + # Dimensions: + num_tokens = 16 # number of query tokens + num_heads = 4 + HEAD_SIZE = 8 + chunk_max_seq_len = 16 # keys/values length per chunk + num_chunks = 3 + max_query_len = num_tokens # for simplicity + softmax_scale = 1.0 + + # Create a fixed query tensor in fp16. + q = torch.randn(num_tokens, num_heads, HEAD_SIZE, device=device, dtype=torch.float16) + cu_seqlens_q = torch.tensor([0, num_tokens], device=device, dtype=torch.int32) + + # Compute chunked attention outputs. + # (Note: flash_attn_varlen_func returns softmax_lse in shape [NUM_TOKENS, NUM_HEADS], + # so we transpose it to [NUM_HEADS, NUM_TOKENS] for merging.) + chunks_output = [] + chunks_lse = [] + chunks_k = [] + chunks_v = [] + for _ in range(num_chunks): + chunk_k = torch.randn(chunk_max_seq_len, num_heads, HEAD_SIZE, device=device, dtype=torch.float16) + chunk_v = torch.randn(chunk_max_seq_len, num_heads, HEAD_SIZE, device=device, dtype=torch.float16) + chunks_k.append(chunk_k) + chunks_v.append(chunk_v) + cu_seqlens_k = torch.tensor([0, chunk_max_seq_len], device=device, dtype=torch.int32) + attn_output, attn_softmax_lse = flash_attn_varlen_func( + q=q, + k=chunk_k, + v=chunk_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_query_len, + max_seqlen_k=chunk_max_seq_len, + softmax_scale=softmax_scale, + causal=True, + return_softmax_lse=True, + fa_version=3, + ) + chunks_output.append(attn_output) + # Transpose lse from [num_tokens, num_heads] to [num_heads, num_tokens] + chunks_lse.append(attn_softmax_lse.transpose(0, 1).contiguous()) + + # Iteratively merge the chunk outputs. + # Allocate temporary tensor for merged lse with shape [num_heads, num_tokens]. + merged_output = chunks_output[0].clone() + merged_lse = chunks_lse[0].clone() + for i in range(1, num_chunks): + tmp_output = torch.empty_like(merged_output) + tmp_lse = torch.empty_like(merged_lse) + merge_attn_states( + tmp_output, + merged_output, + merged_lse, + chunks_output[i], + chunks_lse[i], + tmp_lse, + ) + merged_output = tmp_output + merged_lse = tmp_lse + + # Unchunked version: concatenate keys and values and call flash_attn_varlen_func once. + full_k = torch.cat(chunks_k, dim=0) # shape: (num_chunks*chunk_max_seq_len, num_heads, HEAD_SIZE) + full_v = torch.cat(chunks_v, dim=0) + total_seq_len = num_chunks * chunk_max_seq_len + cu_seqlens_k_full = torch.tensor([0, total_seq_len], device=device, dtype=torch.int32) + attn_output_full, attn_softmax_lse_full = flash_attn_varlen_func( + q=q, + k=full_k, + v=full_v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k_full, + max_seqlen_q=max_query_len, + max_seqlen_k=total_seq_len, + softmax_scale=softmax_scale, + causal=True, + return_softmax_lse=True, + fa_version=3, + ) + # Transpose the full lse to [num_heads, num_tokens] for comparison. + attn_softmax_lse_full = attn_softmax_lse_full.transpose(0, 1).contiguous() + + # Compare the merged (iterative) result with the unchunked result. + # (fp16 numerics are less precise, so we use a looser tolerance.) + torch.testing.assert_close(merged_output, attn_output_full, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(merged_lse, attn_softmax_lse_full, atol=1e-3, rtol=1e-3) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py new file mode 100644 index 000000000000..c2d5f53d3fe3 --- /dev/null +++ b/tests/kernels/test_flashmla.py @@ -0,0 +1,136 @@ +import math +import random +import pytest + +import torch +import triton + +from vllm.attention.ops.flashmla import ( + is_flashmla_supported, + flash_mla_with_kvcache, + get_mla_metadata, +) + + +def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: + x, y = x.double(), y.double() + RMSE = ((x - y) * (x - y)).mean().sqrt().item() + cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) + amax_diff = (x - y).abs().max().item() + # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + assert cos_diff < 1e-5 + +@pytest.mark.parametrize("b", [128]) +@pytest.mark.parametrize("s_q", [1, 2]) +@pytest.mark.parametrize("mean_sk", [4096, 8192]) +@pytest.mark.parametrize("h_q", [16, 32, 64, 128]) +@pytest.mark.parametrize("h_kv", [1]) +@pytest.mark.parametrize("d", [576]) +@pytest.mark.parametrize("dv", [512]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("varlen", [False, True]) +@torch.inference_mode() +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen): + # TODO: parametrize using pytest + dtype = torch.bfloat16 + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, blocked_k, block_table, cache_seqlens, dv, + tile_scheduler_metadata, num_splits, causal=causal, + ) + + def scaled_dot_product_attention(query, key, value, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_flash, lse_flash = flash_mla() + out_torch, lse_torch = ref_mla() + cal_diff(out_flash, out_torch, "out") + cal_diff(lse_flash, lse_torch, "lse") + + t = triton.testing.do_bench(flash_mla, fast_flush=False) + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + + +# if __name__ == "__main__": +# dtype = torch.bfloat16 +# device = torch.device("cuda:0") +# torch.set_default_dtype(dtype) +# torch.set_default_device(device) +# torch.cuda.set_device(device) +# torch.manual_seed(0) +# random.seed(0) + +# h_kv = 1 +# d, dv = 576, 512 +# causal = True + +# for b in [128]: +# for s in [4096, 8192]: +# for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 +# for s_q in [1, 2]: # MTP = 1, 2 +# for varlen in [False, True]: +# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3306610ad800..584c5f83e201 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1163,3 +1163,64 @@ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + +def get_flash_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py new file mode 100644 index 000000000000..c1850f44047e --- /dev/null +++ b/vllm/attention/ops/flashmla.py @@ -0,0 +1,87 @@ +# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py +from typing import Optional, Tuple + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) +from vllm.platforms import current_platform + + +if not current_platform.is_tpu() and not current_platform.is_hpu(): + try: + import vllm._C_flashmla + _C_flashmla_AVAILABLE = True + except ImportError as e: + logger.warning("Failed to import from vllm._C_flashmla with %r", e) + _C_flashmla_AVAILABLE = False +else: + _C_flashmla_AVAILABLE = False + + +def is_flashmla_supported() -> bool: + return _C_flashmla_AVAILABLE and \ + current_platform.get_device_capability()[0] == 9 + + +def get_mla_metadata( + cache_seqlens: torch.Tensor, + num_heads_per_head_k: int, + num_heads_k: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + cache_seqlens: (batch_size), dtype torch.int32. + num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. + num_heads_k: num_heads_k. + + Return: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: (batch_size + 1), dtype torch.int32. + """ + return torch.ops._C_flashmla.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor, + cache_seqlens: torch.Tensor, + head_dim_v: int, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. + cache_seqlens: (batch_size), torch.int32. + head_dim_v: Head_dim of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + causal: bool. Whether to apply causal attention mask. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, softmax_lse = torch.ops._C_flashmla.fwd_kvcache_mla( + q, + k_cache, + None, + head_dim_v, + cache_seqlens, + block_table, + softmax_scale, + causal, + tile_scheduler_metadata, + num_splits, + ) + return out, softmax_lse From a1832e90711fed1b8e1b269b5e734f7e3f8a0cb9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 05:32:59 +0000 Subject: [PATCH 02/26] comments Signed-off-by: Lucas Wilkinson --- vllm/attention/ops/flashmla.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index c1850f44047e..0f139cef6310 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -85,3 +85,15 @@ def flash_mla_with_kvcache( num_splits, ) return out, softmax_lse + +# +# TODO: Add fake functions +# +# @register_fake("_C_flashmla::get_mla_metadata") +# def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# +# @register_fake("_C_flashmla::fwd_kvcache_mla") +# def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: +# return .... +# From bef305b3eafc70a66572530b3a035ff3fa972abb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 06:23:38 +0000 Subject: [PATCH 03/26] working in eager mode Signed-off-by: Lucas Wilkinson --- tests/kernels/test_flashmla.py | 21 --- vllm/attention/backends/flashmla.py | 177 ++++++++++++++++++++++++++ vllm/attention/backends/mla/common.py | 20 +-- vllm/platforms/cuda.py | 4 +- vllm/platforms/interface.py | 1 + 5 files changed, 190 insertions(+), 33 deletions(-) create mode 100644 vllm/attention/backends/flashmla.py diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py index c2d5f53d3fe3..330cc3e9a0f7 100644 --- a/tests/kernels/test_flashmla.py +++ b/tests/kernels/test_flashmla.py @@ -113,24 +113,3 @@ def ref_mla(): FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") - - -# if __name__ == "__main__": -# dtype = torch.bfloat16 -# device = torch.device("cuda:0") -# torch.set_default_dtype(dtype) -# torch.set_default_device(device) -# torch.cuda.set_device(device) -# torch.manual_seed(0) -# random.seed(0) - -# h_kv = 1 -# d, dv = 576, 512 -# causal = True - -# for b in [128]: -# for s in [4096, 8192]: -# for h_q in [16, 32, 64, 128]: # TP = 8, 4, 2, 1 -# for s_q in [1, 2]: # MTP = 1, 2 -# for varlen in [False, True]: -# test_flash_mla(b, s_q, s, h_q, h_kv, d, dv, causal, varlen) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py new file mode 100644 index 000000000000..771f13dc444c --- /dev/null +++ b/vllm/attention/backends/flashmla.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, List, Optional, Type, Set, Tuple + +import torch + +from dataclasses import asdict, dataclass +from vllm.attention.backends.abstract import AttentionType +from vllm.attention.backends.mla.common import (MLACommonBackend, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder, + MLACommonState) +from vllm.attention.ops.flashmla import ( + is_flashmla_supported, + flash_mla_with_kvcache, + get_mla_metadata, +) + + +class FlashMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASHMLA" + + @staticmethod + def get_impl_cls() -> Type["FlashMLAImpl"]: + return FlashMLAImpl + + @staticmethod + def get_metadata_cls() -> Type["FlashMLAMetadata"]: + return FlashMLAMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: + return FlashMLAMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashMLAState"]: + return FlashMLAState + +@dataclass +class FlashMLAMetadata(MLACommonMetadata): + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + decode_num_splits: Optional[torch.Tensor] = None + + _cached_decode_metadata: Optional["MLACommonMetadata"] = None + + @property + def decode_metadata(self): + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + + common_decode_metadata = super().decode_metadata + self._cached_decode_metadata = FlashMLAMetadata( + # TODO: cached but can this be faster? + **asdict(common_decode_metadata), + decode_tile_scheduler_metadata=self.decode_tile_scheduler_metadata, + decode_num_splits=self.decode_num_splits, + ) + return self._cached_decode_metadata + + +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + common_metadata = super().build( + seq_lens, query_lens, cuda_graph_pad_size, batch_size) + + decode_tile_scheduler_metadata, decode_num_splits = None, None + if common_metadata.num_decode_tokens > 0: + decode_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + common_metadata.seq_lens_tensor[common_metadata.num_prefills:], + self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + 1, + ) + + return FlashMLAMetadata( + # TODO: not on hotpath but can this be faster? + **asdict(common_metadata), + decode_tile_scheduler_metadata=decode_tile_scheduler_metadata, + decode_num_splits=decode_num_splits, + ) + +class FlashMLAState(MLACommonState): + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + input_buffers = super().get_graph_input_buffers(attn_metadata, + is_encoder_decoder_model) + if attn_metadata.tile_scheduler_metadata is not None: + tile_scheduler_metadata = attn_metadata.tile_scheduler_metadata + num_splits = attn_metadata.num_splits + input_buffers["tile_scheduler_metadata"] = tile_scheduler_metadata + input_buffers["num_splits"] = num_splits + + return input_buffers + + + +class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + assert is_flashmla_supported(), \ + "FlashMLA is not supported on this device" + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "FlashMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashMLAImpl") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Triton MLA not yet supported") + + decode_meta = attn_metadata.decode_metadata + assert decode_meta is not None + + q = torch.cat([q_nope, q_pe], dim=-1)\ + .unsqueeze(1) # Add seqlen dim of 1 (decode) + + o, _ = flash_mla_with_kvcache( + q=q, + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + head_dim_v=self.kv_lora_rank, + tile_scheduler_metadata=decode_meta.decode_tile_scheduler_metadata, + num_splits=decode_meta.decode_num_splits, + softmax_scale=self.scale, + causal=True, + ) + + return self._v_up_proj_and_o_proj(o) diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 4dd562be3838..8223584466d7 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -506,8 +506,8 @@ class MLACommonMetadata(AttentionMetadata): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional["MLACommonMetadata"] = None - _cached_decode_metadata: Optional["MLACommonMetadata"] = None + _cached_prefill_common_metadata: Optional["MLACommonMetadata"] = None + _cached_decode_common_metadata: Optional["MLACommonMetadata"] = None num_prefill_tokens: int @@ -540,8 +540,8 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: if self.num_prefills == 0: return None - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata + if self._cached_prefill_common_metadata is not None: + return self._cached_prefill_common_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None @@ -564,7 +564,7 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: input_positions = (None if self.input_positions is None else self.input_positions[:self.num_prefill_tokens]) - self._cached_prefill_metadata = MLACommonMetadata( + self._cached_prefill_common_metadata = MLACommonMetadata( # Required by ModelRunner use_cuda_graph=False, # Not Attention Related # Required by Attention Metadata @@ -595,15 +595,15 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: context_chunk_seq_tot=self.context_chunk_seq_tot, context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, ) - return self._cached_prefill_metadata + return self._cached_prefill_common_metadata @property def decode_metadata(self) -> Optional["MLACommonMetadata"]: if self.num_decode_tokens == 0: return None - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata + if self._cached_decode_common_metadata is not None: + return self._cached_decode_common_metadata assert self.seq_lens_tensor is not None # Compute some attn_metadata fields which default to None @@ -616,7 +616,7 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: input_positions = (None if self.input_positions is None else self.input_positions[self.num_prefill_tokens:]) - self._cached_decode_metadata = MLACommonMetadata( + self._cached_decode_common_metadata = MLACommonMetadata( # Required by ModelRunner use_cuda_graph=self.use_cuda_graph, # Not Attention Related # Required by Attention Metadata @@ -647,7 +647,7 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: input_positions=input_positions, head_dim=self.head_dim, is_profile_run=self.is_profile_run) - return self._cached_decode_metadata + return self._cached_decode_common_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bf425b89132e..4f3f1658bf6a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -157,8 +157,8 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" if use_mla: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" + logger.info("Using Flash MLA backend.") + return "vllm.attention.backends.flashmla.FlashMLABackend" if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d6dae2e526dc..0e4988a4fa74 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -35,6 +35,7 @@ class _Backend(enum.Enum): OPENVINO = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() + FLASHMLA = enum.auto() HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() From c74a4f0425e247be7db1553580e1d71cfac69921 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 06:30:28 +0000 Subject: [PATCH 04/26] format Signed-off-by: Lucas Wilkinson --- tests/kernels/test_flashmla.py | 61 +++++++++++++++++------------ vllm/attention/backends/flashmla.py | 41 ++++++++++--------- vllm/attention/ops/flashmla.py | 22 +++++++---- 3 files changed, 72 insertions(+), 52 deletions(-) diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py index 330cc3e9a0f7..a4dceb3da16f 100644 --- a/tests/kernels/test_flashmla.py +++ b/tests/kernels/test_flashmla.py @@ -1,25 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 import math import random -import pytest +import pytest import torch import triton -from vllm.attention.ops.flashmla import ( - is_flashmla_supported, - flash_mla_with_kvcache, - get_mla_metadata, -) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata) def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: x, y = x.double(), y.double() - RMSE = ((x - y) * (x - y)).mean().sqrt().item() - cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) - amax_diff = (x - y).abs().max().item() - # print(f"{name}: {cos_diff=}, {RMSE=}, {amax_diff=}") + cos_diff = 1 - 2 * (x * y).sum().item() / max( + (x * x + y * y).sum().item(), 1e-12) assert cos_diff < 1e-5 + @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192]) @@ -31,7 +28,8 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("varlen", [False, True]) @torch.inference_mode() -def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen): +def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, + varlen): # TODO: parametrize using pytest dtype = torch.bfloat16 device = torch.device("cuda:0") @@ -41,32 +39,42 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, varlen torch.manual_seed(0) random.seed(0) - print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " + f"{d=}, {dv=}, {causal=}, {varlen=}") - cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) if varlen: for i in range(b): - cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), + s_q) total_seqlens = cache_seqlens.sum().item() - mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") q = torch.randn(b, s_q, h_q, d) - block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + block_table = torch.arange(b * max_seqlen_pad // block_size, + dtype=torch.int32).view( + b, max_seqlen_pad // block_size) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) for i in range(b): - blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan") + blocked_k.view(b, max_seqlen_pad, h_kv, + d)[i, cache_seqlens[i].item():] = float("nan") blocked_v = blocked_k[..., :dv] tile_scheduler_metadata, num_splits = get_mla_metadata( - cache_seqlens, s_q * h_q // h_kv, h_kv) + cache_seqlens, s_q * h_q // h_kv, h_kv) def flash_mla(): return flash_mla_with_kvcache( - q, blocked_k, block_table, cache_seqlens, dv, - tile_scheduler_metadata, num_splits, causal=causal, + q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, ) def scaled_dot_product_attention(query, key, value, is_causal=False): @@ -80,7 +88,8 @@ def scaled_dot_product_attention(query, key, value, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) - temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, + dtype=torch.bool).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -94,13 +103,13 @@ def ref_mla(): for i in range(b): begin = i * max_seqlen_pad end = begin + cache_seqlens[i] - O, LSE = scaled_dot_product_attention( + ref_O, LSE = scaled_dot_product_attention( q[i].transpose(0, 1), blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), is_causal=causal, ) - out[i] = O.transpose(0, 1) + out[i] = ref_O.transpose(0, 1) lse[i] = LSE return out, lse @@ -111,5 +120,7 @@ def ref_mla(): t = triton.testing.do_bench(flash_mla, fast_flush=False) FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 - bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} " + f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s") diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 771f13dc444c..255e230f849d 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -1,21 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Type, Set, Tuple +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Tuple, Type import torch -from dataclasses import asdict, dataclass from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder, MLACommonState) -from vllm.attention.ops.flashmla import ( - is_flashmla_supported, - flash_mla_with_kvcache, - get_mla_metadata, -) +from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, + get_mla_metadata, + is_flashmla_supported) class FlashMLABackend(MLACommonBackend): @@ -40,9 +38,11 @@ def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]: def get_state_cls() -> Type["FlashMLAState"]: return FlashMLAState + @dataclass class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, + torch.Tensor]] = None decode_num_splits: Optional[torch.Tensor] = None _cached_decode_metadata: Optional["MLACommonMetadata"] = None @@ -66,45 +66,48 @@ def decode_metadata(self): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): - common_metadata = super().build( - seq_lens, query_lens, cuda_graph_pad_size, batch_size) + common_metadata = super().build(seq_lens, query_lens, + cuda_graph_pad_size, batch_size) decode_tile_scheduler_metadata, decode_num_splits = None, None if common_metadata.num_decode_tokens > 0: - decode_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + decode_tile_scheduler_metadata, decode_num_splits = \ + get_mla_metadata( common_metadata.seq_lens_tensor[common_metadata.num_prefills:], self.runner.model_config.get_num_attention_heads( self.runner.parallel_config), - 1, + 1, ) return FlashMLAMetadata( # TODO: not on hotpath but can this be faster? - **asdict(common_metadata), + **asdict(common_metadata), decode_tile_scheduler_metadata=decode_tile_scheduler_metadata, decode_num_splits=decode_num_splits, ) + class FlashMLAState(MLACommonState): + def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): - input_buffers = super().get_graph_input_buffers(attn_metadata, - is_encoder_decoder_model) + input_buffers = super().get_graph_input_buffers( + attn_metadata, is_encoder_decoder_model) if attn_metadata.tile_scheduler_metadata is not None: tile_scheduler_metadata = attn_metadata.tile_scheduler_metadata num_splits = attn_metadata.num_splits input_buffers["tile_scheduler_metadata"] = tile_scheduler_metadata input_buffers["num_splits"] = num_splits - + return input_buffers - - + class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): @@ -164,7 +167,7 @@ def _forward_decode( o, _ = flash_mla_with_kvcache( q=q, - k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 + k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 block_table=decode_meta.block_tables, cache_seqlens=decode_meta.seq_lens_tensor, head_dim_v=self.kv_lora_rank, diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 0f139cef6310..09a10deda4da 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -1,17 +1,17 @@ +# SPDX-License-Identifier: Apache-2.0 # adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py from typing import Optional, Tuple import torch from vllm.logger import init_logger - -logger = init_logger(__name__) from vllm.platforms import current_platform +logger = init_logger(__name__) if not current_platform.is_tpu() and not current_platform.is_hpu(): try: - import vllm._C_flashmla + import vllm._C_flashmla # noqa: F401 _C_flashmla_AVAILABLE = True except ImportError as e: logger.warning("Failed to import from vllm._C_flashmla with %r", e) @@ -37,10 +37,13 @@ def get_mla_metadata( num_heads_k: num_heads_k. Return: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._C_flashmla.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return torch.ops._C_flashmla.get_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) def flash_mla_with_kvcache( @@ -61,9 +64,11 @@ def flash_mla_with_kvcache( block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), + torch.int32, return by get_mla_metadata. num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. Return: @@ -71,7 +76,7 @@ def flash_mla_with_kvcache( softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = q.shape[-1]**(-0.5) out, softmax_lse = torch.ops._C_flashmla.fwd_kvcache_mla( q, k_cache, @@ -86,6 +91,7 @@ def flash_mla_with_kvcache( ) return out, softmax_lse + # # TODO: Add fake functions # From 1cb71c7a35739e6df16c1f5e40c4c92720df4fb6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 07:01:45 +0000 Subject: [PATCH 05/26] cuda-graphs still broken but closer i think Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flashmla.py | 54 ++++++++++++++++++++++++--- vllm/attention/backends/mla/common.py | 2 +- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 255e230f849d..bf6660792a9e 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Tuple, Type @@ -95,19 +96,62 @@ def build(self, seq_lens: List[int], query_lens: List[int], class FlashMLAState(MLACommonState): + @contextmanager + def graph_capture(self, max_batch_size: int): + # Run a dummy `get_mla_metadata` so we can get the right shapes + self._graph_decoder_tile_scheduler_metadata, \ + self._graph_decode_num_splits = get_mla_metadata( + torch.ones( + max_batch_size, dtype=torch.int32, device=self.runner.device), + self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + 1, + ) + + with super().graph_capture(max_batch_size): + yield + + del self._graph_decoder_tile_scheduler_metadata + del self._graph_decode_num_splits + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + common_metadata = super().graph_capture_get_metadata_for_batch( + batch_size, is_encoder_decoder_model) + assert common_metadata.num_decode_tokens > 0 + + return FlashMLAMetadata( + # TODO: not on hotpath but can this be faster? + **asdict(common_metadata), + decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata, + decode_num_splits=self._graph_decode_num_splits[:batch_size + 1], + ) + def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): input_buffers = super().get_graph_input_buffers( attn_metadata, is_encoder_decoder_model) - if attn_metadata.tile_scheduler_metadata is not None: - tile_scheduler_metadata = attn_metadata.tile_scheduler_metadata - num_splits = attn_metadata.num_splits - input_buffers["tile_scheduler_metadata"] = tile_scheduler_metadata - input_buffers["num_splits"] = num_splits + input_buffers["decode_tile_scheduler_metadata"] = \ + attn_metadata.decode_metadata.decode_tile_scheduler_metadata + input_buffers["decode_num_splits"] = \ + attn_metadata.decode_metadata.decode_num_splits return input_buffers + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + super().prepare_graph_input_buffers(input_buffers, attn_metadata, + is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"].copy_( + attn_metadata.decode_metadata.decode_tile_scheduler_metadata, + non_blocking=True) + input_buffers["decode_num_splits"].copy_( + attn_metadata.decode_metadata.decode_num_splits, non_blocking=True) + class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 8223584466d7..0633c8610b7a 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -357,7 +357,7 @@ def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): assert self._is_graph_capturing - attn_metadata = self.runner.attn_backend.make_metadata( + attn_metadata = MLACommonMetadata( multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, use_cuda_graph=True, From 205f2bc5114ad381fe6082537175328fe08098eb Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 07:06:58 +0000 Subject: [PATCH 06/26] better comments Signed-off-by: Lucas Wilkinson --- csrc/pytorch_shim.h | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/csrc/pytorch_shim.h b/csrc/pytorch_shim.h index 4bedfcd47656..779d40bdcc4a 100644 --- a/csrc/pytorch_shim.h +++ b/csrc/pytorch_shim.h @@ -3,16 +3,29 @@ #include /** - * Unforunately, the type signatures of the flash_attn ops are not compatible - * with the PyTorch library bindings. To get around that we use - * `make_pytorch_shim` which creates a lambda that exponses the API using - * PyTorch compatible types to the types, then converts them to the types - * expected by the flash_attn ops. This shims allows us to make minimal changes - * to `flash_api.cpp` making it easier to synchronize with upstream changes. + * PyBind and PyTorch Library apis generally require different type signatures. + * This file provides a shim to (mostly, there may be missing conversions) to + * convert from function designed to be used with PyBind to one that can be used + * with PyTorch Library. This is done using `make_pytorch_shim` which creates a + * lambda that exponses the API using PyTorch compatible types to the types. + * This is useful when trying to ingergate PyBind based external libraries into + * vLLM. + * + * Example: + * + * PYBIND11_MODULE(NAME, m) { + * m.def("foo", &foo); + * } + * + * could be replaced with (using the shim): + * TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { + * m.def("foo", make_pytorch_shim(&foo)); + * m.impl("foo", torch::kCUDA, make_pytorch_shim(&foo)); + * } * * The `pytorch_library_compatible_type` struct is used to map from the * flash_attn ops types to a PyTorch library compatible one. The main issues is - * that the following types are not support by PyTorch libary bindings: + * that the following types are not support by PyTorch library bindings: * - `int` * - `float` * - `c10::optional &` From 00a1f7a3f7362a08edb66fc7419cdcd65acab13b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 07:12:26 +0000 Subject: [PATCH 07/26] remove extra files Signed-off-by: Lucas Wilkinson --- test_copy.py | 25 ------- test_merge.py | 196 -------------------------------------------------- 2 files changed, 221 deletions(-) delete mode 100644 test_copy.py delete mode 100644 test_merge.py diff --git a/test_copy.py b/test_copy.py deleted file mode 100644 index 33ed5d81c61f..000000000000 --- a/test_copy.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from torch.profiler import profile, record_function, ProfilerActivity - -x = torch.randn(512, 512).cuda() -y = torch.randn(512, 512).cuda() - -with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - x[...] = x - -print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) - -with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - x.copy_(x) - -print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) - -with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - y[...] = x - -print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) - -with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - y.copy_(x) - -print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)) \ No newline at end of file diff --git a/test_merge.py b/test_merge.py deleted file mode 100644 index 9e6b88f4bbf4..000000000000 --- a/test_merge.py +++ /dev/null @@ -1,196 +0,0 @@ -import torch -import triton -import triton.language as tl -import unittest -from typing import Optional -from vllm.vllm_flash_attn import flash_attn_varlen_func - -# === Use your provided merge_attn_states implementation exactly === - -def merge_attn_states( - output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, - output_lse: Optional[torch.Tensor] = None, -) -> None: - num_tokens = output.shape[0] - num_query_heads = output.shape[1] - head_size = output.shape[2] - padded_head_size = triton.next_power_of_2(head_size) - - merge_attn_states_kernel[(num_tokens, num_query_heads)]( - output, - output_lse, - prefix_output, - prefix_lse, - suffix_output, - suffix_lse, - head_size, - padded_head_size, - output_lse is not None, - ) - - -@triton.jit -def merge_attn_states_kernel( - output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - output_lse, # [NUM_HEADS, NUM_TOKENS] - prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - prefix_lse, # [NUM_HEADS, NUM_TOKENS] - suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - suffix_lse, # [NUM_HEADS, NUM_TOKENS] - HEAD_SIZE: tl.constexpr, - PADDED_HEAD_SIZE: tl.constexpr, - OUTPUT_LSE: tl.constexpr, -): - token_idx = tl.program_id(0) - num_tokens = tl.num_programs(0) - head_idx = tl.program_id(1) - num_heads = tl.num_programs(1) - - # Load lse values for this token & head. - p_lse = tl.load(prefix_lse + head_idx * num_tokens + token_idx) - s_lse = tl.load(suffix_lse + head_idx * num_tokens + token_idx) - - # Determine validity (for causal masking, masked positions will be -∞, which is not finite). - p_valid = tl.isfinite(p_lse) - s_valid = tl.isfinite(s_lse) - both_valid = p_valid & s_valid - only_p = p_valid & (~s_valid) - only_s = s_valid & (~p_valid) - - # Compute merged candidate only if both sides are valid. - max_lse = tl.maximum(p_lse, s_lse) - p_shift = p_lse - max_lse - s_shift = s_lse - max_lse - out_se = tl.exp(p_shift) + tl.exp(s_shift) - - merged_lse_candidate = tl.log(out_se) + max_lse - # If both are valid, merge; otherwise, choose the valid side. - merged_lse = tl.where(both_valid, merged_lse_candidate, tl.where(only_p, p_lse, s_lse)) - - # Optionally store merged lse. - if OUTPUT_LSE: - tl.store(output_lse + head_idx * num_tokens + token_idx, merged_lse) - - # Load the attention outputs. - head_arange = tl.arange(0, PADDED_HEAD_SIZE) - head_mask = head_arange < HEAD_SIZE - base_offset = token_idx * num_heads * HEAD_SIZE - p_out = tl.load(prefix_output + base_offset + head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - s_out = tl.load(suffix_output + base_offset + head_idx * HEAD_SIZE + head_arange, - mask=head_mask) - - # Compute candidate merged output if both sides valid. - p_scale = tl.exp(p_shift) / out_se - s_scale = tl.exp(s_shift) / out_se - merged_output_candidate = p_out * p_scale + s_out * s_scale - merged_output = tl.where(both_valid, merged_output_candidate, - tl.where(only_p, p_out, s_out)) - - tl.store(output + base_offset + head_idx * HEAD_SIZE + head_arange, - merged_output, - mask=head_mask) - - -# === Single test: iterative merge (via multiple flash_attn_varlen_func calls) -# vs. a single unchunked call. We transpose the softmax lse outputs -# because FlashAttention returns them as [NUM_TOKENS, NUM_HEADS], -# but our merge kernel expects [NUM_HEADS, NUM_TOKENS]. === - -class TestFlashAttnMerge(unittest.TestCase): - def test_flash_attn_merge(self): - torch.manual_seed(0) - device = "cuda" - # Dimensions: - num_tokens = 16 # number of query tokens - num_heads = 4 - HEAD_SIZE = 8 - chunk_max_seq_len = 16 # keys/values length per chunk - num_chunks = 3 - max_query_len = num_tokens # for simplicity - softmax_scale = 1.0 - - # Create a fixed query tensor in fp16. - q = torch.randn(num_tokens, num_heads, HEAD_SIZE, device=device, dtype=torch.float16) - cu_seqlens_q = torch.tensor([0, num_tokens], device=device, dtype=torch.int32) - - # Compute chunked attention outputs. - # (Note: flash_attn_varlen_func returns softmax_lse in shape [NUM_TOKENS, NUM_HEADS], - # so we transpose it to [NUM_HEADS, NUM_TOKENS] for merging.) - chunks_output = [] - chunks_lse = [] - chunks_k = [] - chunks_v = [] - for _ in range(num_chunks): - chunk_k = torch.randn(chunk_max_seq_len, num_heads, HEAD_SIZE, device=device, dtype=torch.float16) - chunk_v = torch.randn(chunk_max_seq_len, num_heads, HEAD_SIZE, device=device, dtype=torch.float16) - chunks_k.append(chunk_k) - chunks_v.append(chunk_v) - cu_seqlens_k = torch.tensor([0, chunk_max_seq_len], device=device, dtype=torch.int32) - attn_output, attn_softmax_lse = flash_attn_varlen_func( - q=q, - k=chunk_k, - v=chunk_v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_query_len, - max_seqlen_k=chunk_max_seq_len, - softmax_scale=softmax_scale, - causal=True, - return_softmax_lse=True, - fa_version=3, - ) - chunks_output.append(attn_output) - # Transpose lse from [num_tokens, num_heads] to [num_heads, num_tokens] - chunks_lse.append(attn_softmax_lse.transpose(0, 1).contiguous()) - - # Iteratively merge the chunk outputs. - # Allocate temporary tensor for merged lse with shape [num_heads, num_tokens]. - merged_output = chunks_output[0].clone() - merged_lse = chunks_lse[0].clone() - for i in range(1, num_chunks): - tmp_output = torch.empty_like(merged_output) - tmp_lse = torch.empty_like(merged_lse) - merge_attn_states( - tmp_output, - merged_output, - merged_lse, - chunks_output[i], - chunks_lse[i], - tmp_lse, - ) - merged_output = tmp_output - merged_lse = tmp_lse - - # Unchunked version: concatenate keys and values and call flash_attn_varlen_func once. - full_k = torch.cat(chunks_k, dim=0) # shape: (num_chunks*chunk_max_seq_len, num_heads, HEAD_SIZE) - full_v = torch.cat(chunks_v, dim=0) - total_seq_len = num_chunks * chunk_max_seq_len - cu_seqlens_k_full = torch.tensor([0, total_seq_len], device=device, dtype=torch.int32) - attn_output_full, attn_softmax_lse_full = flash_attn_varlen_func( - q=q, - k=full_k, - v=full_v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k_full, - max_seqlen_q=max_query_len, - max_seqlen_k=total_seq_len, - softmax_scale=softmax_scale, - causal=True, - return_softmax_lse=True, - fa_version=3, - ) - # Transpose the full lse to [num_heads, num_tokens] for comparison. - attn_softmax_lse_full = attn_softmax_lse_full.transpose(0, 1).contiguous() - - # Compare the merged (iterative) result with the unchunked result. - # (fp16 numerics are less precise, so we use a looser tolerance.) - torch.testing.assert_close(merged_output, attn_output_full, atol=1e-3, rtol=1e-3) - torch.testing.assert_close(merged_lse, attn_softmax_lse_full, atol=1e-3, rtol=1e-3) - -if __name__ == '__main__': - unittest.main() From 905be3bfdff0bd808d097bb14cf4ac6718a2a110 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 07:19:08 +0000 Subject: [PATCH 08/26] add attribution Signed-off-by: Lucas Wilkinson --- tests/kernels/test_flashmla.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py index a4dceb3da16f..155698b3b4b3 100644 --- a/tests/kernels/test_flashmla.py +++ b/tests/kernels/test_flashmla.py @@ -1,3 +1,4 @@ +# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py # SPDX-License-Identifier: Apache-2.0 import math import random From 728e0b6410c4777225882b5691e4562160a269d7 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 20:50:48 +0000 Subject: [PATCH 09/26] fix cuda graphs Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flashmla.py | 70 +++++++++++++++++++-------- vllm/attention/backends/mla/common.py | 14 +++--- 2 files changed, 58 insertions(+), 26 deletions(-) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index bf6660792a9e..dfaf1f8ef5c5 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from dataclasses import asdict, dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -16,6 +16,9 @@ get_mla_metadata, is_flashmla_supported) +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + class FlashMLABackend(MLACommonBackend): @@ -65,12 +68,25 @@ def decode_metadata(self): ) return self._cached_decode_metadata + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + raise NotImplementedError( + "advance_step is not implemented for FlashMLA") + class FlashMLAMetadataBuilder(MLACommonMetadataBuilder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): common_metadata = super().build(seq_lens, query_lens, @@ -81,9 +97,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], decode_tile_scheduler_metadata, decode_num_splits = \ get_mla_metadata( common_metadata.seq_lens_tensor[common_metadata.num_prefills:], - self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config), - 1, + self.num_q_heads, + 1, # MQA for the decode path ) return FlashMLAMetadata( @@ -94,7 +109,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], ) -class FlashMLAState(MLACommonState): +class FlashMLAState(MLACommonState[FlashMLAMetadata]): + + def __init__(self, *args, **kwds): + super().__init__(*args, **kwds) + + self.num_q_heads = self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config) @contextmanager def graph_capture(self, max_batch_size: int): @@ -103,9 +124,8 @@ def graph_capture(self, max_batch_size: int): self._graph_decode_num_splits = get_mla_metadata( torch.ones( max_batch_size, dtype=torch.int32, device=self.runner.device), - self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config), - 1, + self.num_q_heads, + 1, # MQA for the decode path ) with super().graph_capture(max_batch_size): @@ -116,18 +136,28 @@ def graph_capture(self, max_batch_size: int): def graph_capture_get_metadata_for_batch( self, batch_size: int, is_encoder_decoder_model: bool = False): - common_metadata = super().graph_capture_get_metadata_for_batch( + metadata = super().graph_capture_get_metadata_for_batch( batch_size, is_encoder_decoder_model) - assert common_metadata.num_decode_tokens > 0 + assert metadata.num_decode_tokens > 0 - return FlashMLAMetadata( - # TODO: not on hotpath but can this be faster? - **asdict(common_metadata), - decode_tile_scheduler_metadata=\ - self._graph_decoder_tile_scheduler_metadata, - decode_num_splits=self._graph_decode_num_splits[:batch_size + 1], + decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( + metadata.seq_lens_tensor, + self.num_q_heads, + 1, # MQA for the decode path ) + self._graph_decoder_tile_scheduler_metadata.copy_( + decoder_tile_scheduler_metadata, non_blocking=True) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits, + non_blocking=True) + + metadata.decode_tile_scheduler_metadata=\ + self._graph_decoder_tile_scheduler_metadata + metadata.decode_num_splits=\ + self._graph_decode_num_splits[:batch_size + 1] + + return metadata + def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): @@ -146,11 +176,11 @@ def prepare_graph_input_buffers(self, is_encoder_decoder_model: bool = False): super().prepare_graph_input_buffers(input_buffers, attn_metadata, is_encoder_decoder_model) + input_buffers["decode_tile_scheduler_metadata"].copy_( - attn_metadata.decode_metadata.decode_tile_scheduler_metadata, - non_blocking=True) + attn_metadata.decode_metadata.decode_tile_scheduler_metadata) input_buffers["decode_num_splits"].copy_( - attn_metadata.decode_metadata.decode_num_splits, non_blocking=True) + attn_metadata.decode_metadata.decode_num_splits) class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): @@ -201,7 +231,7 @@ def _forward_decode( ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Triton MLA not yet supported") + raise NotImplementedError("FP8 FlashMLA not yet supported") decode_meta = attn_metadata.decode_metadata assert decode_meta is not None diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0633c8610b7a..cbb78782310b 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -292,7 +292,10 @@ def get_supported_head_sizes() -> List[int]: return [576] -class MLACommonState(AttentionState): +T = TypeVar("T", bound="MLACommonMetadata") + + +class MLACommonState(AttentionState, Generic[T]): def __init__(self, runner): self.runner = runner @@ -354,10 +357,12 @@ def graph_clone(self, batch_size: int): return self.__class__(self.runner) def graph_capture_get_metadata_for_batch( - self, batch_size: int, is_encoder_decoder_model: bool = False): + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: assert self._is_graph_capturing - attn_metadata = MLACommonMetadata( + attn_metadata = self.runner.attn_backend.make_metadata( multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, use_cuda_graph=True, @@ -722,9 +727,6 @@ def advance_step(self, block_tables=self.block_tables) -T = TypeVar("T", bound=MLACommonMetadata) - - class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): """ NOTE: Please read the comment at the top of the file before trying to From 6681e43abeb0fd48713557db64d2c8f5bd855953 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 21:44:32 +0000 Subject: [PATCH 10/26] cleaner build fallbacks Signed-off-by: Lucas Wilkinson --- CMakeLists.txt | 79 +------------------ cmake/{ => external_projects}/flashmla.cmake | 5 +- cmake/{ => external_projects}/flashmla.patch | 0 cmake/external_projects/vllm_flash_attn.cmake | 67 ++++++++++++++++ setup.py | 6 +- vllm/attention/ops/flashmla.py | 34 +++++--- vllm/platforms/cuda.py | 20 ++++- 7 files changed, 120 insertions(+), 91 deletions(-) rename cmake/{ => external_projects}/flashmla.cmake (95%) rename cmake/{ => external_projects}/flashmla.patch (100%) create mode 100644 cmake/external_projects/vllm_flash_attn.cmake diff --git a/CMakeLists.txt b/CMakeLists.txt index c14e1708a6ec..0dd350c93ed5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -575,79 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP") WITH_SOABI) endif() -# vllm-flash-attn currently only supported on CUDA -if (NOT VLLM_GPU_LANG STREQUAL "CUDA") - return() +# For CUDA we also build and ship some external projects. +if (VLLM_GPU_LANG STREQUAL "CUDA") + include(cmake/external_projects/flashmla.cmake) + include(cmake/external_projects/vllm_flash_attn.cmake) endif () - -include(cmake/flashmla.cmake) - -# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target -# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the -# arches in the CUDA case (and instead set the gencodes on a per file basis) -# we need to manually set VLLM_GPU_ARCHES here. -if(VLLM_GPU_LANG STREQUAL "CUDA") - foreach(_ARCH ${CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") - endforeach() -endif() - -# -# Build vLLM flash attention from source -# -# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. -# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. -# They should be identical but if they aren't, this is a massive footgun. -# -# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). -# If no component is specified, vllm-flash-attn is still installed. - -# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. -# This is to enable local development of vllm-flash-attn within vLLM. -# It can be set as an environment variable or passed as a cmake argument. -# The environment variable takes precedence. -if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) - set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) -endif() - -if(VLLM_FLASH_ATTN_SRC_DIR) - FetchContent_Declare( - vllm-flash-attn SOURCE_DIR - ${VLLM_FLASH_ATTN_SRC_DIR} - BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn - ) -else() - FetchContent_Declare( - vllm-flash-attn - GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade - GIT_PROGRESS TRUE - # Don't share the vllm-flash-attn build between build types - BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn - ) -endif() - - -# Fetch the vllm-flash-attn library -FetchContent_MakeAvailable(vllm-flash-attn) -message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") - -# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in -# case only one is built, in the case both are built redundant work is done) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn - COMPONENT _vllm_fa2_C - FILES_MATCHING PATTERN "*.py" -) - -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm_flash_attn - COMPONENT _vllm_fa3_C - FILES_MATCHING PATTERN "*.py" -) - -# Nothing after vllm-flash-attn, see comment about macros above diff --git a/cmake/flashmla.cmake b/cmake/external_projects/flashmla.cmake similarity index 95% rename from cmake/flashmla.cmake rename to cmake/external_projects/flashmla.cmake index a4cbcc8957de..2d3204842d3a 100644 --- a/cmake/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -53,7 +53,7 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FLASH_MLA_ARCHS) CUDA_ARCHS "${FLASH_MLA_ARCHS}") define_gpu_extension_target( - _C_flashmla + _flashmla_C DESTINATION vllm LANGUAGE ${VLLM_GPU_LANG} SOURCES ${FlashMLA_SOURCES} @@ -62,5 +62,8 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FLASH_MLA_ARCHS) INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} USE_SABI 3 WITH_SOABI) +else() + # Create an empty target for setup.py + add_custom_target(_flashmla_C) endif() diff --git a/cmake/flashmla.patch b/cmake/external_projects/flashmla.patch similarity index 100% rename from cmake/flashmla.patch rename to cmake/external_projects/flashmla.patch diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake new file mode 100644 index 000000000000..ef6261fa6d9b --- /dev/null +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -0,0 +1,67 @@ +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() + +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). +# If no component is specified, vllm-flash-attn is still installed. + +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) +endif() + +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare( + vllm-flash-attn SOURCE_DIR + ${VLLM_FLASH_ATTN_SRC_DIR} + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade + GIT_PROGRESS TRUE + # Don't share the vllm-flash-attn build between build types + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn + ) +endif() + + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in +# case only one is built, in the case both are built redundant work is done) +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa2_C + FILES_MATCHING PATTERN "*.py" +) + +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm_flash_attn + COMPONENT _vllm_fa3_C + FILES_MATCHING PATTERN "*.py" +) \ No newline at end of file diff --git a/setup.py b/setup.py index 8163d29749b8..f4eda1ea3abd 100755 --- a/setup.py +++ b/setup.py @@ -328,6 +328,7 @@ def run(self) -> None: files_to_copy = [ "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", + "vllm/_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", @@ -612,7 +613,10 @@ def _read_requirements(filename: str) -> List[str]: # FA3 requires CUDA 12.0 or later ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) - ext_modules.append(CMakeExtension(name="vllm._C_flashmla")) + # Optional since this doesn't get built (produce an .so file) when + # not targeting a hopper system + ext_modules.append( + CMakeExtension(name="vllm._flashmla_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 09a10deda4da..6b25eafab007 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -11,18 +11,28 @@ if not current_platform.is_tpu() and not current_platform.is_hpu(): try: - import vllm._C_flashmla # noqa: F401 - _C_flashmla_AVAILABLE = True + import vllm._flashmla_C # noqa: F401 + _flashmla_C_AVAILABLE = True except ImportError as e: - logger.warning("Failed to import from vllm._C_flashmla with %r", e) - _C_flashmla_AVAILABLE = False + logger.warning("Failed to import from vllm._flashmla_C with %r", e) + _flashmla_C_AVAILABLE = False else: - _C_flashmla_AVAILABLE = False + _flashmla_C_AVAILABLE = False -def is_flashmla_supported() -> bool: - return _C_flashmla_AVAILABLE and \ - current_platform.get_device_capability()[0] == 9 +def is_flashmla_supported() -> Tuple[bool, Optional[str]]: + """ + Return: is_supported_flag, unsupported_reason (optional). + """ + if not current_platform.is_cuda(): + return False, "FlashMLA is only supported on CUDA devices." + if current_platform.get_device_capability()[0] != 9: + return False, "FlashMLA is only supported on Hopper devices." + if not _flashmla_C_AVAILABLE: + return False, "vllm._flashmla_C is not available, likely was not "\ + "compiled due to insufficient nvcc version or a supported arch"\ + "was not in the list of target arches to compile for." + return True, None def get_mla_metadata( @@ -41,7 +51,7 @@ def get_mla_metadata( dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._C_flashmla.get_mla_metadata(cache_seqlens, + return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) @@ -77,7 +87,7 @@ def flash_mla_with_kvcache( """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) - out, softmax_lse = torch.ops._C_flashmla.fwd_kvcache_mla( + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( q, k_cache, None, @@ -95,11 +105,11 @@ def flash_mla_with_kvcache( # # TODO: Add fake functions # -# @register_fake("_C_flashmla::get_mla_metadata") +# @register_fake("_flashmla_C::get_mla_metadata") # def _get_mla_metadata_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # -# @register_fake("_C_flashmla::fwd_kvcache_mla") +# @register_fake("_flashmla_C::fwd_kvcache_mla") # def _fwd_kvcache_mla_fake(....) -> Tuple[torch.Tensor, torch.Tensor]: # return .... # diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 4f3f1658bf6a..f1d82b1069df 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -157,8 +157,24 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using Flash Attention backend on V1 engine.") return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" if use_mla: - logger.info("Using Flash MLA backend.") - return "vllm.attention.backends.flashmla.FlashMLABackend" + if selected_backend == _Backend.FLASHMLA: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + if not is_flashmla_supported()[0]: + logger.warning( + "FlashMLA backend is not supported due to %s", + is_flashmla_supported()[1]) + elif block_size != 64: + logger.warning( + "FlashMLA backend is not supported for block size %d" + " (currently only supports block size 64).", + block_size) + else: + logger.info("Using FlashMLA backend.") + return "vllm.attention.backends.flashmla.FlashMLABackend" + + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") return "vllm.attention.backends.flashinfer.FlashInferBackend" From 4a755b995a1de4eba2b83e072d25b13b67d6e529 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 24 Feb 2025 22:17:20 +0000 Subject: [PATCH 11/26] ok cuda-graphs actually fixed now I think Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flashmla.py | 31 +++++++++++---------------- vllm/attention/backends/mla/common.py | 26 +++++++++++----------- 2 files changed, 25 insertions(+), 32 deletions(-) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index dfaf1f8ef5c5..494a3bfca0a9 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -44,7 +44,7 @@ def get_state_cls() -> Type["FlashMLAState"]: @dataclass -class FlashMLAMetadata(MLACommonMetadata): +class FlashMLAMetadata(MLACommonMetadata["FlashMLAMetadata"]): decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None decode_num_splits: Optional[torch.Tensor] = None @@ -53,20 +53,14 @@ class FlashMLAMetadata(MLACommonMetadata): @property def decode_metadata(self): - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - - common_decode_metadata = super().decode_metadata - self._cached_decode_metadata = FlashMLAMetadata( - # TODO: cached but can this be faster? - **asdict(common_decode_metadata), - decode_tile_scheduler_metadata=self.decode_tile_scheduler_metadata, - decode_num_splits=self.decode_num_splits, - ) - return self._cached_decode_metadata + decode_metadata = super().decode_metadata + # TODO: cache assignment? + if decode_metadata is not None: + decode_metadata.decode_tile_scheduler_metadata=\ + self.decode_tile_scheduler_metadata + decode_metadata.decode_num_splits=\ + self.decode_num_splits + return decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -141,15 +135,14 @@ def graph_capture_get_metadata_for_batch( assert metadata.num_decode_tokens > 0 decoder_tile_scheduler_metadata, decode_num_splits = get_mla_metadata( - metadata.seq_lens_tensor, + self._graph_seq_lens[:batch_size], self.num_q_heads, 1, # MQA for the decode path ) self._graph_decoder_tile_scheduler_metadata.copy_( - decoder_tile_scheduler_metadata, non_blocking=True) - self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits, - non_blocking=True) + decoder_tile_scheduler_metadata) + self._graph_decode_num_splits[:batch_size + 1].copy_(decode_num_splits) metadata.decode_tile_scheduler_metadata=\ self._graph_decoder_tile_scheduler_metadata diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index cbb78782310b..1efd469d60ba 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -443,7 +443,7 @@ def begin_forward(self, model_input): @dataclass -class MLACommonMetadata(AttentionMetadata): +class MLACommonMetadata(AttentionMetadata, Generic[T]): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to @@ -511,8 +511,8 @@ class MLACommonMetadata(AttentionMetadata): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_common_metadata: Optional["MLACommonMetadata"] = None - _cached_decode_common_metadata: Optional["MLACommonMetadata"] = None + _cached_prefill_metadata: Optional["MLACommonMetadata"] = None + _cached_decode_metadata: Optional["MLACommonMetadata"] = None num_prefill_tokens: int @@ -541,12 +541,12 @@ def __post_init__(self): f" received {self.head_dim}.") @property - def prefill_metadata(self) -> Optional["MLACommonMetadata"]: + def prefill_metadata(self) -> Optional[T]: if self.num_prefills == 0: return None - if self._cached_prefill_common_metadata is not None: - return self._cached_prefill_common_metadata + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata assert self.seq_lens is not None assert self.seq_lens_tensor is not None @@ -569,7 +569,7 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: input_positions = (None if self.input_positions is None else self.input_positions[:self.num_prefill_tokens]) - self._cached_prefill_common_metadata = MLACommonMetadata( + self._cached_prefill_metadata = self.__class__( # Required by ModelRunner use_cuda_graph=False, # Not Attention Related # Required by Attention Metadata @@ -600,15 +600,15 @@ def prefill_metadata(self) -> Optional["MLACommonMetadata"]: context_chunk_seq_tot=self.context_chunk_seq_tot, context_chunk_max_seq_lens=self.context_chunk_max_seq_lens, ) - return self._cached_prefill_common_metadata + return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional["MLACommonMetadata"]: + def decode_metadata(self) -> Optional[T]: if self.num_decode_tokens == 0: return None - if self._cached_decode_common_metadata is not None: - return self._cached_decode_common_metadata + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata assert self.seq_lens_tensor is not None # Compute some attn_metadata fields which default to None @@ -621,7 +621,7 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: input_positions = (None if self.input_positions is None else self.input_positions[self.num_prefill_tokens:]) - self._cached_decode_common_metadata = MLACommonMetadata( + self._cached_decode_metadata = self.__class__( # Required by ModelRunner use_cuda_graph=self.use_cuda_graph, # Not Attention Related # Required by Attention Metadata @@ -652,7 +652,7 @@ def decode_metadata(self) -> Optional["MLACommonMetadata"]: input_positions=input_positions, head_dim=self.head_dim, is_profile_run=self.is_profile_run) - return self._cached_decode_common_metadata + return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", From 649a7bf17a3a66d87e593ce95a5291bd81efaaf6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 02:46:21 +0000 Subject: [PATCH 12/26] format Signed-off-by: Lucas Wilkinson --- vllm/_custom_ops.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 584c5f83e201..0e83bcaead94 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1164,6 +1164,7 @@ def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + def get_flash_mla_metadata( cache_seqlens: torch.Tensor, num_heads_per_head_k: int, @@ -1179,7 +1180,9 @@ def get_flash_mla_metadata( tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._C.get_flash_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k) + return torch.ops._C.get_flash_mla_metadata(cache_seqlens, + num_heads_per_head_k, + num_heads_k) def flash_mla_with_kvcache( @@ -1210,7 +1213,7 @@ def flash_mla_with_kvcache( softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + softmax_scale = q.shape[-1]**(-0.5) out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( q, k_cache, From 20315610b399e46a2c832f13d968b03a676d850c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 04:33:53 +0000 Subject: [PATCH 13/26] clean up Signed-off-by: Lucas Wilkinson --- tests/kernels/test_flashmla.py | 1 - vllm/attention/backends/flashmla.py | 22 ++++++++-------------- vllm/attention/backends/mla/common.py | 8 ++++---- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py index 155698b3b4b3..1eebabc351a3 100644 --- a/tests/kernels/test_flashmla.py +++ b/tests/kernels/test_flashmla.py @@ -51,7 +51,6 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, total_seqlens = cache_seqlens.sum().item() max_seqlen = cache_seqlens.max().item() max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 - # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") q = torch.randn(b, s_q, h_q, d) block_table = torch.arange(b * max_seqlen_pad // block_size, diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index 494a3bfca0a9..afd776eb27d2 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from contextlib import contextmanager -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type import torch @@ -73,7 +73,7 @@ def advance_step(self, "advance_step is not implemented for FlashMLA") -class FlashMLAMetadataBuilder(MLACommonMetadataBuilder): +class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -83,24 +83,18 @@ def __init__(self, *args, **kwargs): def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): - common_metadata = super().build(seq_lens, query_lens, - cuda_graph_pad_size, batch_size) + m = super().build(seq_lens, query_lens, cuda_graph_pad_size, + batch_size) - decode_tile_scheduler_metadata, decode_num_splits = None, None - if common_metadata.num_decode_tokens > 0: - decode_tile_scheduler_metadata, decode_num_splits = \ + if m.num_decode_tokens > 0: + m.decode_tile_scheduler_metadata, m.decode_num_splits = \ get_mla_metadata( - common_metadata.seq_lens_tensor[common_metadata.num_prefills:], + m.seq_lens_tensor[m.num_prefills:], self.num_q_heads, 1, # MQA for the decode path ) - return FlashMLAMetadata( - # TODO: not on hotpath but can this be faster? - **asdict(common_metadata), - decode_tile_scheduler_metadata=decode_tile_scheduler_metadata, - decode_num_splits=decode_num_splits, - ) + return m class FlashMLAState(MLACommonState[FlashMLAMetadata]): diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 1efd469d60ba..ca2ed4feedba 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -511,8 +511,8 @@ class MLACommonMetadata(AttentionMetadata, Generic[T]): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional["MLACommonMetadata"] = None - _cached_decode_metadata: Optional["MLACommonMetadata"] = None + _cached_prefill_metadata: Optional[T] = None + _cached_decode_metadata: Optional[T] = None num_prefill_tokens: int @@ -727,7 +727,7 @@ def advance_step(self, block_tables=self.block_tables) -class MLACommonMetadataBuilder(AttentionMetadataBuilder[MLACommonMetadata]): +class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -960,7 +960,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], assert max(context_chunk_seq_tot) <= \ self.chunked_prefill_workspace_size - return MLACommonMetadata( + return self.runner.attn_backend.make_metadata( # Required by ModelRunner use_cuda_graph=use_captured_graph, # Not Attention Related # Required by Attention Metadata From 1722bb06b55722538e34432025c3692ecfe992a9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 04:39:07 +0000 Subject: [PATCH 14/26] review comment Signed-off-by: Lucas Wilkinson --- vllm/attention/ops/flashmla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 6b25eafab007..09326c696c63 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_hpu(): +if not current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True From 8e64f74dfdf07d7b7e91cb82db49b71acf0bce87 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 05:01:11 +0000 Subject: [PATCH 15/26] fix mypy Signed-off-by: Lucas Wilkinson --- vllm/attention/backends/flashmla.py | 4 +--- vllm/attention/backends/mla/common.py | 10 +++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index afd776eb27d2..273c69b63ec6 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -44,13 +44,11 @@ def get_state_cls() -> Type["FlashMLAState"]: @dataclass -class FlashMLAMetadata(MLACommonMetadata["FlashMLAMetadata"]): +class FlashMLAMetadata(MLACommonMetadata): decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None decode_num_splits: Optional[torch.Tensor] = None - _cached_decode_metadata: Optional["MLACommonMetadata"] = None - @property def decode_metadata(self): decode_metadata = super().decode_metadata diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index ca2ed4feedba..ff66ed4c0e9a 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -443,7 +443,7 @@ def begin_forward(self, model_input): @dataclass -class MLACommonMetadata(AttentionMetadata, Generic[T]): +class MLACommonMetadata(AttentionMetadata): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to @@ -511,8 +511,8 @@ class MLACommonMetadata(AttentionMetadata, Generic[T]): # [4, 6], it is [0, 4, 10]. seq_start_loc: Optional[torch.Tensor] = None - _cached_prefill_metadata: Optional[T] = None - _cached_decode_metadata: Optional[T] = None + _cached_prefill_metadata: Optional[Any] = None + _cached_decode_metadata: Optional[Any] = None num_prefill_tokens: int @@ -541,7 +541,7 @@ def __post_init__(self): f" received {self.head_dim}.") @property - def prefill_metadata(self) -> Optional[T]: + def prefill_metadata(self): if self.num_prefills == 0: return None @@ -603,7 +603,7 @@ def prefill_metadata(self) -> Optional[T]: return self._cached_prefill_metadata @property - def decode_metadata(self) -> Optional[T]: + def decode_metadata(self): if self.num_decode_tokens == 0: return None From b05779280c9632da673105b84feee15e2c0014dc Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 05:14:01 +0000 Subject: [PATCH 16/26] review comments Signed-off-by: Lucas Wilkinson format Signed-off-by: Lucas Wilkinson format Signed-off-by: Lucas Wilkinson format Signed-off-by: Lucas Wilkinson --- vllm/platforms/cuda.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index f1d82b1069df..c6f3ccf0a3c4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -141,6 +141,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config if cache_config and cache_config.block_size is None: cache_config.block_size = 16 + # TODO(lucas): handle this more gracefully + if envs.VLLM_ATTENTION_BACKEND is not None \ + and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \ + and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "FlashMLA: Forcing kv cache block size to 64 since this" + " is currently the only block size supported by the kernel.") @classmethod def get_current_memory_usage(cls, From 87499c3426c747a8035e1f11ae2c3f3daecc91a7 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 05:24:32 +0000 Subject: [PATCH 17/26] cleanup Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 9 +++++---- setup.py | 1 + 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 2d3204842d3a..fc3b48d994e5 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -35,10 +35,11 @@ endif() FetchContent_MakeAvailable(flashmla) message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") -# The machete kernels only work on hopper and require CUDA 12.0 or later. -# Only build Machete kernels if we are building for something compatible with sm90a +# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. +# Only build FlashMLA kernels if we are building for something compatible with +# sm90a cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FLASH_MLA_ARCHS) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu) @@ -63,7 +64,7 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND FLASH_MLA_ARCHS) USE_SABI 3 WITH_SOABI) else() - # Create an empty target for setup.py + # Create an empty target for setup.py when not targeting sm90a systems add_custom_target(_flashmla_C) endif() diff --git a/setup.py b/setup.py index f4eda1ea3abd..a636d266cfbd 100755 --- a/setup.py +++ b/setup.py @@ -613,6 +613,7 @@ def _read_requirements(filename: str) -> List[str]: # FA3 requires CUDA 12.0 or later ext_modules.append( CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"): # Optional since this doesn't get built (produce an .so file) when # not targeting a hopper system ext_modules.append( From c47e8144e481e5f8c8fc46ef33ec87b63e67793c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 07:09:00 +0000 Subject: [PATCH 18/26] fix bad logic Signed-off-by: Lucas Wilkinson --- vllm/attention/ops/flashmla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 09326c696c63..1edf338bb759 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -9,7 +9,7 @@ logger = init_logger(__name__) -if not current_platform.is_cuda(): +if current_platform.is_cuda(): try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True From cf3e5bd51bd6e5220468fa01b3834bd8b34b44ed Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 16:01:31 +0000 Subject: [PATCH 19/26] review comments Signed-off-by: Lucas Wilkinson --- vllm/platforms/cuda.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c6f3ccf0a3c4..1ed006f0b142 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -177,6 +177,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", block_size) + elif dtype != torch.bfloat16: + logger.warning( + "FlashMLA backend is not supported for dtype %s" + " (currently only supports torch.bfloat16).", dtype) else: logger.info("Using FlashMLA backend.") return "vllm.attention.backends.flashmla.FlashMLABackend" From d474a4b772d1ad20cda5d4af1c6366f418c1e483 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 16:13:26 +0000 Subject: [PATCH 20/26] update to latest flashMLA which supports fp16 Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 2 +- vllm/platforms/cuda.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index fc3b48d994e5..003ef49fb910 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -19,7 +19,7 @@ else() FetchContent_Declare( flashmla GIT_REPOSITORY https://github.com/deepseek-ai/FlashMLA - GIT_TAG 414a2f3eedeb5ad3c4a6e89d8641e059519cacc9 + GIT_TAG 4edea86f9e85eea6ea41dd14b2798fc6a0e2d80c GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 1ed006f0b142..c6f3ccf0a3c4 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -177,10 +177,6 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", block_size) - elif dtype != torch.bfloat16: - logger.warning( - "FlashMLA backend is not supported for dtype %s" - " (currently only supports torch.bfloat16).", dtype) else: logger.info("Using FlashMLA backend.") return "vllm.attention.backends.flashmla.FlashMLABackend" From 337f3ee5fdceeb056e94eecde44e51b96f0cb28c Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 16:44:06 +0000 Subject: [PATCH 21/26] update to use fork Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 9 +- cmake/external_projects/flashmla.patch | 32 ------- csrc/pytorch_shim.h | 123 ------------------------- 3 files changed, 2 insertions(+), 162 deletions(-) delete mode 100644 cmake/external_projects/flashmla.patch delete mode 100644 csrc/pytorch_shim.h diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 003ef49fb910..693d63768789 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -18,16 +18,11 @@ if(FLASH_MLA_SRC_DIR) else() FetchContent_Declare( flashmla - GIT_REPOSITORY https://github.com/deepseek-ai/FlashMLA - GIT_TAG 4edea86f9e85eea6ea41dd14b2798fc6a0e2d80c + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git + GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845 GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" - PATCH_COMMAND git apply --ignore-whitespace - "${CMAKE_CURRENT_LIST_DIR}/flashmla.patch" - # For incremental builds to prevent the patch from being reapplied, - # https://stackoverflow.com/a/73725257 - UPDATE_DISCONNECTED TRUE ) endif() diff --git a/cmake/external_projects/flashmla.patch b/cmake/external_projects/flashmla.patch deleted file mode 100644 index 0e97600e2906..000000000000 --- a/cmake/external_projects/flashmla.patch +++ /dev/null @@ -1,32 +0,0 @@ -diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp -index 5a1cb8e..65fbfb0 100644 ---- a/csrc/flash_api.cpp -+++ b/csrc/flash_api.cpp -@@ -1,6 +1,6 @@ - // Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp - --#include -+// #include - #include - #include - #include -@@ -196,8 +196,14 @@ mha_fwd_kvcache_mla( - return {out, softmax_lse}; - } - --PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { -- m.doc() = "FlashMLA"; -- m.def("get_mla_metadata", &get_mla_metadata); -- m.def("fwd_kvcache_mla", &mha_fwd_kvcache_mla); -+#include "core/registration.h" -+#include "pytorch_shim.h" -+ -+TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { -+ m.def("get_mla_metadata", make_pytorch_shim(&get_mla_metadata)); -+ m.impl("get_mla_metadata", torch::kCUDA, make_pytorch_shim(&get_mla_metadata)); -+ -+ m.def("fwd_kvcache_mla", make_pytorch_shim(&mha_fwd_kvcache_mla)); -+ m.impl("fwd_kvcache_mla", torch::kCUDA, make_pytorch_shim(&mha_fwd_kvcache_mla)); - } -+REGISTER_EXTENSION(TORCH_EXTENSION_NAME) -\ No newline at end of file diff --git a/csrc/pytorch_shim.h b/csrc/pytorch_shim.h deleted file mode 100644 index 779d40bdcc4a..000000000000 --- a/csrc/pytorch_shim.h +++ /dev/null @@ -1,123 +0,0 @@ -#pragma once - -#include - -/** - * PyBind and PyTorch Library apis generally require different type signatures. - * This file provides a shim to (mostly, there may be missing conversions) to - * convert from function designed to be used with PyBind to one that can be used - * with PyTorch Library. This is done using `make_pytorch_shim` which creates a - * lambda that exponses the API using PyTorch compatible types to the types. - * This is useful when trying to ingergate PyBind based external libraries into - * vLLM. - * - * Example: - * - * PYBIND11_MODULE(NAME, m) { - * m.def("foo", &foo); - * } - * - * could be replaced with (using the shim): - * TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { - * m.def("foo", make_pytorch_shim(&foo)); - * m.impl("foo", torch::kCUDA, make_pytorch_shim(&foo)); - * } - * - * The `pytorch_library_compatible_type` struct is used to map from the - * flash_attn ops types to a PyTorch library compatible one. The main issues is - * that the following types are not support by PyTorch library bindings: - * - `int` - * - `float` - * - `c10::optional &` - * - `c10::optional &` - * So we convert them to (respectively): - * - `int64_t` - * - `double` - * - `const c10::optional&` - * - `const c10::optional&` - */ - -template -struct pytorch_library_compatible_type { - using type = T; - static T convert_from_type(T arg) { return arg; } -}; - -template -using pytorch_library_compatible_type_t = - typename pytorch_library_compatible_type::type; - -template -T convert_from_pytorch_compatible_type( - pytorch_library_compatible_type_t arg) { - return pytorch_library_compatible_type::convert_from_type(arg); -} - -// Map `c10::optional &` -> `const c10::optional&` -// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate -// the optional container) -template -struct pytorch_library_compatible_type&> { - using type = const c10::optional&; - static c10::optional& convert_from_type(const c10::optional& arg) { - return const_cast&>(arg); - } -}; - -// Map `c10::optional` -> -// `c10::optional>` -// (NOTE: tested for `c10::optional` -> `c10::optional`) -template -struct pytorch_library_compatible_type> { - using type = c10::optional>; - static c10::optional> convert_from_type( - c10::optional arg) { - return arg; - } -}; - -// Map `c10::optional&` -> `const c10::optional&` -template <> -struct pytorch_library_compatible_type&> { - using type = const c10::optional&; - static c10::optional& convert_from_type( - const c10::optional& arg) { - return const_cast&>( - reinterpret_cast&>(arg)); - } -}; - -// Map `int` -> `int64_t` -template <> -struct pytorch_library_compatible_type { - using type = int64_t; - static int convert_from_type(int64_t arg) { - TORCH_CHECK(arg <= std::numeric_limits::max(), - "int64_t value is too large to be converted to int"); - TORCH_CHECK(arg >= std::numeric_limits::min(), - "int64_t value is too small to be converted to int"); - return arg; - } -}; - -// Map `float` -> `double` -template <> -struct pytorch_library_compatible_type { - using type = double; - static float convert_from_type(double arg) { - TORCH_CHECK(std::abs(arg) <= std::numeric_limits::max(), - "double value is too large to be converted to float"); - return arg; - } -}; - -// -// Shim Utils -// - -template -auto make_pytorch_shim(Ret (*fun)(Args... args)) { - return [fun](pytorch_library_compatible_type_t... args) { - return fun(convert_from_pytorch_compatible_type(args)...); - }; -} From 48207c9e54fafd101b95a1f3c1bd2569b52fd328 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 17:02:04 +0000 Subject: [PATCH 22/26] remove unnessary include Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 693d63768789..5ae961c833ed 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -40,7 +40,6 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu) set(FlashMLA_INCLUDES - ${PROJECT_SOURCE_DIR}/csrc ${flashmla_SOURCE_DIR}/csrc/cutlass/include ${flashmla_SOURCE_DIR}/csrc/include) From c215c6a30ede2a8b541b7b55d4c7b350aac126a9 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 18:17:25 +0000 Subject: [PATCH 23/26] add fp16 source Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 5ae961c833ed..e4fc8337737f 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -37,7 +37,8 @@ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu) + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include From 02a46a3d53ba26c184fd0eb24bf7de5d56fac724 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Tue, 25 Feb 2025 18:29:40 +0000 Subject: [PATCH 24/26] missing symbol Signed-off-by: Lucas Wilkinson --- cmake/external_projects/flashmla.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index e4fc8337737f..6291475164ba 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -38,7 +38,8 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) set(FlashMLA_SOURCES ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu - ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu) + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu + ${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu) set(FlashMLA_INCLUDES ${flashmla_SOURCE_DIR}/csrc/cutlass/include From b055298289516d85dfdc786ecd60c225b4d0d0a2 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 26 Feb 2025 08:11:46 +0000 Subject: [PATCH 25/26] improve logging, skip flashmla tests when not supported Signed-off-by: Lucas Wilkinson --- tests/kernels/test_flashmla.py | 5 ++++- vllm/attention/ops/flashmla.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py index 1eebabc351a3..7b6d79870d38 100644 --- a/tests/kernels/test_flashmla.py +++ b/tests/kernels/test_flashmla.py @@ -8,7 +8,8 @@ import triton from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, - get_mla_metadata) + get_mla_metadata, + is_flashmla_supported) def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: @@ -18,6 +19,8 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: assert cos_diff < 1e-5 +@pytest.mark.skipif(not is_flashmla_supported()[0], + "FlashMLA is not supported") @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192]) diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 1edf338bb759..18b69a6b3ddf 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -13,8 +13,7 @@ try: import vllm._flashmla_C # noqa: F401 _flashmla_C_AVAILABLE = True - except ImportError as e: - logger.warning("Failed to import from vllm._flashmla_C with %r", e) + except ImportError: _flashmla_C_AVAILABLE = False else: _flashmla_C_AVAILABLE = False @@ -30,8 +29,9 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]: return False, "FlashMLA is only supported on Hopper devices." if not _flashmla_C_AVAILABLE: return False, "vllm._flashmla_C is not available, likely was not "\ - "compiled due to insufficient nvcc version or a supported arch"\ - "was not in the list of target arches to compile for." + "compiled due to insufficient nvcc version or a supported arch "\ + "(only sm90a currently) was not in the list of target arches to "\ + "compile for." return True, None From ca7fa2d3df79e488ff04714dd53256dbb223b738 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 26 Feb 2025 15:43:37 +0000 Subject: [PATCH 26/26] fix pytest errors Signed-off-by: Lucas Wilkinson --- tests/kernels/test_flashmla.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_flashmla.py b/tests/kernels/test_flashmla.py index 7b6d79870d38..21c1079fc8eb 100644 --- a/tests/kernels/test_flashmla.py +++ b/tests/kernels/test_flashmla.py @@ -18,9 +18,12 @@ def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: (x * x + y * y).sum().item(), 1e-12) assert cos_diff < 1e-5 +FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \ + if not is_flashmla_supported()[0] else "FlashMLA is supported" + @pytest.mark.skipif(not is_flashmla_supported()[0], - "FlashMLA is not supported") + reason=FLASH_MLA_UNSUPPORTED_REASON) @pytest.mark.parametrize("b", [128]) @pytest.mark.parametrize("s_q", [1, 2]) @pytest.mark.parametrize("mean_sk", [4096, 8192])