Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 4 additions & 73 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -575,77 +575,8 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
WITH_SOABI)
endif()

# vllm-flash-attn currently only supported on CUDA
if (NOT VLLM_GPU_LANG STREQUAL "CUDA")
return()
# For CUDA we also build and ship some external projects.
if (VLLM_GPU_LANG STREQUAL "CUDA")
include(cmake/external_projects/flashmla.cmake)
include(cmake/external_projects/vllm_flash_attn.cmake)
endif ()

# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
# arches in the CUDA case (and instead set the gencodes on a per file basis)
# we need to manually set VLLM_GPU_ARCHES here.
if(VLLM_GPU_LANG STREQUAL "CUDA")
foreach(_ARCH ${CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
endforeach()
endif()

#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# If no component is specified, vllm-flash-attn is still installed.

# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()

if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
endif()


# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)

install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)

# Nothing after vllm-flash-attn, see comment about macros above
66 changes: 66 additions & 0 deletions cmake/external_projects/flashmla.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
include(FetchContent)

# If FLASH_MLA_SRC_DIR is set, flash-mla is installed from that directory
# instead of downloading.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{FLASH_MLA_SRC_DIR})
set(FLASH_MLA_SRC_DIR $ENV{FLASH_MLA_SRC_DIR})
endif()

if(FLASH_MLA_SRC_DIR)
FetchContent_Declare(
flashmla
SOURCE_DIR ${FLASH_MLA_SRC_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG 575f7724b9762f265bbee5889df9c7d630801845
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
)
endif()


FetchContent_MakeAvailable(flashmla)
message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")

# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_bf16_sm90.cu
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_fp16_sm90.cu
${flashmla_SOURCE_DIR}/csrc/flash_fwd_mla_metadata.cu)

set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/include)

set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")

define_gpu_extension_target(
_flashmla_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${FlashMLA_SOURCES}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
USE_SABI 3
WITH_SOABI)
else()
# Create an empty target for setup.py when not targeting sm90a systems
add_custom_target(_flashmla_C)
endif()

67 changes: 67 additions & 0 deletions cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target
# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the
# arches in the CUDA case (and instead set the gencodes on a per file basis)
# we need to manually set VLLM_GPU_ARCHES here.
if(VLLM_GPU_LANG STREQUAL "CUDA")
foreach(_ARCH ${CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real")
endforeach()
endif()

#
# Build vLLM flash attention from source
#
# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM.
# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs.
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# If no component is specified, vllm-flash-attn is still installed.

# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
# This is to enable local development of vllm-flash-attn within vLLM.
# It can be set as an environment variable or passed as a cmake argument.
# The environment variable takes precedence.
if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR})
set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR})
endif()

if(VLLM_FLASH_ATTN_SRC_DIR)
FetchContent_Declare(
vllm-flash-attn SOURCE_DIR
${VLLM_FLASH_ATTN_SRC_DIR}
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 720c94869cf2e0ff5a706e9c7f1dce0939686ade
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
endif()


# Fetch the vllm-flash-attn library
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")

# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)

install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def run(self) -> None:
files_to_copy = [
"vllm/_C.abi3.so",
"vllm/_moe_C.abi3.so",
"vllm/_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
Expand Down Expand Up @@ -612,6 +613,11 @@ def _read_requirements(filename: str) -> List[str]:
# FA3 requires CUDA 12.0 or later
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.3"):
# Optional since this doesn't get built (produce an .so file) when
# not targeting a hopper system
ext_modules.append(
CMakeExtension(name="vllm._flashmla_C", optional=True))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))

if _build_custom_ops():
Expand Down
132 changes: 132 additions & 0 deletions tests/kernels/test_flashmla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/tests/test_flash_mla.py
# SPDX-License-Identifier: Apache-2.0
import math
import random

import pytest
import torch
import triton

from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)


def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
x, y = x.double(), y.double()
cos_diff = 1 - 2 * (x * y).sum().item() / max(
(x * x + y * y).sum().item(), 1e-12)
assert cos_diff < 1e-5

FLASH_MLA_UNSUPPORTED_REASON = is_flashmla_supported()[1] \
if not is_flashmla_supported()[0] else "FlashMLA is supported"


@pytest.mark.skipif(not is_flashmla_supported()[0],
reason=FLASH_MLA_UNSUPPORTED_REASON)
@pytest.mark.parametrize("b", [128])
@pytest.mark.parametrize("s_q", [1, 2])
@pytest.mark.parametrize("mean_sk", [4096, 8192])
@pytest.mark.parametrize("h_q", [16, 32, 64, 128])
@pytest.mark.parametrize("h_kv", [1])
@pytest.mark.parametrize("d", [576])
@pytest.mark.parametrize("dv", [512])
@pytest.mark.parametrize("block_size", [64])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("varlen", [False, True])
@torch.inference_mode()
def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal,
varlen):
# TODO: parametrize using pytest
dtype = torch.bfloat16
device = torch.device("cuda:0")
torch.set_default_dtype(dtype)
torch.set_default_device(device)
torch.cuda.set_device(device)
torch.manual_seed(0)
random.seed(0)

print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, "
f"{d=}, {dv=}, {causal=}, {varlen=}")

cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32)
if varlen:
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2),
s_q)
total_seqlens = cache_seqlens.sum().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256

q = torch.randn(b, s_q, h_q, d)
block_table = torch.arange(b * max_seqlen_pad // block_size,
dtype=torch.int32).view(
b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv,
d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv]

tile_scheduler_metadata, num_splits = get_mla_metadata(
cache_seqlens, s_q * h_q // h_kv, h_kv)

def flash_mla():
return flash_mla_with_kvcache(
q,
blocked_k,
block_table,
cache_seqlens,
dv,
tile_scheduler_metadata,
num_splits,
causal=causal,
)

def scaled_dot_product_attention(query, key, value, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k,
dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse

def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
ref_O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
is_causal=causal,
)
out[i] = ref_O.transpose(0, 1)
lse[i] = LSE
return out, lse

out_flash, lse_flash = flash_mla()
out_torch, lse_torch = ref_mla()
cal_diff(out_flash, out_torch, "out")
cal_diff(lse_flash, lse_torch, "lse")

t = triton.testing.do_bench(flash_mla, fast_flush=False)
FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2
bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d +
b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)
print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} "
f"TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s")
Loading