Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ ARG PYTHON_VERSION=3.12
# Example:
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04

# Important: We build with an old version of Ubuntu to maintain broad
# Important: We build with an old version of Ubuntu to maintain broad
# compatibility with other Linux OSes. The main reason for this is that the
# glibc version is baked into the distro, and binaries built with one glibc
# version are not backwards compatible with OSes that use an earlier version.
Expand Down Expand Up @@ -371,7 +371,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
# Install FlashInfer from source
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
# Keep this in sync with "flashinfer" extra in setup.py
ARG FLASHINFER_GIT_REF="v0.3.1"
ARG FLASHINFER_GIT_REF="v0.4.0"
# Flag to control whether to compile FlashInfer AOT kernels
# Set to "true" to enable AOT compilation:
# docker build --build-arg FLASHINFER_AOT_COMPILE=true ...
Expand All @@ -392,7 +392,7 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
fi
pushd flashinfer
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ]; then
if [[ "${CUDA_VERSION}" == 12.8.* ]] && [ "$TARGETPLATFORM" = "linux/amd64" ] && [ "${FLASHINFER_GIT_REF}" = "v0.3.1" ]; then
# NOTE: To make new precompiled wheels, see tools/flashinfer-build.sh
echo "🏗️ Installing FlashInfer from pre-compiled wheel"
uv pip install --system https://wheels.vllm.ai/flashinfer-python/flashinfer_python-0.3.1-cp39-abi3-manylinux1_x86_64.whl \
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile.nightly_torch
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,15 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.


# build flashinfer for torch nightly from source around 10 mins
# release version: v0.3.1
# release version: v0.4.0
# todo(elainewy): cache flashinfer build result for faster build
ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
--mount=type=cache,target=/root/.cache/uv \
echo "git clone flashinfer..." \
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
&& cd flashinfer \
&& git checkout v0.3.1 \
&& git checkout v0.4.0 \
&& git submodule update --init --recursive \
&& echo "finish git clone flashinfer..." \
&& rm -rf build \
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def _read_requirements(filename: str) -> list[str]:
], # Required for audio processing
"video": [], # Kept for backwards compatibility
# FlashInfer should be updated together with the Dockerfile
"flashinfer": ["flashinfer-python==0.3.1"],
"flashinfer": ["flashinfer-python==0.4.0"],
# Optional deps for AMD FP4 quantization support
"petit-kernel": ["petit-kernel"],
},
Expand Down
21 changes: 9 additions & 12 deletions tests/kernels/attention/test_flashinfer_trtllm_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch

from tests.kernels.quantization.nvfp4_utils import (
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX,
dequantize_nvfp4_to_dtype,
get_nvfp4_global_scale,
)
from vllm.platforms import current_platform
from vllm.utils import round_up
Expand Down Expand Up @@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
o_sf_scale = get_nvfp4_global_scale(output)
o_sf_scale_float = o_sf_scale.item()

# TRTLLM Decode
if o_quant_dtype == FP4_DTYPE:
Expand All @@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
o_sf_scale=o_sf_scale,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
Expand Down Expand Up @@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
output = torch.empty(ref_query.shape, dtype=dtype)
wrapper.run(ref_query, ref_kv_cache, out=output)
o_scale = 1.0
o_sf_scale = None
o_sf_scale_float = None
if o_quant_dtype == FP8_DTYPE:
_, o_scale = to_float8(output)
elif o_quant_dtype == FP4_DTYPE:
o_sf_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
).to(torch.float32)
o_sf_scale = get_nvfp4_global_scale(output)
o_sf_scale_float = o_sf_scale.item()

# TRTLLM Prefill
if o_quant_dtype == FP4_DTYPE:
Expand Down Expand Up @@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
o_sf_scale=o_sf_scale,
o_sf_scale=o_sf_scale_float,
out=output_trtllm,
)
if o_quant_dtype == FP8_DTYPE:
Expand Down
8 changes: 5 additions & 3 deletions tests/kernels/quantization/nvfp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
return values.reshape(m, n * 2).to(dtype=dtype)


def get_nvfp4_global_scale(a: torch.Tensor):
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)


def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
torch.float32
)
a_global_scale = get_nvfp4_global_scale(a)
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale
2 changes: 1 addition & 1 deletion tests/quantization/test_blackwell_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def can_initialize(model: str, extra_args: Optional[list[str]] = None):
with RemoteOpenAIServer(
model,
server_args,
max_wait_seconds=1000, # Due to FlashInfer compile
max_wait_seconds=1500, # Due to FlashInfer compile
Copy link
Contributor Author

@elvischenv elvischenv Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override_hf_configs=dummy_hf_overrides,
) as server:
client = server.get_client()
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ def fast_plan_decode(
qo_indptr_host = _get_range_buf(batch_size + 1, "cpu")

try:
# Make sure we pass exactly 15 arguments for tensor core version
# Make sure we pass exactly 18 arguments for tensor core version
self._plan_info = self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
Expand All @@ -1208,6 +1208,9 @@ def fast_plan_decode(
head_dim,
head_dim,
False, # causal
window_left,
-1, # fixed_split_size
False, # disable_split_kv
)
except Exception as e:
raise RuntimeError(f"Error in tensor core plan: {e}") from e
Expand Down