diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index 8afde017d383..37e2980eea97 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -17,7 +17,7 @@ Latest reproduction guilde: [github issue link](https://github.com/vllm-project/ - SGLang: `lmsysorg/sglang:v0.3.2-cu121` - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` - - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - *NOTE: we use r24.07 as the current implementation only works for this version. We are going to bump this up.* - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. - Hardware - 8x Nvidia A100 GPUs diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index df2735fefeed..20f3ce1adb46 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -843,3 +843,10 @@ steps: commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 + +- label: Qwen MoE EP Test # optional + gpu: h200 + optional: true + num_gpus: 2 + commands: + - CUDA_VISIBLE_DEVICES=1,2 VLLM_ALL2ALL_BACKEND=deepep_high_throughput VLLM_USE_DEEP_GEMM=1 VLLM_LOGGING_LEVEL=DEBUG python3 /vllm-workspace/examples/offline_inference/data_parallel.py --model Qwen/Qwen1.5-MoE-A2.7B --tp-size=1 --dp-size=2 --max-model-len 2048 diff --git a/CMakeLists.txt b/CMakeLists.txt index a1deefb07f09..aca42c3fe555 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -750,6 +750,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "found in CUDA target architectures") endif() endif() + + # Only build W4A8 kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(W4A8_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND W4A8_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${W4A8_ARCHS}") + + list(APPEND VLLM_EXT_SRC "${SRCS}") + + message(STATUS "Building W4A8 kernels for archs: ${W4A8_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 + AND W4A8_ARCHS) + message(STATUS "Not building W4A8 kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building W4A8 kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() + # if CUDA endif endif() diff --git a/README.md b/README.md index fd8b02ac1f78..ef5b43588953 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,15 @@ Easy, fast, and cheap LLM serving for everyone *Latest News* šŸ”„ +- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH). - [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152). -- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
Previous News +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). diff --git a/benchmarks/README.md b/benchmarks/README.md index 176b40212978..a2dd5bb58325 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -59,6 +59,12 @@ become available. āœ… synthetic + + RandomMultiModal (Image/Video) + 🟔 + 🚧 + synthetic + Prefix Repetition āœ… @@ -722,4 +728,75 @@ python benchmarks/benchmark_serving.py \ --endpoint /v1/chat/completion ``` +### Synthetic Random Images (random-mm) + +Generate synthetic image inputs alongside random text prompts to stress-test vision models without external datasets. + +Notes: + +- Works only with online benchmark via the OpenAI backend (`--backend openai-chat`) and endpoint `/v1/chat/completions`. +- Video sampling is not yet implemented. + +Start the server (example): + +```bash +vllm serve Qwen/Qwen2.5-VL-3B-Instruct \ + --dtype bfloat16 \ + --max-model-len 16384 \ + --limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --mm-processor-kwargs max_pixels=1003520 +``` + +Benchmark. It is recommended to use the flag `--ignore-eos` to simulate real responses. You can set the size of the output via the arg `random-output-len`. + +Ex.1: Fixed number of items and a single image resolutionm, enforcing generation of approx 40 tokens: + +```bash +vllm bench serve \ + --backend openai-chat \ + --model Qwen/Qwen2.5-VL-3B-Instruct \ + --endpoint /v1/chat/completions \ + --dataset-name random-mm \ + --num-prompts 100 \ + --max-concurrency 10 \ + --random-prefix-len 25 \ + --random-input-len 300 \ + --random-output-len 40 \ + --random-range-ratio 0.2 \ + --random-mm-base-items-per-request 2 \ + --random-mm-limit-mm-per-prompt '{"image": 3, "video": 0}' \ + --random-mm-bucket-config '{(224, 224, 1): 1.0}' \ + --request-rate inf \ + --ignore-eos \ + --seed 42 +``` + +The number of items per request can be controlled by passing multiple image buckets: + +```bash + --random-mm-base-items-per-request 2 \ + --random-mm-num-mm-items-range-ratio 0.5 \ + --random-mm-limit-mm-per-prompt '{"image": 4, "video": 0}' \ + --random-mm-bucket-config '{(256, 256, 1): 0.7, (720, 1280, 1): 0.3}' \ +``` + +Flags specific to `random-mm`: + +- `--random-mm-base-items-per-request`: base number of multimodal items per request. +- `--random-mm-num-mm-items-range-ratio`: vary item count uniformly in the closed integer range [floor(nĀ·(1āˆ’r)), ceil(nĀ·(1+r))]. Set r=0 to keep it fixed; r=1 allows 0 items. +- `--random-mm-limit-mm-per-prompt`: per-modality hard caps, e.g. '{"image": 3, "video": 0}'. +- `--random-mm-bucket-config`: dict mapping (H, W, T) → probability. Entries with probability 0 are removed; remaining probabilities are renormalized to sum to 1. Use T=1 for images. Set any T>1 for videos (video sampling not yet supported). + +Behavioral notes: + +- If the requested base item count cannot be satisfied under the provided per-prompt limits, the tool raises an error rather than silently clamping. + +How sampling works: + +- Determine per-request item count k by sampling uniformly from the integer range defined by `--random-mm-base-items-per-request` and `--random-mm-num-mm-items-range-ratio`, then clamp k to at most the sum of per-modality limits. +- For each of the k items, sample a bucket (H, W, T) according to the normalized probabilities in `--random-mm-bucket-config`, while tracking how many items of each modality have been added. +- If a modality (e.g., image) reaches its limit from `--random-mm-limit-mm-per-prompt`, all buckets of that modality are excluded and the remaining bucket probabilities are renormalized before continuing. +This should be seen as an edge case, and if this behavior can be avoided by setting `--random-mm-limit-mm-per-prompt` to a large number. Note that this might result in errors due to engine config `--limit-mm-per-prompt`. +- The resulting request contains synthetic image data in `multi_modal_data` (OpenAI Chat format). When `random-mm` is used with the OpenAI Chat backend, prompts remain text and MM content is attached via `multi_modal_data`. +
diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index a9c4d30d9b18..1b1c3b321cce 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -284,6 +284,25 @@ def machete_create_bench_fn( ) +def cutlass_w4a8_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.cutlass_encode_and_reorder_int4b(w_q) + # expects fp8 scales + w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn)) + + return lambda: ops.cutlass_w4a8_mm( + a=bt.a, + b_q=w_q, + b_group_scales=w_s, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + maybe_schedule=schedule, + ) + + # impl # bench @@ -385,6 +404,20 @@ def bench( ) ) + # cutlass w4a8 + if types.act_type == torch.float8_e4m3fn and group_size == 128: + timers.append( + bench_fns( + label, + sub_label, + f"cutlass w4a8 ({name_type_string})", + [ + cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) + if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS diff --git a/benchmarks/kernels/benchmark_trtllm_decode_attention.py b/benchmarks/kernels/benchmark_trtllm_decode_attention.py index 72b54b40a2d1..603ce5ecf0d2 100644 --- a/benchmarks/kernels/benchmark_trtllm_decode_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_decode_attention.py @@ -9,8 +9,11 @@ import flashinfer import torch +from vllm.utils import round_up + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -61,13 +64,13 @@ def benchmark_decode( else: raise ValueError(f"Invalid kv_layout: {kv_layout}") - query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype) if q_quant_dtype == FP8_DTYPE: - query, q_scale = to_float8(query) - ref_query = query.to(dtype) * q_scale + query, _ = to_float8(ref_query) else: - q_scale = 1.0 - ref_query = query + query = ref_query kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_seq_len @@ -75,14 +78,13 @@ def benchmark_decode( seq_lens = kv_lens max_seq_len = torch.max(seq_lens).item() - kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) if kv_quant_dtype == FP8_DTYPE: - kv_cache, kv_scale = to_float8(kv_cache) - ref_kv_cache = kv_cache.to(dtype) * kv_scale + kv_cache, _ = to_float8(ref_kv_cache) else: - kv_scale = 1.0 - ref_kv_cache = kv_cache - k_scale = v_scale = kv_scale + kv_cache = ref_kv_cache max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( @@ -142,11 +144,31 @@ def time_fn(fn, warmup=10, trials=20): return sum(times) / len(times), torch.std(torch.tensor(times)) o_scale = 1.0 + o_sf_scale = None output_baseline = torch.empty(ref_query.shape, dtype=dtype) - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_decode(): - return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) def trtllm_decode(): return flashinfer.decode.trtllm_batch_decode_with_kv_cache( @@ -158,6 +180,7 @@ def trtllm_decode(): max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -237,6 +260,7 @@ def write_results_to_csv(results, filename=None): (None, None, None), (None, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] for quant_dtype in quant_dtypes: diff --git a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py index 49810e20c7d8..40903c6c3444 100644 --- a/benchmarks/kernels/benchmark_trtllm_prefill_attention.py +++ b/benchmarks/kernels/benchmark_trtllm_prefill_attention.py @@ -9,8 +9,11 @@ import flashinfer import torch +from vllm.utils import round_up + FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = torch.float8_e4m3fn +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -72,13 +75,15 @@ def benchmark_prefill( ] ) - query = torch.randn(torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + q_scale = 1.0 + ref_query = torch.randn( + torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype + ) if q_quant_dtype == FP8_DTYPE: - query, q_scale = to_float8(query) - ref_query = query.to(dtype) * q_scale + query, _ = to_float8(ref_query) else: - q_scale = 1.0 - ref_query = query + query = ref_query kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32) kv_lens[-1] = max_kv_len @@ -86,14 +91,13 @@ def benchmark_prefill( seq_lens = kv_lens + q_lens max_seq_len = torch.max(seq_lens).item() - kv_cache = torch.randn(kv_cache_shape, dtype=dtype) + # Always using 1.0 scale to reflect the real perf in benchmarking + k_scale = v_scale = 1.0 + ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype) if kv_quant_dtype == FP8_DTYPE: - kv_cache, kv_scale = to_float8(kv_cache) - ref_kv_cache = kv_cache.to(dtype) * kv_scale + kv_cache, _ = to_float8(ref_kv_cache) else: - kv_scale = 1.0 - ref_kv_cache = kv_cache - k_scale = v_scale = kv_scale + kv_cache = ref_kv_cache max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables = torch.randint( @@ -152,11 +156,31 @@ def time_fn(fn, warmup=10, trials=20): return sum(times) / len(times), torch.std(torch.tensor(times)) o_scale = 1.0 + o_sf_scale = None output_baseline = torch.empty(ref_query.shape, dtype=dtype) - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + o_sf_scale = 500.0 + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8), + torch.empty( + ( + round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4), + ), + dtype=torch.float8_e4m3fn, + ), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) def baseline_prefill(): - return wrapper.run(ref_query, ref_kv_cache, out=output_baseline) + return wrapper.run( + ref_query, + ref_kv_cache, + k_scale=k_scale, + v_scale=v_scale, + out=output_baseline, + ) def trtllm_prefill(): return flashinfer.prefill.trtllm_batch_context_with_kv_cache( @@ -172,6 +196,7 @@ def trtllm_prefill(): batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) @@ -250,6 +275,7 @@ def write_results_to_csv(results, filename=None): # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] for quant_dtype in quant_dtypes: diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 4fcdbadd65ec..e648a91077fd 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -11,8 +11,8 @@ from typing import Any import torch -import tqdm import triton +from tqdm import tqdm from vllm.model_executor.layers.quantization.utils.fp8_utils import ( _w8a8_block_fp8_matmul, diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index a27f02394afb..9a057990bda5 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -95,4 +95,10 @@ ([2048, 2816], 1), ([1408, 2048], 0), ], + "CohereLabs/c4ai-command-a-03-2025": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 73728], 1), + ([36864, 12288], 0), + ], } diff --git a/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu new file mode 100644 index 000000000000..fdac47c425d6 --- /dev/null +++ b/csrc/quantization/cutlass_w4a8/w4a8_mm_entry.cu @@ -0,0 +1,418 @@ +// +// Based off of: +// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu +// + +#include +#include +#include +#include "cutlass_extensions/torch_utils.hpp" + +#include "core/registration.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" + +#include "cutlass_extensions/common.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm::cutlass_w4a8 { + +using namespace cute; + +// ------------------------------------------------------------------------------------- +// Static configuration shared across all instantiations +// ------------------------------------------------------------------------------------- +using MmaType = cutlass::float_e4m3_t; // A/scale element type +using QuantType = cutlass::int4b_t; // B element type (packed int4) + +static int constexpr TileShapeK = 128 * 8 / sizeof_bits::value; +static int constexpr ScalePackSize = 8; // pack 8 scale elements together +static int constexpr PackFactor = 8; // 8 4-bit packed into int32 + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) +using StrideA = cutlass::detail::TagToStrideA_t; + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = + cutlass::layout::ColumnMajor; // Layout type for B matrix operand +using LayoutB_Transpose = + typename cutlass::layout::LayoutTranspose::type; +constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reordered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in +// contiguous locations in global memory. It specifies the reordering within a +// single warp's fragment +using LayoutAtomQuant = + decltype(cutlass::compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(cute::tile_to_shape( + LayoutAtomQuant{}, Layout, StrideB>{})); + +// Group-wise scales +using ElementScale = MmaType; +using LayoutScale = cutlass::layout::RowMajor; + +// Per-tok, per-chan scales +using ElementSChannel = float; + +// C/D matrix configuration +using ElementC = + cutlass::bfloat16_t; // Element type for C and D matrix operands +using LayoutC = + cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Memory access granularity/alignment of C + // matrix in units of elements (up to 16 bytes) + +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that + // supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using KernelSchedule = + cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch + // based on the default + // setting in the + // Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +// ---------------------------------------------------------------------------- +// Kernel template — Tile/Cluster shapes +// ---------------------------------------------------------------------------- +template +struct W4A8GemmKernel { + using TileShape = + decltype(cute::append(TileShape_MN{}, cute::Int{})); + using ClusterShape = ClusterShape_MNK; + + // Epilogue per-tok, per-chan scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + using EVTCompute = typename ChTokScalesEpilogue::EVTCompute; + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, + ElementAccumulator, ElementSChannel, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C + // matrix. We can enable this if beta == 0 by changing ElementC to + // void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, + AlignmentC, ElementD, + typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule, // This is the only epi supporting the required + // swap + transpose. + EVTCompute>::CollectiveOp; + + // The Scale information must get paired with the operand that will be scaled. + // In this example, B is scaled so we make a tuple of B's information and the + // scale information. + using CollectiveMainloopShuffled = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, + LayoutB_Reordered, AlignmentB, ElementA, LayoutA_Transpose, + AlignmentA, ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, CollectiveEpilogue>; + using GemmShuffled = + cutlass::gemm::device::GemmUniversalAdapter; + + using StrideC = typename GemmKernelShuffled::StrideC; + using StrideD = typename GemmKernelShuffled::StrideD; + using StrideS = typename CollectiveMainloopShuffled::StrideScale; + + static torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type) { + // TODO: param validation + int m = A.size(0); + int k = A.size(1); + int n = B.size(1); + + // Allocate output + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto device = A.device(); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + torch::Tensor D = + torch::empty({m, n}, torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); + // prepare arg pointers + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.data_ptr()); + // can we avoid harcode the 8 here + auto S_ptr = + static_cast const*>( + group_scales.const_data_ptr()); + + // runtime layout for B + auto shape_B = cute::make_shape(n, k, 1); + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + // strides + int const scale_k = cutlass::ceil_div(k, group_size); + StrideA stride_A = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + // Reverse stride here due to swap and transpose + StrideD stride_D = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1)); + StrideS stride_S = cutlass::make_cute_packed_stride( + StrideS{}, cute::make_shape(n, scale_k, 1)); + + // Create a structure of gemm kernel arguments suitable for invoking an + // instance of Gemm auto arguments = + // args_from_options(options); + /// Populates a Gemm::Arguments structure from the given arguments + /// Swap the A and B tensors, as well as problem shapes here. + using Args = typename GemmShuffled::Arguments; + using MainloopArguments = typename GemmKernelShuffled::MainloopArguments; + using EpilogueArguments = typename GemmKernelShuffled::EpilogueArguments; + + MainloopArguments mainloop_arguments{ + B_ptr, layout_B_reordered, A_ptr, stride_A, + S_ptr, stride_S, group_size}; + + EpilogueArguments epilogue_arguments{ + ChTokScalesEpilogue::prepare_args(channel_scales, token_scales), + nullptr, + {}, // no C + D_ptr, + stride_D}; + + Args arguments{cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, 1}, // shape + mainloop_arguments, + epilogue_arguments}; + + // Workspace + size_t workspace_size = GemmShuffled::get_workspace_size(arguments); + torch::Tensor workspace = + torch::empty(workspace_size, + torch::TensorOptions().dtype(torch::kU8).device(device)); + + // Run GEMM + GemmShuffled gemm; + CUTLASS_CHECK(gemm.can_implement(arguments)); + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + CUTLASS_CHECK(gemm.run(stream)); + + return D; + } +}; + +// ---------------------------------------------------------------------------- +// Kernel instantiations and dispatch logic +// ---------------------------------------------------------------------------- +using Kernel_256x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_256x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x256_2x1x1 = + W4A8GemmKernel, Shape<_2, _1, _1>>; +using Kernel_128x256_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x128_1x1x1 = + W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x64_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x32_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; +using Kernel_128x16_1x1x1 = W4A8GemmKernel, Shape<_1, _1, _1>>; + +torch::Tensor mm_dispatch(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, + torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + const std::string& schedule) { + if (schedule == "256x128_1x1x1") { + return Kernel_256x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x64_1x1x1") { + return Kernel_256x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x32_1x1x1") { + return Kernel_256x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "256x16_1x1x1") { + return Kernel_256x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_2x1x1") { + return Kernel_128x256_2x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x256_1x1x1") { + return Kernel_128x256_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x128_1x1x1") { + return Kernel_128x128_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x64_1x1x1") { + return Kernel_128x64_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x32_1x1x1") { + return Kernel_128x32_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } else if (schedule == "128x16_1x1x1") { + return Kernel_128x16_1x1x1::mm(A, B, group_scales, group_size, + channel_scales, token_scales, + maybe_out_type); + } + TORCH_CHECK(false, "Unknown W4A8 schedule: ", schedule); + return {}; +} + +torch::Tensor mm(torch::Tensor const& A, + torch::Tensor const& B, // already packed + torch::Tensor const& group_scales, // already packed + int64_t group_size, torch::Tensor const& channel_scales, + torch::Tensor const& token_scales, + std::optional const& maybe_out_type, + std::optional maybe_schedule) { + // requested a specific schedule + if (maybe_schedule) { + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, *maybe_schedule); + } + std::string schedule; + int M = A.size(0); + int K = A.size(1); + int N = B.size(1); + // heuristic + if (M <= 16) { + schedule = (K == 16384 && N == 18432) ? "256x16_1x1x1" : "128x16_1x1x1"; + } else if (M <= 32) { + schedule = (K == 16384 && N == 18432) ? "256x32_1x1x1" : "128x32_1x1x1"; + } else if (M <= 64) { + if (K == 16384 && N == 18432) + schedule = "256x64_1x1x1"; + else if (N <= 8192 && K <= 8192) + schedule = "128x32_1x1x1"; + else + schedule = "128x64_1x1x1"; + } else if (M <= 128) { + if (K == 16384 && N == 18432) + schedule = "256x128_1x1x1"; + else if (N <= 8192) + schedule = "128x64_1x1x1"; + else + schedule = "128x128_1x1x1"; + } else if (M <= 256) { + if (N <= 4096) + schedule = "128x64_1x1x1"; + else if (N <= 8192) + schedule = "128x128_1x1x1"; + else + schedule = "128x256_1x1x1"; + } else if (M <= 512 && N <= 4096) { + schedule = "128x128_1x1x1"; + } else if (M <= 1024) { + schedule = "128x256_1x1x1"; + } else { + schedule = "128x256_2x1x1"; + } + return mm_dispatch(A, B, group_scales, group_size, channel_scales, + token_scales, maybe_out_type, schedule); +} + +// ---------------------------------------------------------------------------- +// Pre-processing utils +// ---------------------------------------------------------------------------- +torch::Tensor pack_scale_fp8(torch::Tensor const& scales) { + TORCH_CHECK(scales.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(scales.is_cuda()); + + auto packed_scales = torch::empty( + {scales.numel() * ScalePackSize}, + torch::TensorOptions().dtype(scales.dtype()).device(scales.device())); + auto scales_ptr = static_cast(scales.const_data_ptr()); + auto packed_scales_ptr = + static_cast*>( + packed_scales.data_ptr()); + + cutlass::pack_scale_fp8(scales_ptr, packed_scales_ptr, scales.numel()); + + return packed_scales; +} + +torch::Tensor encode_and_reorder_int4b(torch::Tensor const& B) { + TORCH_CHECK(B.dtype() == torch::kInt32); + TORCH_CHECK(B.dim() == 2); + + torch::Tensor B_packed = torch::empty_like(B); + + int k = B.size(0) * PackFactor; // logical k + int n = B.size(1); + + auto B_ptr = static_cast(B.const_data_ptr()); + auto B_packed_ptr = static_cast(B_packed.data_ptr()); + auto shape_B = cute::make_shape(n, k, 1); + auto layout_B = make_layout(shape_B, LayoutRight{}); // row major + LayoutB_Reordered layout_B_reordered = + cute::tile_to_shape(LayoutAtomQuant{}, shape_B); + + cutlass::unified_encode_int4b(B_ptr, B_packed_ptr, n * k); + cutlass::reorder_tensor(B_packed_ptr, layout_B, layout_B_reordered); + + return B_packed; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_w4a8_mm", &mm); + m.impl("cutlass_pack_scale_fp8", &pack_scale_fp8); + m.impl("cutlass_encode_and_reorder_int4b", &encode_and_reorder_int4b); +} + +} // namespace vllm::cutlass_w4a8 \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4edb7af50f10..7ae054dc19fb 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -309,6 +309,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file + + // CUTLASS w4a8 GEMM + ops.def( + "cutlass_w4a8_mm(" + " Tensor A," + " Tensor B," + " Tensor group_scales," + " int group_size," + " Tensor channel_scales," + " Tensor token_scales," + " ScalarType? out_type," + " str? maybe_schedule" + ") -> Tensor", + {stride_tag}); + // pack scales + ops.def("cutlass_pack_scale_fp8(Tensor scales) -> Tensor"); + // encode and reorder weight matrix + ops.def("cutlass_encode_and_reorder_int4b(Tensor B) -> Tensor"); + // conditionally compiled so impl registration is in source file + #endif // Dequantization for GGML. diff --git a/docker/Dockerfile b/docker/Dockerfile index cfaa59868215..839ac501dbaf 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -432,31 +432,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # Install DeepGEMM from source -ARG DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" ARG DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" -RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' - . /etc/environment - CUDA_MAJOR="${CUDA_VERSION%%.*}" - CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" - CUDA_MINOR="${CUDA_MINOR%%.*}" - if [ "$CUDA_MAJOR" -ge 12 ] && [ "$CUDA_MINOR" -ge 8 ]; then - git clone --recursive --shallow-submodules \ - ${DEEPGEMM_GIT_REPO} deepgemm - echo "šŸ—ļø Building DeepGEMM" - pushd deepgemm - git checkout ${DEEPGEMM_GIT_REF} - # Build DeepGEMM - # (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) - rm -rf build dist - rm -rf *.egg-info - python3 setup.py bdist_wheel - uv pip install --system dist/*.whl - popd - rm -rf deepgemm - else - echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" - fi -BASH +COPY tools/install_deepgemm.sh /tmp/install_deepgemm.sh +RUN --mount=type=cache,target=/root/.cache/uv \ + VLLM_DOCKER_BUILD_CONTEXT=1 /tmp/install_deepgemm.sh --cuda-version "${CUDA_VERSION}" --ref "${DEEPGEMM_GIT_REF}" \ + && rm /tmp/install_deepgemm.sh + +# Install EP kernels(pplx-kernels and DeepEP), NixL +COPY tools/ep_kernels/install_python_libraries.sh install_python_libraries.sh +COPY tools/install_nixl.sh install_nixl.sh +ENV CUDA_HOME=/usr/local/cuda +RUN export TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST:-9.0a+PTX}" \ + && bash install_python_libraries.sh \ + && bash install_nixl.sh --force #################### vLLM installation IMAGE #################### diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 36232e6ad96c..61ea44220ad2 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -2,6 +2,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg), August 23rd 2025. [[Slides]](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH) - [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA), August 2nd 2025. [[Slides]](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) [[Recording]](https://www.chaspark.com/#/live/1166916873711665152). - [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index 357a5eb59406..6c7c31f503c1 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -172,6 +172,7 @@ The availablilty of batch-level DP is based on model implementation. Currently, the following models support `mm_encoder_tp_mode="data"`: - Llama4 () +- MiniCPM-V-4 () - Qwen2.5-VL () - Step3 () @@ -195,6 +196,13 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 !!! note API server scale-out is only available for online inference. +!!! warning + By default, 8 CPU threads are used in each API server to load media items (e.g. images) + from request data. + + If you apply API server scale-out, consider adjusting `VLLM_MEDIA_LOADING_THREAD_COUNT` + to avoid CPU resource exhaustion. + !!! note [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled because it requires a one-to-one correspondance between API and engine core processes. diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index e62a33b2085c..0b41e73b030c 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -18,7 +18,7 @@ vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096 - Download and install [Anything LLM desktop](https://anythingllm.com/desktop). -- On the bottom left of open settings, AI Prooviders --> LLM: +- On the bottom left of open settings, AI Providers --> LLM: - LLM Provider: Generic OpenAI - Base URL: http://{vllm server host}:{vllm server port}/v1 - Chat Model Name: `Qwen/Qwen1.5-32B-Chat-AWQ` diff --git a/docs/design/fused_moe_modular_kernel.md b/docs/design/fused_moe_modular_kernel.md index 4b917ab408ee..3c4c7d210217 100644 --- a/docs/design/fused_moe_modular_kernel.md +++ b/docs/design/fused_moe_modular_kernel.md @@ -226,7 +226,7 @@ Doing this will add the new implementation to the test suite. The unit test file [test_modular_kernel_combinations.py](gh-file:tests/kernels/moe/test_modular_kernel_combinations.py) can also be executed as a standalone script. Example: `python3 -m tests.kernels.moe.test_modular_kernel_combinations --pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts` -As a side-effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked +As a side effect, this script can be used to test `FusedMoEPrepareAndFinalize` & `FusedMoEPermuteExpertsUnpermute` compatibility. When invoked with incompatible types, the script will error. ### How To Profile diff --git a/docs/design/metrics.md b/docs/design/metrics.md index b01838883f31..b24364247b3f 100644 --- a/docs/design/metrics.md +++ b/docs/design/metrics.md @@ -565,7 +565,7 @@ model and then validate those tokens with the larger model. - `vllm:spec_decode_num_emitted_tokens_total` (Counter) There is a PR under review () to add "prompt lookup (ngram)" -seculative decoding to v1. Other techniques will follow. We should +speculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. !!! note diff --git a/docs/design/multiprocessing.md b/docs/design/multiprocessing.md index 06ebd7725858..247072d1cb27 100644 --- a/docs/design/multiprocessing.md +++ b/docs/design/multiprocessing.md @@ -77,7 +77,7 @@ The `multiproc_xpu_executor` forces the use of `spawn`. There are other miscellaneous places hard-coding the use of `spawn`: -- +- - Related PRs: diff --git a/docs/design/paged_attention.md b/docs/design/paged_attention.md index fb991a35caf3..d87b2a639df1 100644 --- a/docs/design/paged_attention.md +++ b/docs/design/paged_attention.md @@ -422,7 +422,7 @@ a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle a whole block of value tokens. And each `accs` in each thread contains 8 elements that accumulated at 8 different head positions. For the thread 0, the `accs` variable will have 8 elements, which -are 0th, 32th … 224th elements of a value head that are accumulated +are 0th, 32nd … 224th elements of a value head that are accumulated from all assigned 8 tokens. ## LV diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md index 13b151bc7f38..5e86e9388f32 100644 --- a/docs/features/quantization/inc.md +++ b/docs/features/quantization/inc.md @@ -7,7 +7,7 @@ Intel Gaudi supports quantization of various modules and functions, including, b [Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). !!! note - Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vLLM HPU extension](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. !!! note `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 37d502ef9ce0..afc605a504b3 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -284,6 +284,14 @@ Supported models: Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` +### DeepSeek-V3.1 Models (`deepseek_v31`) + +Supported models: + +* `deepseek-ai/DeepSeek-V3.1` (use with ) + +Flags: `--tool-call-parser deepseek_v31 --chat-template {see_above}` + ### Kimi-K2 Models (`kimi_k2`) Supported models: diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 7a34d47d8e49..e76ec35e1edc 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -170,7 +170,7 @@ This value is 4GB by default. Larger space can support more concurrent requests, First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. -Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: +Inference batch size is an important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: - Offline Inference: `4096 * world_size` @@ -179,7 +179,7 @@ Inference batch size is a important parameter for the performance. Larger batch - Offline Inference: `256 * world_size` - Online Serving: `128 * world_size` -vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. +vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more details of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP together if there are enough CPU sockets and memory nodes. ### Which quantization configs does vLLM CPU support? @@ -190,6 +190,6 @@ vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage mu ### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? -- Both of them requires `amx` CPU flag. +- Both of them require `amx` CPU flag. - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 61b2b02aa10b..ff912efec9ca 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -261,13 +261,13 @@ Lower value corresponds to less usable graph memory reserved for prefill stage, User can also configure the strategy for capturing HPU Graphs for prompt and decode stages separately. Strategy affects the order of capturing graphs. There are two strategies implemented: -- `max_bs` - graph capture queue will sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode +- `max_bs` - graph capture queue will be sorted in descending order by their batch sizes. Buckets with equal batch sizes are sorted by sequence length in ascending order (e.g. `(64, 128)`, `(64, 256)`, `(32, 128)`, `(32, 256)`, `(1, 128)`, `(1,256)`), default strategy for decode - `min_tokens` - graph capture queue will be sorted in ascending order by the number of tokens each graph processes (`batch_size*sequence_length`), default strategy for prompt When there's large amount of requests pending, vLLM scheduler will attempt to fill the maximum batch size for decode as soon as possible. When a request is finished, decode batch size decreases. When that happens, vLLM will attempt to schedule a prefill iteration for requests in the waiting queue, to fill the decode batch size to its previous state. This means that in a full load scenario, decode batch size is often at its maximum, which makes large batch size HPU Graphs crucial to capture, as reflected by `max_bs` strategy. On the other hand, prefills will be executed most frequently with very low batch sizes (1-4), which is reflected in `min_tokens` strategy. !!! note - `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. + `VLLM_GRAPH_PROMPT_RATIO` does not set a hard limit on memory taken by graphs for each stage (prefill and decode). vLLM will first attempt to use up entirety of usable prefill graph memory (usable graph memory * `VLLM_GRAPH_PROMPT_RATIO`) for capturing prefill HPU Graphs, next it will attempt to do the same for decode graphs and usable decode graph memory pool. If one stage is fully captured, and there is unused memory left within usable graph memory pool, vLLM will attempt further graph capture for the other stage, until no more HPU Graphs can be captured without exceeding reserved memory pool. The behavior on that mechanism can be observed in the example below. Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 297d98142b5f..8fb1019f2bdf 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -328,7 +328,7 @@ th { | `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | āœ…ļøŽ | āœ…ļøŽ | āœ…ļøŽ | | `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | āœ…ļøŽ | āœ…ļøŽ | āœ…ļøŽ | | `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | āœ…ļøŽ | āœ…ļøŽ | āœ…ļøŽ | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | āœ…ļøŽ | | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | āœ…ļøŽ | āœ…ļøŽ | | `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | | `MBartForConditionalGeneration` | mBART | `facebook/mbart-large-en-ro`, `facebook/mbart-large-50`, etc. | | | | | `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `zai-org/chatglm2-6b`, `zai-org/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | āœ…ļøŽ | āœ…ļøŽ | āœ…ļøŽ | @@ -615,6 +615,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | āœ…ļøŽ | āœ…ļøŽ | | `Cohere2VisionForConditionalGeneration` | Command A Vision | T + I+ | `CohereLabs/command-a-vision-07-2025`, etc. | | āœ…ļøŽ | āœ…ļøŽ | | `DeepseekVLV2ForCausalLM`^ | DeepSeek-VL2 | T + I+ | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | āœ…ļøŽ | āœ…ļøŽ | +| `DonutForConditionalGeneration`^ | Donut | T + I | `ByteDance/Dolphin`, `naver-clova-ix/donut-base-finetuned-docvqa`, etc. | | | | | `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | āœ…ļøŽ | āœ…ļøŽ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | āœ…ļøŽ | āœ…ļøŽ | āš ļø | diff --git a/examples/offline_inference/dolphin.py b/examples/offline_inference/dolphin.py new file mode 100644 index 000000000000..d2ba27cd1e02 --- /dev/null +++ b/examples/offline_inference/dolphin.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import copy +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +import regex as re +from PIL import Image +from transformers import DonutProcessor + +from vllm import LLM, SamplingParams +from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt +from vllm.multimodal.utils import fetch_image + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +@dataclass +class ImageDimensions: + original_w: int + original_h: int + padded_w: int + padded_h: int + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def map_to_original_coordinates( + x1, y1, x2, y2, dims: ImageDimensions +) -> tuple[int, int, int, int]: + try: + top = (dims.padded_h - dims.original_h) // 2 + left = (dims.padded_w - dims.original_w) // 2 + orig_x1 = max(0, x1 - left) + orig_y1 = max(0, y1 - top) + orig_x2 = min(dims.original_w, x2 - left) + orig_y2 = min(dims.original_h, y2 - top) + if orig_x2 <= orig_x1: + orig_x2 = min(orig_x1 + 1, dims.original_w) + if orig_y2 <= orig_y1: + orig_y2 = min(orig_y1 + 1, dims.original_h) + return int(orig_x1), int(orig_y1), int(orig_x2), int(orig_y2) + except Exception as e: + print(f"map_to_original_coordinates error: {str(e)}") + return 0, 0, min(100, dims.original_w), min(100, dims.original_h) + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def adjust_box_edges(image, boxes: list[list[float]], max_pixels=15, threshold=0.2): + if isinstance(image, str): + image = cv2.imread(image) + img_h, img_w = image.shape[:2] + new_boxes = [] + for box in boxes: + best_box = copy.deepcopy(box) + + def check_edge(img, current_box, i, is_vertical): + edge = current_box[i] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + _, binary = cv2.threshold( + gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + if is_vertical: + line = binary[current_box[1] : current_box[3] + 1, edge] + else: + line = binary[edge, current_box[0] : current_box[2] + 1] + transitions = np.abs(np.diff(line)) + return np.sum(transitions) / len(transitions) + + edges = [(0, -1, True), (2, 1, True), (1, -1, False), (3, 1, False)] + current_box = copy.deepcopy(box) + current_box[0] = min(max(current_box[0], 0), img_w - 1) + current_box[1] = min(max(current_box[1], 0), img_h - 1) + current_box[2] = min(max(current_box[2], 0), img_w - 1) + current_box[3] = min(max(current_box[3], 0), img_h - 1) + + for i, direction, is_vertical in edges: + best_score = check_edge(image, current_box, i, is_vertical) + if best_score <= threshold: + continue + for step in range(max_pixels): + current_box[i] += direction + if i == 0 or i == 2: + current_box[i] = min(max(current_box[i], 0), img_w - 1) + else: + current_box[i] = min(max(current_box[i], 0), img_h - 1) + score = check_edge(image, current_box, i, is_vertical) + if score < best_score: + best_score = score + best_box = copy.deepcopy(current_box) + if score <= threshold: + break + new_boxes.append(best_box) + return new_boxes + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def process_coordinates(coords, padded_image, dims: ImageDimensions, previous_box=None): + try: + x1, y1 = int(coords[0] * dims.padded_w), int(coords[1] * dims.padded_h) + x2, y2 = int(coords[2] * dims.padded_w), int(coords[3] * dims.padded_h) + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_boxes = adjust_box_edges(padded_image, [[x1, y1, x2, y2]]) + x1, y1, x2, y2 = new_boxes[0] + x1, y1, x2, y2 = ( + max(0, min(x1, dims.padded_w - 1)), + max(0, min(y1, dims.padded_h - 1)), + max(0, min(x2, dims.padded_w)), + max(0, min(y2, dims.padded_h)), + ) + if x2 <= x1: + x2 = min(x1 + 1, dims.padded_w) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + if previous_box is not None: + prev_x1, prev_y1, prev_x2, prev_y2 = previous_box + if (x1 < prev_x2 and x2 > prev_x1) and (y1 < prev_y2 and y2 > prev_y1): + y1 = prev_y2 + y1 = min(y1, dims.padded_h - 1) + if y2 <= y1: + y2 = min(y1 + 1, dims.padded_h) + new_previous_box = [x1, y1, x2, y2] + orig_x1, orig_y1, orig_x2, orig_y2 = map_to_original_coordinates( + x1, y1, x2, y2, dims + ) + return x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, new_previous_box + except Exception as e: + print(f"process_coordinates error: {str(e)}") + orig_x1, orig_y1, orig_x2, orig_y2 = ( + 0, + 0, + min(100, dims.original_w), + min(100, dims.original_h), + ) + return 0, 0, 100, 100, orig_x1, orig_y1, orig_x2, orig_y2, [0, 0, 100, 100] + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def prepare_image(image) -> tuple[np.ndarray, ImageDimensions]: + try: + image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) + original_h, original_w = image_cv.shape[:2] + max_size = max(original_h, original_w) + top = (max_size - original_h) // 2 + bottom = max_size - original_h - top + left = (max_size - original_w) // 2 + right = max_size - original_w - left + padded_image = cv2.copyMakeBorder( + image_cv, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0) + ) + padded_h, padded_w = padded_image.shape[:2] + dimensions = ImageDimensions( + original_w=original_w, + original_h=original_h, + padded_w=padded_w, + padded_h=padded_h, + ) + return padded_image, dimensions + except Exception as e: + print(f"prepare_image error: {str(e)}") + h, w = image.height, image.width + dimensions = ImageDimensions(original_w=w, original_h=h, padded_w=w, padded_h=h) + return np.zeros((h, w, 3), dtype=np.uint8), dimensions + + +# Copied from https://github.com/bytedance/Dolphin/utils/utils.py +def parse_layout_string(bbox_str): + """Parse layout string using regular expressions""" + pattern = r"\[(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+),\s*(\d*\.?\d+)\]\s*(\w+)" + matches = re.finditer(pattern, bbox_str) + + parsed_results = [] + for match in matches: + coords = [float(match.group(i)) for i in range(1, 5)] + label = match.group(5).strip() + parsed_results.append((coords, label)) + + return parsed_results + + +model_id = "ByteDance/Dolphin" + +# The input image size for Dolphin is 896 x 896, +# and the patch_size is 4 x 4. +# Therefore, the initial number of patches is: +# Height: 896 / 4 = 224 patches +# Width: 896 / 4 = 224 patches + +# The Dolphin model uses a staged downsampling approach, +# defined by the "depths": [2, 2, 14, 2] configuration. +# Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, +# which halves the feature map's dimensions (dividing both height and width by 2). +# Before Stage 2: The size changes from 224 x 224 to (224/2) x (224/2) = 112 x 112. +# Before Stage 3: The size changes from 112 x 112 to (112/2) x (112/2) = 56 x 56. +# Before Stage 4: The size changes from 56 x 56 to (56/2) x (56/2) = 28 x 28. + +# Because vLLM needs to fill the image features with an encoder_prompt, +# and the encoder_prompt will have `` tokens added when tokenized, +# we need to construct an encoder_prompt with a length of 28 x 28 - 1 = 783. +encoder_prompt = "".join(["0"] * 783) +sampling_params = SamplingParams( + temperature=0.0, + max_tokens=2048, +) + +processor = DonutProcessor.from_pretrained(model_id) +llm = LLM( + model=model_id, + dtype="float16", + max_num_seqs=8, + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, +) + +parser = argparse.ArgumentParser() +parser.add_argument( + "--image_path", type=str, default=None, help="Path to a local image file." +) +args = parser.parse_args() + +if args.image_path: + if not os.path.exists(args.image_path): + raise FileNotFoundError(f"Error: File not found at {args.image_path}") + image = Image.open(args.image_path).convert("RGB") +else: + image = fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) + + +prompt = "Parse the reading order of this document. " +decoder_prompt = f"{prompt}" +decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer(decoder_prompt, add_special_tokens=False)[ + "input_ids" + ] +) +enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt(prompt=encoder_prompt, multi_modal_data={"image": image}), + decoder_prompt=decoder_prompt_tokens, +) +layout_outputs = llm.generate(prompts=enc_dec_prompt, sampling_params=sampling_params) +layout_result_str = layout_outputs[0].outputs[0].text +print(f"Layout analysis output:\n{layout_result_str}") + +padded_image, dims = prepare_image(image) +layout_results = parse_layout_string(layout_result_str) +text_table_elements = [] +previous_box = None +reading_order = 0 +for bbox_coords, label in layout_results: + if label == "fig": + continue + try: + x1, y1, x2, y2, orig_x1, orig_y1, orig_x2, orig_y2, previous_box = ( + process_coordinates(bbox_coords, padded_image, dims, previous_box) + ) + cropped = padded_image[y1:y2, x1:x2] + if cropped.size > 0 and cropped.shape[0] > 3 and cropped.shape[1] > 3: + pil_crop = Image.fromarray(cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)) + prompt_ocr = ( + "Parse the table in the image. " + if label == "tab" + else "Read text in the image. " + ) + text_table_elements.append( + { + "crop": pil_crop, + "prompt": prompt_ocr, + "reading_order": reading_order, + } + ) + reading_order += 1 + except Exception as e: + print(f"Error processing bbox (label: {label}): {str(e)}") + continue + +if text_table_elements: + batch_prompts = [] + for elem in text_table_elements: + decoder_prompt_str = f"{elem['prompt']}" + decoder_prompt_tokens = TokensPrompt( + prompt_token_ids=processor.tokenizer( + decoder_prompt_str, add_special_tokens=False + )["input_ids"] + ) + enc_dec_prompt = ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt=encoder_prompt, multi_modal_data={"image": elem["crop"]} + ), + decoder_prompt=decoder_prompt_tokens, + ) + batch_prompts.append(enc_dec_prompt) + batch_outputs = llm.generate(prompts=batch_prompts, sampling_params=sampling_params) + for i, output in enumerate(batch_outputs): + text_table_elements[i]["text"] = output.outputs[0].text.strip() + +print("------" * 8) +text_table_elements.sort(key=lambda x: x["reading_order"]) +for elem in text_table_elements: + print(elem.get("text", "")) diff --git a/examples/offline_inference/encoder_decoder_multimodal.py b/examples/offline_inference/encoder_decoder_multimodal.py index d27a902edb7e..655f9f3fce7a 100644 --- a/examples/offline_inference/encoder_decoder_multimodal.py +++ b/examples/offline_inference/encoder_decoder_multimodal.py @@ -13,6 +13,7 @@ from vllm import LLM, EngineArgs, PromptType, SamplingParams from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -21,6 +22,50 @@ class ModelRequestData(NamedTuple): prompts: Sequence[PromptType] +def run_donut(): + engine_args = EngineArgs( + model="naver-clova-ix/donut-base-finetuned-docvqa", + max_num_seqs=2, + limit_mm_per_prompt={"image": 1}, + dtype="float16", + hf_overrides={"architectures": ["DonutForConditionalGeneration"]}, + ) + + # The input image size for donut-base-finetuned-docvqa is 2560 x 1920, + # and the patch_size is 4 x 4. + # Therefore, the initial number of patches is: + # Height: 1920 / 4 = 480 patches + # Width: 2560 / 4 = 640 patches + # The Swin model uses a staged downsampling approach, + # defined by the "depths": [2, 2, 14, 2] configuration. + # Before entering stages 2, 3, and 4, a "Patch Merging" operation is performed, + # which halves the feature map's dimensions (dividing both height and width by 2). + # Before Stage 2: The size changes from 480 x 640 to (480/2) x (640/2) = 240 x 320. + # Before Stage 3: The size changes from 240 x 320 to (240/2) x (320/2) = 120 x 160. + # Before Stage 4: The size changes from 120 x 160 to (120/2) x (160/2) = 60 x 80. + # Because vLLM needs to fill the image features with an encoder_prompt, + # and the encoder_prompt will have `` tokens added when tokenized, + # we need to construct an encoder_prompt with a length of 60 x 80 - 1 = 4799. + prompts = [ + { + "encoder_prompt": { + "prompt": "".join(["$"] * 4799), + "multi_modal_data": { + "image": fetch_image( + "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/0.jpg" + ) # noqa: E501 + }, + }, + "decoder_prompt": "What time is the coffee break?", # noqa: E501 + }, + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + def run_florence2(): engine_args = EngineArgs( model="microsoft/Florence-2-large", @@ -118,6 +163,7 @@ def run_whisper(): model_example_map = { + "donut": run_donut, "florence2": run_florence2, "mllama": run_mllama, "whisper": run_whisper, diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index c4972f02d0f8..cae1d2801829 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -138,7 +138,7 @@ def main(): sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) if not args.custom_mm_prompts: outputs = llm.generate( - TokensPrompt(prompt_token_ids=prompt_ids), + [TokensPrompt(prompt_token_ids=prompt_id) for prompt_id in prompt_ids], sampling_params=sampling_params, ) else: diff --git a/examples/tool_chat_template_deepseekv31.jinja b/examples/tool_chat_template_deepseekv31.jinja new file mode 100644 index 000000000000..863be69d60b6 --- /dev/null +++ b/examples/tool_chat_template_deepseekv31.jinja @@ -0,0 +1,91 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} +{% if not thinking is defined %} + {% set thinking = false %} +{% endif %} +{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{% if tools is defined and tools is not none %} + {% set tool_ns = namespace(text='## Tools\nYou have access to the following tools:\n') %} + {% for tool in tools %} + {% set tool_ns.text = tool_ns.text + '\n### ' + tool.function.name + '\nDescription: ' + tool.function.description + '\n\nParameters: ' + (tool.function.parameters | tojson) + '\n' %} + {% endfor %} + {% set tool_ns.text = tool_ns.text + "\nIMPORTANT: ALWAYS adhere to this exact format for tool use:\n<|tool▁calls▁begin|><|tool▁call▁begin|>tool_call_name<|tool▁sep|>tool_call_arguments<|tool▁call▁end|>{{additional_tool_calls}}<|tool▁calls▁end|>\n\nWhere:\n\n- `tool_call_name` must be an exact match to one of the available tools\n- `tool_call_arguments` must be valid JSON that strictly follows the tool's Parameters Schema\n- For multiple tool calls, chain them directly without separators or spaces\n" %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} +{% endif %} + +{{ bos_token }}{{ ns.system_prompt }} +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content']}} + {%- endif %} + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments']|tojson + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %} + {%- if ns.is_last_user %} + {{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking %} + {{''}} + {%- else %} + {{''}} + {%- endif %} + {%- endif %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {%- set content = message['content'] -%} + {%- if '' in content %} + {%- set content = content.split('', 1)[1] -%} + {%- endif %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} + {%- if not thinking %} + {{''}} + {%- else %} + {{''}} + {%- endif %} +{% endif %} diff --git a/requirements/common.txt b/requirements/common.txt index 8acf634526ff..e21abfb9a30b 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -18,7 +18,7 @@ prometheus_client >= 0.18.0 pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer -lm-format-enforcer >= 0.10.11, < 0.11 +lm-format-enforcer == 0.11.3 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines_core == 0.2.10 ; platform_machine != "s390x" outlines == 0.1.11 ; platform_machine == "s390x" diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 7038c9024c6b..c3bb65b70a0b 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -17,4 +17,4 @@ setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 -conch-triton-kernels==1.2.1 +conch-triton-kernels==1.2.1 \ No newline at end of file diff --git a/setup.py b/setup.py index fa406b868c07..ca6e0a8592cc 100644 --- a/setup.py +++ b/setup.py @@ -695,6 +695,8 @@ def _read_requirements(filename: str) -> list[str]: "video": [], # Kept for backwards compatibility # FlashInfer should be updated together with the Dockerfile "flashinfer": ["flashinfer-python==0.2.12"], + # Optional deps for AMD FP4 quantization support + "petit-kernel": ["petit-kernel"], }, cmdclass=cmdclass, package_data=package_data, diff --git a/tests/benchmarks/test_random_dataset.py b/tests/benchmarks/test_random_dataset.py new file mode 100644 index 000000000000..26cae369cdd5 --- /dev/null +++ b/tests/benchmarks/test_random_dataset.py @@ -0,0 +1,344 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random +from typing import Any, NamedTuple, Optional, cast + +import numpy as np +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.benchmarks.datasets import (RandomDataset, RandomMultiModalDataset, + SampleRequest) + + +@pytest.fixture(scope="session") +def hf_tokenizer() -> PreTrainedTokenizerBase: + # Use a small, commonly available tokenizer + return AutoTokenizer.from_pretrained("gpt2") + + +class Params(NamedTuple): + num_requests: int + prefix_len: int + range_ratio: float + input_len: int + output_len: int + + +@pytest.fixture(scope="session") +def random_dataset_params() -> Params: + return Params(num_requests=16, + prefix_len=7, + range_ratio=0.3, + input_len=50, + output_len=20) + + +def _fingerprint_sample(req: SampleRequest) -> tuple[str, int, int]: + """Project a SampleRequest into a comparable tuple.""" + return (req.prompt, req.prompt_len, req.expected_output_len) + + +def _collect_samples(dataset: RandomDataset, + tokenizer: PreTrainedTokenizerBase, + num_requests: int = 16, + prefix_len: int = 7, + range_ratio: float = 0.3, + input_len: int = 50, + output_len: int = 20) -> list[tuple[str, int, int]]: + samples = dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + ) + return [_fingerprint_sample(s) for s in samples] + + +@pytest.mark.benchmark +def test_random_dataset_same_seed( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Same seed should yield identical outputs, even if global RNGs change. + + This guards against accidental reliance on Python's random or np.random + in RandomDataset after moving to numpy.default_rng. + """ + p = random_dataset_params + common_seed = 123 + dataset_a = RandomDataset(random_seed=common_seed) + dataset_b = RandomDataset(random_seed=common_seed) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + # Perturb global RNG state to ensure isolation + random.seed(999) + _ = [random.random() for _ in range(100)] + np.random.seed(888) + _ = [np.random.random() for _ in range(100)] + + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a == b + +@pytest.mark.benchmark +def test_random_dataset_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, + random_dataset_params: Params) -> None: + """Different seeds should change outputs with overwhelming likelihood.""" + p = random_dataset_params + seed_a = 0 + dataset_a = RandomDataset(random_seed=seed_a) + a = _collect_samples(dataset_a, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + + seed_b = 999 + dataset_b = RandomDataset(random_seed=seed_b) + # Perturb global RNG with same seed as dataset_a to ensure isolation + random.seed(seed_a) + np.random.seed(seed_a) + b = _collect_samples(dataset_b, + hf_tokenizer, + num_requests=p.num_requests, + prefix_len=p.prefix_len, + range_ratio=p.range_ratio, + input_len=p.input_len, + output_len=p.output_len) + assert a != b + + +# ----------------------------- +# RandomMultiModalDataset tests +# ----------------------------- + +def _mm_fingerprint_sample( + req: SampleRequest, +) -> tuple[str, int, int, int, list[str]]: + """Create a compact fingerprint for multimodal samples. + + Includes: + - prompt string + - prompt_len + - expected_output_len + - count of multimodal items + - per-item type and URL prefix (e.g., 'data:image/jpeg;base64,') + """ + items = req.multi_modal_data or [] + item_prefixes: list[str] = [] + for it in items: + if isinstance(it, dict) and it.get("type") == "image_url": + url = it.get("image_url", {}).get("url", "") + # Only keep a short identifying prefix to avoid huge strings + item_prefixes.append(f"image:{url[:22]}") + elif isinstance(it, dict) and it.get("type") == "video_url": + url = it.get("video_url", {}).get("url", "") + item_prefixes.append(f"video:{url[:22]}") + else: + item_prefixes.append("unknown:") + return (req.prompt, req.prompt_len, req.expected_output_len, len(items), + item_prefixes) + + +def _collect_mm_samples( + dataset: RandomMultiModalDataset, + tokenizer: PreTrainedTokenizerBase, + *, + num_requests: int = 8, + prefix_len: int = 3, + range_ratio: float = 0.0, + input_len: int = 20, + output_len: int = 5, + base_items_per_request: int = 2, + num_mm_items_range_ratio: float = 0.0, + limit_mm_per_prompt: Optional[dict[str, int]] = None, + bucket_config: Optional[dict[tuple[int, int, int], float]] = None, + enable_multimodal_chat: bool = False, +) -> list[SampleRequest]: + if limit_mm_per_prompt is None: + limit_mm_per_prompt = {"image": 5, "video": 0} + if bucket_config is None: + bucket_config = {(32, 32, 1): 0.5, (52, 64, 1): 0.5} + return dataset.sample( + tokenizer=tokenizer, + num_requests=num_requests, + prefix_len=prefix_len, + range_ratio=range_ratio, + input_len=input_len, + output_len=output_len, + base_items_per_request=base_items_per_request, + num_mm_items_range_ratio=num_mm_items_range_ratio, + limit_mm_per_prompt=limit_mm_per_prompt, + bucket_config=bucket_config, + enable_multimodal_chat=enable_multimodal_chat, + ) + + +@pytest.mark.benchmark +def test_random_mm_same_seed(hf_tokenizer: PreTrainedTokenizerBase) -> None: + seed = 42 + ds_a = RandomMultiModalDataset(random_seed=seed) + ds_b = RandomMultiModalDataset(random_seed=seed) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa == fb + + +@pytest.mark.benchmark +def test_random_mm_different_seeds( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds_a = RandomMultiModalDataset(random_seed=0) + ds_b = RandomMultiModalDataset(random_seed=999) + a = _collect_mm_samples(ds_a, hf_tokenizer) + b = _collect_mm_samples(ds_b, hf_tokenizer) + fa = [_mm_fingerprint_sample(s) for s in a] + fb = [_mm_fingerprint_sample(s) for s in b] + assert fa != fb + +@pytest.mark.benchmark +def test_random_mm_respects_limits( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Requesting 3 items with a per-prompt limit of 1 should error per current + # design (dataset refuses to silently clamp below the requested baseline). + with pytest.raises(ValueError): + _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=12, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + + +@pytest.mark.benchmark +def test_random_mm_zero_prob_entries_are_removed( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Second bucket has zero probability and should be ignored after + # normalization + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=6, + base_items_per_request=2, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 10, "video": 0}, + bucket_config={(32, 32, 1): 1.0, (52, 64, 1): 0.0}, + ) + for s in samples: + assert isinstance(s.multi_modal_data, list) + typed_mm = cast(list[dict[str, Any]], s.multi_modal_data) + for it in typed_mm: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_zero_items(hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + samples = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=0, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 5, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + for s in samples: + assert s.multi_modal_data == [] + +@pytest.mark.benchmark +def test_random_mm_num_items_per_prompt( + hf_tokenizer: PreTrainedTokenizerBase) -> None: + ds = RandomMultiModalDataset(random_seed=0) + # Fixed number of images per prompt + # set num_mm_items_range_ratio to 0.0 + # TODO: modify video values when video sampling is implemented + samples_fixed_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=3, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 3, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with 3 mm items per prompt + assert len(samples_fixed_items) == 5 + for s in samples_fixed_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) == 3 + for it in mm_data: + assert it.get("type") == "image_url" + + +@pytest.mark.benchmark +def test_random_mm_bucket_config_not_mutated( + hf_tokenizer: PreTrainedTokenizerBase, +) -> None: + + ds = RandomMultiModalDataset(random_seed=0) + # This bucket config is not normalized to sum to 1 + # and has more buckets than requested images + original = {(32, 32, 1): 0.2, (52, 64, 1): 6, (25, 64, 1): 3} + # Keep a snapshot to compare after sampling + snapshot = dict(original) + + _ = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=4, + base_items_per_request=1, + num_mm_items_range_ratio=0.0, + limit_mm_per_prompt={"image": 1, "video": 0}, + bucket_config=original, + ) + + # Ensure the original dict content is unchanged + assert original == snapshot + + + # Vary number of mm items per prompt + # set num_mm_items_range_ratio to 0.5 + samples_varying_items = _collect_mm_samples( + ds, + hf_tokenizer, + num_requests=5, + base_items_per_request=2, + num_mm_items_range_ratio=0.5, + limit_mm_per_prompt={"image": 4, "video": 0}, + bucket_config={(32, 32, 1): 1.0}, + ) + # Must have 5 requests each with less than 4 mm items per prompt + # but at least 1 mm item per prompt + assert len(samples_varying_items) == 5 + for s in samples_varying_items: + mm_data = cast(list[dict[str, Any]], s.multi_modal_data) + assert len(mm_data) <= 4 + assert len(mm_data) >= 1 + for it in mm_data: + assert it.get("type") == "image_url" diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index aade29b99de7..0c7e6fbccf20 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -8,11 +8,12 @@ from vllm import LLM, SamplingParams from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, - kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fusion import FUSED_OPS, FusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from .backend import TestBackend diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4a3820e20fd8..5cfad935a0fb 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -7,11 +7,13 @@ import vllm.envs as envs import vllm.plugins from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, - FusionPass, GroupShape, QuantKey) + FusionPass) from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape, QuantKey, ScaleDesc) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) from vllm.platforms import current_platform @@ -30,10 +32,8 @@ def __init__(self, hidden_size: int, eps: float, static: bool, self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN - self.key = QuantKey(dtype=FP8_DTYPE, - static=static, - group_shape=group_shape, - symmetric=True) + quant_scale = ScaleDesc(torch.float32, static, group_shape) + self.key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) if static: self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] else: diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index bef0fdef985e..dba668cfa16a 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -11,9 +11,10 @@ from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata) from vllm import LLM, SamplingParams +from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention import Attention from vllm.attention.selector import global_force_attn_backend_context_manager -from vllm.compilation.fusion import QUANT_OPS, QuantKey, kFp8StaticTensorSym +from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass from vllm.compilation.fx_utils import find_op_nodes from vllm.compilation.noop_elimination import NoOpEliminationPass @@ -22,13 +23,14 @@ set_current_vllm_config) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform from vllm.v1.kv_cache_interface import AttentionSpec FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 # globals needed for string-import custom Dynamo backend field backend: Optional[TestBackend] = None @@ -105,9 +107,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # check support attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) + layer.impl.fused_output_quant_supported(quant_key) for key, layer in compile_config.static_forward_context.items() ] @@ -149,12 +149,12 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, backend = None -class TestAttentionStaticQuantPatternModel(torch.nn.Module): - """Test model for AttentionStaticQuantPattern fusion.""" +class AttentionQuantPatternModel(torch.nn.Module): + """Base model for AttentionQuantPattern fusion.""" def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype, device: torch.device, - vllm_config: VllmConfig): + vllm_config: VllmConfig, **kwargs): super().__init__() self.num_qo_heads = num_qo_heads self.num_kv_heads = num_kv_heads @@ -172,11 +172,6 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, prefix="model.layers.0.self_attn.attn", ) - self.fp8_linear = Fp8LinearOp( - act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) - self.wscale = torch.tensor([1.0], dtype=torch.float32) - self.scale = torch.tensor([1.0], dtype=torch.float32) - self.block_size = 16 # Initialize attn MetadataBuilder @@ -230,23 +225,86 @@ def build_attn_metadata(self, batch_size: int): return self.attn_metadata - def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - w: torch.Tensor): + +class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionFp8StaticQuantPattern fusion.""" + + quant_key = kFp8StaticTensorSym + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.quant_key.scale.static, + act_quant_group_shape=self.quant_key.scale.group_shape) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randn(hidden_size, hidden_size).to( + dtype=FP8_DTYPE, device=self.device).t(), + "wscale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([1.0], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): """Forward pass that creates the pattern to be fused.""" attn_output = self.attn(q, k, v) return self.fp8_linear.apply(input=attn_output, - weight=w, - weight_scale=self.wscale, - input_scale=self.scale) + weight=self.w["weight"], + weight_scale=self.w["wscale"], + input_scale=self.w["scale"]) + + +class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): + """Test model for AttentionNvfp4QuantPattern fusion.""" + + quant_key = kNvfp4Quant + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + hidden_size = self.num_qo_heads * self.head_size + self.w = kwargs.get( + "w", { + "weight": + torch.randint(256, (hidden_size, hidden_size // 2), + dtype=FP4_DTYPE, + device=self.device), + "wscale_swizzled": + torch.randn(hidden_size, hidden_size // 16).to( + dtype=FP8_DTYPE, device=self.device), + "wscale": + torch.tensor([500], dtype=torch.float32, device=self.device), + "scale": + torch.tensor([0.002], dtype=torch.float32, device=self.device), + }) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Forward pass that creates the pattern to be fused.""" + attn_output = self.attn(q, k, v) + quant_output, output_block_scale = scaled_fp4_quant( + attn_output, 1 / self.w["scale"]) + return cutlass_scaled_fp4_mm(a=quant_output, + b=self.w["weight"], + block_scale_a=output_block_scale, + block_scale_b=self.w["wscale_swizzled"], + alpha=self.w["scale"] * self.w["wscale"], + out_dtype=attn_output.dtype) @pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("batch_size", [7, 256, 533]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) -@pytest.mark.parametrize( - "model_name, quant_key", - [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", kFp8StaticTensorSym)]) +@pytest.mark.parametrize("model_name, model_class", + [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", + TestAttentionFp8StaticQuantPatternModel), + ("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", + TestAttentionNvfp4QuantPatternModel)]) @pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) @pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @@ -255,8 +313,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, head_size: int, batch_size: int, dtype: torch.dtype, model_name: str, - quant_key: QuantKey, backend: _Backend, - monkeypatch, dist_init): + model_class: type[AttentionQuantPatternModel], + backend: _Backend, monkeypatch, dist_init): """Test AttentionStaticQuantPattern fusion pass""" monkeypatch.setenv("VLLM_USE_V1", "1") @@ -277,8 +335,10 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, cache_config=CacheConfig(cache_dtype="fp8")) # Create test inputs - hidden_size = num_qo_heads * head_size - q = torch.randn(batch_size, hidden_size, dtype=dtype, device=device) + q = torch.randn(batch_size, + num_qo_heads * head_size, + dtype=dtype, + device=device) k = torch.randn(batch_size, num_kv_heads * head_size, dtype=dtype, @@ -287,7 +347,6 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, num_kv_heads * head_size, dtype=dtype, device=device) - linear_w = torch.randn(hidden_size, hidden_size).to(FP8_DTYPE).t() # Mark first dimension as dynamic for realistic testing torch._dynamo.mark_dynamic(q, 0) @@ -299,9 +358,12 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, with set_current_vllm_config(vllm_config_unfused), set_forward_context( attn_metadata=None, vllm_config=vllm_config_unfused ), global_force_attn_backend_context_manager(backend): - model_unfused = TestAttentionStaticQuantPatternModel( - num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config_unfused) + model_unfused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config_unfused) model_unfused = model_unfused.to(device) forward_ctx = get_forward_context() @@ -309,7 +371,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, batch_size) # Run model directly without compilation and fusion - result_unfused = model_unfused(q, k, v, linear_w) + result_unfused = model_unfused(q, k, v) # Run model with attn fusion enabled vllm_config.compilation_config.pass_config = PassConfig( @@ -317,9 +379,13 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, with set_current_vllm_config(vllm_config), set_forward_context( attn_metadata=None, vllm_config=vllm_config ), global_force_attn_backend_context_manager(backend): - model_fused = TestAttentionStaticQuantPatternModel( - num_qo_heads, num_kv_heads, head_size, FP8_DTYPE, device, - vllm_config) + model_fused = model_class(num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + kv_cache_dtype=FP8_DTYPE, + device=device, + vllm_config=vllm_config, + w=model_unfused.w) model_fused = model_fused.to(device) forward_ctx = get_forward_context() @@ -336,21 +402,20 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, backend=test_backend, fullgraph=True) assert model_compiled.attn._o_scale_float is None - result_fused_1 = model_compiled(q, k, v, linear_w) + result_fused_1 = model_compiled(q, k, v) # After the 1st round of the forward pass, output quant scale should be # loaded into the attn layer's _o_scale_float, the 2nd round should # reuse the loaded _o_scale_float assert model_compiled.attn._o_scale_float is not None - result_fused_2 = model_compiled(q, k, v, linear_w) + result_fused_2 = model_compiled(q, k, v) assert model_compiled.attn._o_scale_float is not None # Check attn fusion support + quant_key = model_class.quant_key attn_fusion_supported = [ - layer.impl.fused_output_quant_supported(quant_key.dtype, - quant_key.static, - quant_key.group_shape) for key, - layer in vllm_config.compilation_config.static_forward_context.items() + layer.impl.fused_output_quant_supported(quant_key) for key, layer in + vllm_config.compilation_config.static_forward_context.items() ] if any(attn_fusion_supported): # Check quantization ops in the graph before and after fusion @@ -370,6 +435,15 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, assert attn_nodes_post[0].kwargs.get("output_scale") is not None, \ "Attention should have output_scale after fusion" + assert attn_nodes_pre[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale before fusion" + if quant_key.dtype == FP8_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is None, \ + "Attention should not have output_block_scale after FP8 fusion" + elif quant_key.dtype == FP4_DTYPE: + assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ + "Attention should have output_block_scale after FP4 fusion" # noqa: E501 + # Check that results are closed torch.testing.assert_close(result_unfused, result_fused_1, diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py new file mode 100644 index 000000000000..5a804a389123 --- /dev/null +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +import typing + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed import cleanup_dist_env_and_memory +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce +from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator) +from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, + get_tp_group, + init_distributed_environment, + initialize_model_parallel) +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +torch.manual_seed(42) +random.seed(44) + +test_size_elements = 4 * 1024 * 1024 + + +def symm_mem_allreduce_worker(local_rank: int, world_size: int): + monkeypatch = pytest.MonkeyPatch() + with monkeypatch.context() as m: + m.delenv("CUDA_VISIBLE_DEVICES", raising=False) + dtype = torch.bfloat16 + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + cuda_communicator = typing.cast(CudaCommunicator, + get_tp_group().device_communicator) + symm_mem_comm = cuda_communicator.symm_mem_comm + if symm_mem_comm is None or symm_mem_comm.disabled: + pytest.skip("SymmMemCommunicator is not available or disabled.") + + inp_direct_symm_mem = torch.randint(1, + 23, (test_size_elements, ), + dtype=dtype, + device=device) + if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): + pytest.skip( + "SymmMemCommunicator isn't used for this world and input size." + ) + + original_inp_direct_symm_mem = inp_direct_symm_mem.clone() + out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) + assert out_direct_symm_mem is not None + + group = get_tensor_model_parallel_group().device_group + dist.all_reduce(original_inp_direct_symm_mem, group=group) + torch.testing.assert_close(out_direct_symm_mem, + original_inp_direct_symm_mem, + atol=2.5, + rtol=0.1) + + # Test tensor_model_parallel_all_reduce which should use symm_mem + inp_tensor_parallel = torch.randint(-23, + 1, (test_size_elements, ), + dtype=dtype, + device=device) + original_inp_tensor_parallel = inp_tensor_parallel.clone() + out_tensor_parallel = tensor_model_parallel_all_reduce( + inp_tensor_parallel) + dist.all_reduce(original_inp_tensor_parallel, group=group) + torch.testing.assert_close(out_tensor_parallel, + original_inp_tensor_parallel, + atol=2.5, + rtol=0.1) + + +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("pipeline_parallel_size", [1]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, + pipeline_parallel_size): + world_size = tp_size * pipeline_parallel_size + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + + # Enable SymmMemCommunicator + monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") + + mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) + cleanup_dist_env_and_memory() diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index 79b6ce059ce4..121c0413e1af 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -64,6 +64,28 @@ async def test_smaller_truncation_size(client: openai.AsyncOpenAI): assert response["usage"]["prompt_tokens"] == truncation_size +@pytest.mark.asyncio +async def test_zero_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 0 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + with pytest.raises(openai.BadRequestError) as err: + await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + + assert error_details["type"] == "BadRequestError" + assert "This model's maximum context length is" in error_details["message"] + assert "tokens in the input for embedding generation" in error_details[ + "message"] + assert "Please reduce the length of the input" in error_details["message"] + + @pytest.mark.asyncio async def test_bigger_truncation_size(client: openai.AsyncOpenAI): truncation_size = max_model_len + 1 @@ -74,18 +96,15 @@ async def test_bigger_truncation_size(client: openai.AsyncOpenAI): } with pytest.raises(openai.BadRequestError) as err: - err = await client.post(path="embeddings", - cast_to=object, - body={**kwargs}) - - assert str(err) == f"""openai.BadRequestError: - Error code: 400 - {{'object': 'error', - 'message': 'truncate_prompt_tokens value - ({truncation_size}) - is greater than max_model_len ({max_model_len}). - Please, select a smaller truncation size.', - 'type': 'BadRequestError', - 'param': None, 'code': 400}}""" + await client.post(path="embeddings", cast_to=object, body={**kwargs}) + + assert err.value.status_code == 400 + error_details = err.value.response.json()["error"] + assert error_details["type"] == "BadRequestError" + expected_message = ("truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + assert error_details["message"] == expected_message @pytest.mark.asyncio diff --git a/tests/kernels/attention/test_flashinfer_trtllm_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_attention.py index 69e44264cd44..8d0a11d8eb8a 100644 --- a/tests/kernels/attention/test_flashinfer_trtllm_attention.py +++ b/tests/kernels/attention/test_flashinfer_trtllm_attention.py @@ -6,7 +6,11 @@ import pytest import torch +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm.platforms import current_platform +from vllm.utils import round_up if not current_platform.is_device_capability(100): pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", @@ -14,6 +18,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def to_float8(x, dtype=torch.float8_e4m3fn): @@ -29,7 +34,9 @@ def to_float8(x, dtype=torch.float8_e4m3fn): QUANT_DTYPES = [ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) (None, None, None), + (None, FP8_DTYPE, None), (FP8_DTYPE, FP8_DTYPE, FP8_DTYPE), + (FP8_DTYPE, FP8_DTYPE, FP4_DTYPE), ] BATCH_SIZE = [4, 12] MAX_SEQ_LENS = [(1024, 4096)] @@ -153,11 +160,25 @@ def test_flashinfer_trtllm_decode_with_baseline( output = torch.empty(ref_query.shape, dtype=dtype) wrapper.run(ref_query, ref_kv_cache, out=output) o_scale = 1.0 + o_sf_scale = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) # TRTLLM Decode - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + flashinfer.decode.trtllm_batch_decode_with_kv_cache( query=query, kv_cache=kv_cache, @@ -167,15 +188,27 @@ def test_flashinfer_trtllm_decode_with_baseline( max_seq_len=max_seq_len, bmm1_scale=q_scale * k_scale * sm_scale, bmm2_scale=v_scale / o_scale, + o_sf_scale=o_sf_scale, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: output_trtllm = output_trtllm.to(dtype) * o_scale - - if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 3e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 else: - rtol, atol = 1e-2, 1e-2 + rtol, atol = 1e-2, 2e-2 torch.testing.assert_close(output, output_trtllm, atol=atol, rtol=rtol), \ f"{torch.max(torch.abs(output - output_trtllm))}" @@ -211,6 +244,9 @@ def test_flashinfer_trtllm_prefill_with_baseline( kv_quant_dtype = kv_quant_dtype or dtype o_quant_dtype = o_quant_dtype or dtype + if q_quant_dtype != kv_quant_dtype: + pytest.skip("Skipped mixed QKV dtypes for prefill") + max_q_len, max_kv_len = max_seq_lens num_qo_heads, num_kv_heads = num_heads @@ -303,11 +339,25 @@ def test_flashinfer_trtllm_prefill_with_baseline( output = torch.empty(ref_query.shape, dtype=dtype) wrapper.run(ref_query, ref_kv_cache, out=output) o_scale = 1.0 + o_sf_scale = None if o_quant_dtype == FP8_DTYPE: _, o_scale = to_float8(output) + elif o_quant_dtype == FP4_DTYPE: + o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(output.flatten(), dim=-1)).to(torch.float32) # TRTLLM Prefill - output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + if o_quant_dtype == FP4_DTYPE: + output_trtllm = flashinfer.utils.FP4Tensor( + torch.empty(query.shape[:-1] + (query.shape[-1] // 2, ), + dtype=torch.uint8), + torch.empty((round_up(query.shape[0], 128), + round_up(query.shape[1] * query.shape[2] // 16, 4)), + dtype=torch.float8_e4m3fn), + ) + else: + output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype) + flashinfer.prefill.trtllm_batch_context_with_kv_cache( query=query, kv_cache=kv_cache, @@ -321,12 +371,24 @@ def test_flashinfer_trtllm_prefill_with_baseline( batch_size=batch_size, cum_seq_lens_q=q_indptr, cum_seq_lens_kv=kv_indptr, + o_sf_scale=o_sf_scale, out=output_trtllm, ) if o_quant_dtype == FP8_DTYPE: output_trtllm = output_trtllm.to(dtype) * o_scale - - if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: + elif o_quant_dtype == FP4_DTYPE: + output_trtllm.data = output_trtllm.data.reshape( + -1, query.shape[1] * query.shape[2] // 2) + output_trtllm = dequantize_nvfp4_to_dtype(output_trtllm.data, + output_trtllm.scale, + o_sf_scale, dtype, + query.device) + output_trtllm = output_trtllm.reshape(-1, query.shape[1], + query.shape[2]) + + if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE: + rtol, atol = 4e-1, 1e0 + elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE: rtol, atol = 5e-2, 7e-2 else: rtol, atol = 1e-2, 1e-2 diff --git a/tests/kernels/quantization/test_cutlass_w4a8.py b/tests/kernels/quantization/test_cutlass_w4a8.py new file mode 100644 index 000000000000..7832f8179d0e --- /dev/null +++ b/tests/kernels/quantization/test_cutlass_w4a8.py @@ -0,0 +1,259 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for the CUTLASS W4A8 kernel. + +Run `pytest tests/kernels/test_cutlass_w4a8.py`. +""" + +from dataclasses import dataclass +from typing import Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [(1, 128, 128), (1, 512, 1024), (1, 4096, 4096), (1, 8192, 28672), + (13, 8192, 4096), (26, 4096, 8192), (64, 4096, 4096), + (64, 8192, 28672), (257, 128, 4096), (257, 4096, 4096), + (1024, 4096, 8192), (1024, 8192, 4096)] + +# TODO(czhu): get supported schedules from fn +SCHEDULES = [ + '128x16_1x1x1', '256x16_1x1x1', '128x32_1x1x1', '256x32_1x1x1', + '128x64_1x1x1', '256x64_1x1x1', '128x128_1x1x1', '256x128_1x1x1', + '128x256_1x1x1', '128x256_2x1x1' +] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: torch.Tensor + w_ch_s: torch.Tensor + w_tok_s: torch.Tensor + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +TestTypeTuple = tuple[list[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + *( + TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=w_type, + output_type=o_type, + group_scale_type=torch.float8_e4m3fn, + channel_scale_type=torch.float32, + token_scale_type=torch.float32) + for w_type in [scalar_types.int4] + # TODO(czhu): fp16 out type + for o_type in [torch.bfloat16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +# For testing quantized linear kernels +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return tensor.clamp(min=finfo.min, + max=finfo.max).to(dtype=torch.float8_e4m3fn) + + +def cutlass_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights(w, + wtype, + group_size=group_size, + zero_points=zero_points) + + # since scales are cast to fp8, we need to compute w_ref this way + w_ref = ((w_q).to(torch.float32) * w_s.to(atype).to( + torch.float32).repeat_interleave(group_size, dim=0)).to(atype) + + # bit mask prevents sign extending int4 when packing + w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q) + w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype)) + + return w_ref, w_q_packed, w_s_packed, w_zp + + +def create_test_tensors(shape: tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> Tensors: + m, n, k = shape + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = to_fp8(torch.randn((m, k), device="cuda")) + w = to_fp8(torch.randn((k, n), device="cuda")) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + False) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + # for the practical use case we need per-tok scales for fp8 activations + w_tok_s = torch.randn((m, ), device='cuda', dtype=types.token_scale_type) + # weights are already per-group quantized, use placeholder here + w_ch_s = torch.ones((n, ), device='cuda', dtype=types.channel_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +def mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + # CUTLASS upstream uses fp8 with fastaccum as reference + # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406 + output_ref = torch._scaled_mm( + tensors.a_ref.to(types.act_type), + tensors.w_ref.to(types.act_type).t().contiguous().t(), # col major + tensors.w_tok_s.unsqueeze(1), + tensors.w_ch_s.unsqueeze(0), + out_dtype=types.output_type, + use_fast_accum=True) + + output = ops.cutlass_w4a8_mm( + a=tensors.a, + b_q=tensors.w_q, + b_group_scales=tensors.w_g_s, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + ) + + print(output) + print(output_ref) + + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=1e-3, + atol=1e-3) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +@pytest.mark.parametrize("schedule", SCHEDULES) +def test_cutlass_w4a8(shape, types: TypeConfig, schedule): + group_sizes = [128] + for group_size in group_sizes: + tensors = create_test_tensors(shape, types, group_size) + mm_test_helper(types, tensors, group_size, schedule) + + +# Test to make sure cuda graphs work +class W4A8Layer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.cutlass_w4a8_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="CUTLASS W4A8 is not supported on this GPU type.") +def test_w4a8_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = to_fp8(torch.randn((m, k), device="cuda")) + b = to_fp8(torch.randn((k, n), device="cuda")) + + wtype = scalar_types.int4 + stype = torch.float8_e4m3fn + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack( + a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points) + + w_tok_s = torch.randn((m, ), device='cuda', dtype=torch.float32) + w_ch_s = torch.ones((n, ), device='cuda', dtype=torch.float32) + + # Construct a trivial model with a single layer that calls the kernel + model = W4A8Layer( + b_q=w_q_packed, + b_group_scales=w_s, + b_group_size=group_size, + b_channel_scales=w_ch_s, + a_token_scales=w_tok_s, + ) + + output_ref = torch._scaled_mm( + a, + w_ref.to(a.dtype).t().contiguous().t(), # col major + w_tok_s.unsqueeze(1), + w_ch_s.unsqueeze(0), + out_dtype=torch.bfloat16, + use_fast_accum=True) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + + output.zero_() + g.replay() + + torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3) diff --git a/tests/kernels/test_flex_attention.py b/tests/kernels/test_flex_attention.py index f76bd192460c..39753c0cc15b 100644 --- a/tests/kernels/test_flex_attention.py +++ b/tests/kernels/test_flex_attention.py @@ -9,12 +9,17 @@ import torch from packaging import version -from vllm import SamplingParams +from tests.v1.attention.utils import (BatchSpec, create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config) +from vllm.v1.attention.backends.flex_attention import ( + FlexAttentionMetadataBuilder) -from ..models.utils import check_embeddings_close +from ..models.utils import check_embeddings_close, check_logprobs_close TORCH_VERSION = version.parse(torch.__version__) MINIMUM_TORCH_VERSION = version.parse("2.7.0") +DIRECT_BUILD_VERSION = version.parse("2.9.dev0") def set_seed(seed): @@ -34,22 +39,18 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): """Test that FlexAttention produces the same outputs as the default backend. This test compares the outputs from the FlexAttention backend with - the default backend, ensuring they are identical when using the same seed. + the default backend, ensuring they are similar when using the same seed. """ model_name = "Qwen/Qwen2.5-1.5B-Instruct" seed = 42 max_tokens = 24 + num_logprobs = 5 prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", ] - sampling_params = SamplingParams(temperature=0.0, - top_p=1.0, - seed=seed, - max_tokens=max_tokens) - # Run with flex attention with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") @@ -61,7 +62,8 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): tensor_parallel_size=1, num_gpu_blocks_override=128, enforce_eager=True) as llm_flex: - output_flex = llm_flex.generate(prompts, sampling_params) + output_flex = llm_flex.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) # Run with default backend with monkeypatch.context() as m: @@ -71,20 +73,17 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch): runner="generate", tensor_parallel_size=1, num_gpu_blocks_override=128, - enforce_eager=True) as llm_default: - output_default = llm_default.generate(prompts, sampling_params) - - # Compare outputs from both backends - for i, (flex_result, - default_result) in enumerate(zip(output_flex, output_default)): - prompt = prompts[i] - flex_text = flex_result[1][0] - default_text = default_result[1][0] - - assert flex_text == default_text, ( - f"FlexAttention output doesn't match default for: {prompt!r}\n" - f"FlexAttention: {flex_text!r}\n" - f"Default: {default_text!r}") + enforce_eager=True, + gpu_memory_utilization=0.85) as llm_default: + output_default = llm_default.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=output_flex, + outputs_1_lst=output_default, + name_0="flex", + name_1="default", + ) @pytest.mark.skipif( @@ -136,5 +135,70 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch): ) +@pytest.mark.skipif( + not torch.cuda.is_available() or TORCH_VERSION < DIRECT_BUILD_VERSION, + reason="CUDA not available or PyTorch version < 2.7", +) +def test_block_mask_direct_vs_slow_path(): + """Test that direct path block mask is a superset of slow path. + + The direct path may include extra blocks for performance (over-estimation), + but must include all blocks that the slow path determines are necessary. + """ + device = torch.device("cuda") + + vllm_config = create_vllm_config(model_name="meta-llama/Meta-Llama-3-8B", + block_size=16, + max_model_len=1024) + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # Use a mixed batch that will create groups spanning multiple sequences + batch_spec = BatchSpec(seq_lens=[35, 64, 128, 256], + query_lens=[33, 5, 32, 64], + name="test_mixed_batch") + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + builder = FlexAttentionMetadataBuilder(kv_cache_spec, [], vllm_config, + device) + + metadata_direct = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + builder.direct_build = False + metadata_slow = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + assert metadata_direct.block_mask is not None + assert metadata_slow.block_mask is not None + + # Extract block indices for comparison, B, H are the same + direct_indices = metadata_direct.block_mask.kv_indices[0, 0] + slow_indices = metadata_slow.block_mask.kv_indices[0, 0] + direct_num = metadata_direct.block_mask.kv_num_blocks[0, 0] + slow_num = metadata_slow.block_mask.kv_num_blocks[0, 0] + + # main test: every block needed by slow path must be in direct path + num_groups = direct_num.shape[0] + all_contained = True + missing_details = [] + + for group_idx in range(num_groups): + direct_blocks = set( + direct_indices[group_idx, :direct_num[group_idx]].tolist()) + slow_blocks = set( + slow_indices[group_idx, :slow_num[group_idx]].tolist()) + + missing_blocks = slow_blocks - direct_blocks + if missing_blocks: + all_contained = False + missing_details.append( + f"Group {group_idx}: missing {sorted(missing_blocks)}") + + assert all_contained, ( + "Direct path is missing blocks required by slow path:\n" + + "\n".join(missing_details)) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/models/language/pooling/test_st_projector.py b/tests/models/language/pooling/test_st_projector.py new file mode 100644 index 000000000000..51ddbcc5ab24 --- /dev/null +++ b/tests/models/language/pooling/test_st_projector.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from ...utils import CLSPoolingEmbedModelInfo, EmbedModelInfo +from .mteb_utils import mteb_test_embed_models + +# ST models with projector (Dense) layers +ST_PROJECTOR_MODELS = [ + CLSPoolingEmbedModelInfo( + "TencentBAC/Conan-embedding-v1", + architecture="BertModel", + enable_test=True, + ), +] + + +@pytest.mark.parametrize("model_info", ST_PROJECTOR_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + + mteb_test_embed_models(hf_runner, vllm_runner, model_info) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index adc8b2510d67..a604d11f0e76 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -160,6 +160,7 @@ def _test_processing_correctness( # incorrect token ids. So we need use `add_special_tokens=False` here # to leave bos_token to be added by the processor. _ADD_SPECIAL_TOKENS_OVERRIDES = { + "donut": False, "mllama": False, "ovis": False, "ovis2_5": False, @@ -270,6 +271,7 @@ def _test_processing_correctness_one( "facebook/chameleon-7b", "CohereLabs/command-a-vision-07-2025", "deepseek-ai/deepseek-vl2-tiny", + "naver-clova-ix/donut-base-finetuned-docvqa", "microsoft/Florence-2-base", "adept/fuyu-8b", "google/gemma-3-4b-it", diff --git a/tests/models/registry.py b/tests/models/registry.py index 25dbbd7fa983..b34c6f2e5dc8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -513,6 +513,9 @@ def check_available_online( is_available_online=False, ), # [Encoder-decoder] + "DonutForConditionalGeneration": _HfExamplesInfo("naver-clova-ix/donut-base-finetuned-docvqa", # noqa: E501 + hf_overrides={"architectures": ["DonutForConditionalGeneration"], "model_type": "donut"}, # noqa: E501 + extras={"dolphin": "ByteDance/Dolphin"}), # noqa: E501 # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index cb489c47fd8f..6ce5fcfe644b 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -17,13 +17,11 @@ PromptReplacement, apply_text_matches, apply_token_matches, find_mm_placeholders, - find_text_matches, find_token_matches, iter_token_matches, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import MultiModalProfiler from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.utils import full_groupby from .utils import random_image @@ -75,12 +73,15 @@ ), ], ) +@pytest.mark.parametrize("start_idx", [0, 4, 8]) # yapf: enable -def test_iter_token_matches(token_ids, match_ids, expected): - result = list(iter_token_matches(token_ids, match_ids)) +def test_iter_token_matches(token_ids, match_ids, expected, start_idx): + result = list(iter_token_matches(token_ids, match_ids, + start_idx=start_idx)) # Manually constructed results - assert [item._asdict() for item in result] == expected + assert [item._asdict() for item in result + ] == [item for item in expected if item["start_idx"] >= start_idx] # Invariants match_lens = [end - start for start, end in result] @@ -241,21 +242,23 @@ def test_find_token_matches( # Should not be used since there is nothing to convert to token IDs mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_token_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_token_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -388,21 +391,23 @@ def test_find_text_matches( # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_updates = [ - update_type(key, target, []).bind(mock_tokenizer) + prompt_updates = { + key: update_type(key, target, []).resolve(0) for key, target in target_by_key.items() - ] - result = find_text_matches(prompt, prompt_updates) + } + result = { + key: list(update.iter_text_matches(prompt, mock_tokenizer)) + for key, update in prompt_updates.items() + } # Only displayed on error print("result:", result) # Manually constructed results - result_groups = dict(full_groupby(result, key=lambda x: x.modality)) assert { key: [ dict(start_idx=item.start_idx, end_idx=item.end_idx) - for item in result_groups.get(key, []) + for item in result.get(key, []) ] for key in expected_by_key } == expected_by_key @@ -552,39 +557,35 @@ def test_find_update_text( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_text_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_text_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_text_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @pytest.mark.parametrize( ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501 [ - # Tokenized test cases of `test_find_replace_text` + # Tokenized test cases of `test_find_update_text` # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf ( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], @@ -726,32 +727,28 @@ def test_find_update_tokens( update_type, expected_by_mm_count, ) in expected_by_update_type_mm_count.items(): - mm_prompt_updates = { - key: - [update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)] - for key, target in target_by_key.items() - } - mm_matches = { - key: find_token_matches(prompt, updates) - for key, updates in mm_prompt_updates.items() - } - for mm_count, expected in expected_by_mm_count.items(): - result = apply_token_matches( + mm_prompt_updates = { + key: [[update_type(key, target, repl_by_key[key]).resolve(i)] + for i in range(mm_count)] + for key, target in target_by_key.items() + } + + new_prompt, result = apply_token_matches( prompt, - mm_matches, - {key: mm_count - for key in repl_by_key}, + mm_prompt_updates, + mock_tokenizer, ) # Only displayed on error print("update_type:", update_type) print("mm_count:", mm_count) - print("mm_matches:", mm_matches) + print("mm_prompt_updates:", mm_prompt_updates) + print("new_prompt:", new_prompt) print("result:", result) # Manually constructed results - assert result == expected + assert new_prompt == expected # yapf: disable @@ -878,17 +875,11 @@ def test_find_mm_placeholders( mock_tokenizer = cast(AnyTokenizer, object()) mm_prompt_updates = { - key: [update_type(key, [], repl).bind(mock_tokenizer)] + key: [[update_type(key, [], repl).resolve(i)] for i in range(3)] for key, repl in repl_by_key.items() } - result = find_mm_placeholders( - mm_prompt_updates, - prompt, - # Effectively match all occurrences in the prompt - {key: 3 - for key in repl_by_key}, - ) + result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer) # Only displayed on error print("result:", result) diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index 60e04ad9069e..e4c07aae0ebe 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -10,14 +10,15 @@ create_standard_kv_cache_spec, create_vllm_config, get_attention_backend) -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, set_kv_cache_layout) from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, - _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1, _Backend.TREE_ATTN, + "FLEX_ATTENTION_SLOW" ] # Remove flashinfer from the list if it's not available @@ -97,7 +98,7 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: CommonAttentionMetadata, randomize_blocks: bool = True) -> torch.Tensor: """Create and prepopulate a KV cache with context data. - + Args: k_contexts: List of key context tensors for each sequence v_contexts: List of value context tensors for each sequence @@ -109,9 +110,9 @@ def create_and_prepopulate_kv_cache( device: Device to create the cache on num_blocks: Total number of blocks in the cache block_table: Block table tensor to populate - randomize_blocks: Whether to randomly permute blocks + randomize_blocks: Whether to randomly permute blocks or use sequential order - + Returns: Tuple of (kv_cache, updated_block_table) """ @@ -206,10 +207,18 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, kv_cache: torch.Tensor) -> torch.Tensor: """Run attention computation using the specified backend's AttentionImpl.""" - builder_cls, impl_cls = get_attention_backend(backend) + # Handle special case for FLEX_ATTENTION_SLOW + actual_backend = backend + + use_direct_block_mask = is_torch_equal_or_newer("2.9.0.dev0") + if backend == "FLEX_ATTENTION_SLOW": + actual_backend = _Backend.FLEX_ATTENTION + use_direct_block_mask = False + + builder_cls, impl_cls = get_attention_backend(actual_backend) # Mock flashinfer's get_per_layer_parameters if needed - if backend == _Backend.FLASHINFER_VLLM_V1: + if actual_backend == _Backend.FLASHINFER_VLLM_V1: import unittest.mock from vllm.v1.attention.backends.utils import PerLayerParameters @@ -239,6 +248,8 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls): else: # Build metadata builder = builder_cls(kv_cache_spec, layer_names, vllm_config, device) + if actual_backend == _Backend.FLEX_ATTENTION: + builder.direct_build = use_direct_block_mask attn_metadata = builder.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, @@ -453,11 +464,6 @@ def test_backend_correctness(batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-3 - if backend_name == _Backend.FLEX_ATTENTION: - atol = 5e-1 # TODO: figure out why flex_attention has such large - # numerical differences for medium_decode, medium_prefill, - # mixed_medium - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() max_rel_diff = torch.max( torch.abs(backend_output - sdpa_output) / diff --git a/tests/v1/attention/test_attention_backends_selection.py b/tests/v1/attention/test_attention_backends_selection.py new file mode 100644 index 000000000000..59e562814946 --- /dev/null +++ b/tests/v1/attention/test_attention_backends_selection.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for mamba attention backend selectors.""" + +from types import SimpleNamespace + +import pytest + +from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.short_conv import ShortConv +from vllm.model_executor.models.minimax_text_01 import ( + MiniMaxText01LinearAttention) +from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend +from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend +from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend +from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + + +@pytest.mark.parametrize( + "layer_class, init_kwargs, expected_backend, expected_mamba_type", [ + ( + MambaMixer, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + time_step_rank=8, + use_conv_bias=True, + use_bias=False, + use_rms_norm=True, + ), + Mamba1AttentionBackend, + "mamba1", + ), + ( + MambaMixer2, + dict( + hidden_size=128, + ssm_state_size=16, + conv_kernel_size=4, + intermediate_size=256, + use_conv_bias=True, + use_bias=False, + n_groups=1, + num_heads=8, + head_dim=32, + ), + Mamba2AttentionBackend, + "mamba2", + ), + ( + MiniMaxText01LinearAttention, + dict( + hidden_size=128, + hidden_inner_size=256, + num_heads=8, + head_dim=32, + max_position=2048, + block_size=64, + num_hidden_layer=12, + layer_idx=0, + linear_layer_idx=0, + ), + LinearAttentionBackend, + "linear_attention", + ), + ( + ShortConv, + dict( + config=SimpleNamespace(conv_L_cache=32, conv_bias=True), + dim=128, + layer_idx=0, + ), + ShortConvAttentionBackend, + "short_conv", + ), + ]) +def test_mamba_layers_get_attn_backend(dist_init, layer_class, init_kwargs, + expected_backend, expected_mamba_type): + """Test that Mamba-like layers return the correct attention backend.""" + layer = layer_class(**init_kwargs) + + backend_class = layer.get_attn_backend() + assert backend_class is expected_backend + assert layer.mamba_type == expected_mamba_type + + +@pytest.mark.parametrize("layer_class,expected_backend,expected_mamba_type", [ + (MambaMixer, Mamba1AttentionBackend, "mamba1"), + (MambaMixer2, Mamba2AttentionBackend, "mamba2"), + (MiniMaxText01LinearAttention, LinearAttentionBackend, "linear_attention"), + (ShortConv, ShortConvAttentionBackend, "short_conv"), +]) +def test_mamba_layers_have_unified_interface(layer_class, expected_backend, + expected_mamba_type): + """Test that all Mamba layers have the unified get_attn_backend + interface.""" + assert hasattr(layer_class, 'get_attn_backend'), ( + f"{layer_class.__name__} should have get_attn_backend method") + assert hasattr(layer_class, 'mamba_type'), ( + f"{layer_class.__name__} should have mamba_type property") diff --git a/tests/v1/attention/test_mamba_selectors.py b/tests/v1/attention/test_mamba_selectors.py deleted file mode 100644 index 4245b50c7131..000000000000 --- a/tests/v1/attention/test_mamba_selectors.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for mamba attention backend selectors.""" - -import pytest - -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend - - -@pytest.mark.parametrize(argnames=["mamba_type", "expected_backend"], - argvalues=[("mamba2", Mamba2AttentionBackend)]) -def test_get_mamba_attn_backend_mamba2(mamba_type, expected_backend): - backend_class = get_mamba_attn_backend(mamba_type) - - assert backend_class is expected_backend - - -def test_get_mamba_attn_backend_unsupported(): - unsupported_types = ["mamba", ""] - - for mamba_type in unsupported_types: - err_message = f"Mamba Attention type {mamba_type} is not supported yet." - with pytest.raises(NotImplementedError, match=err_message): - get_mamba_attn_backend(mamba_type) diff --git a/tests/v1/core/test_encoder_cache_manager.py b/tests/v1/core/test_encoder_cache_manager.py new file mode 100644 index 000000000000..60d932a878ab --- /dev/null +++ b/tests/v1/core/test_encoder_cache_manager.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager + + +# ------------------ Mock Classes ------------------ # +class MockRequest: + + def __init__(self, request_id, mm_hashes, token_counts): + self.request_id = request_id + self.mm_hashes = mm_hashes + self._token_counts = token_counts + + def get_num_encoder_tokens(self, input_id: int) -> int: + return self._token_counts[input_id] + + +# ------------------ Unit Tests ------------------ # +def test_basic_allocate_and_reuse(): + cache = EncoderCacheManager(cache_size=10) + req = MockRequest("r1", ["imgA"], [4]) + + assert not cache.check_and_update_cache(req, 0) + assert cache.try_allocate(req, 0, int(1e9)) + + cache.allocate(req, 0) + + assert cache.check_and_update_cache(req, 0) + assert "r1" in cache.cached["imgA"] + assert cache.num_free_slots == 6 + + # Free twice to bring refcount to 0. + cache.free_encoder_input(req, 0) + cache.free_encoder_input(req, 0) + + assert not cache.cached["imgA"] + assert "imgA" in cache.freeable + assert cache.num_freeable_slots == 10 + assert cache.num_free_slots == 6 + + +def test_freeing_decreases_refcount_and_moves_to_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req2", ["img3"], [5]) + + assert manager.try_allocate(req, 0, int(1e9)) + manager.allocate(req, 0) + + assert len(manager.cached["img3"]) == 1 + + manager.free_encoder_input(req, 0) + + assert not manager.cached["img3"] + assert "img3" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_free_request_frees_all_inputs(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("req3", ["a", "b"], [2, 3]) + + assert manager.try_allocate(req, 0, int(1e9)) + manager.allocate(req, 0) + + assert manager.try_allocate(req, 1, int(1e9)) + manager.allocate(req, 1) + + assert len(manager.cached["a"]) == 1 + assert len(manager.cached["b"]) == 1 + + manager.free(req) + + assert not manager.cached["a"] + assert not manager.cached["b"] + assert "a" in manager.freeable + assert "b" in manager.freeable + assert manager.num_freeable_slots == 10 + + +def test_eviction_when_cache_is_full(): + manager = EncoderCacheManager(cache_size=10) + + req1 = MockRequest("req1", ["x"], [6]) + req2 = MockRequest("req2", ["y"], [5]) + + assert manager.try_allocate(req1, 0, int(1e9)) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + assert manager.try_allocate(req2, 0, int(1e9)) + manager.allocate(req2, 0) + + # 'x' should have been evicted. + assert "x" not in manager.cached + assert "x" in manager.get_freed_mm_hashes() + + +def test_get_cached_input_ids(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqX", ["m", "n", "o"], [2, 4, 3]) + + assert manager.try_allocate(req, 0, int(1e9)) + manager.allocate(req, 0) + + assert manager.try_allocate(req, 2, int(1e9)) + manager.allocate(req, 2) + + cached_ids = manager.get_cached_input_ids(req) + assert cached_ids == {0, 2} + + +def test_has_cache_restores_from_freeable(): + manager = EncoderCacheManager(cache_size=10) + req = MockRequest("reqY", ["imgZ"], [4]) + + assert manager.try_allocate(req, 0, int(1e9)) + manager.allocate(req, 0) + + manager.free_encoder_input(req, 0) + + # Should restore from freeable. + assert manager.check_and_update_cache(req, 0) + assert len(manager.cached["imgZ"]) == 1 + assert "imgZ" not in manager.freeable + assert manager.num_freeable_slots == 6 + + +def test_get_freed_mm_hashes_clears_freed_list(): + manager = EncoderCacheManager(cache_size=10) + req1 = MockRequest("reqA", ["a"], [5]) + req2 = MockRequest("reqB", ["b"], [6]) + + assert manager.try_allocate(req1, 0, int(1e9)) + manager.allocate(req1, 0) + manager.free_encoder_input(req1, 0) + + # Should trigger eviction of 'a'. + assert manager.try_allocate(req2, 0, int(1e9)) + manager.allocate(req2, 0) + + freed = manager.get_freed_mm_hashes() + assert "a" in freed + assert manager.get_freed_mm_hashes() == [] diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 070008fcbf59..07d7c12a4f5e 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -338,7 +338,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) @@ -391,7 +391,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -443,7 +443,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -490,7 +490,7 @@ def test_stop_via_update_from_output(): }, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None) diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 849c3f59ae52..78a71f10a594 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -143,7 +143,11 @@ def create_requests( mm_position = mm_positions[i] mm_item = MultiModalKwargsItem.dummy("dummy_m") mm_kwargs = [mm_item] * len(mm_position) - mm_hashes = ["hash"] * len(mm_position) + # Dummy hash for each mm item should be unique + # since encoder cache tracks entries by hash + mm_hashes = [ + "hash" + str(i) + "_" + str(j) for j in range(len(mm_position)) + ] else: mm_position = None mm_kwargs = None diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index 572af0175d11..cd82eb2ac419 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -41,8 +41,11 @@ PARAMS_MODELS_BACKENDS_TOKENIZER_MODE = [ ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "lm-format-enforcer", "auto", + None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), + ("Qwen/Qwen2.5-1.5B-Instruct", "lm-format-enforcer", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", @@ -148,7 +151,8 @@ def test_structured_output( generated_text = output.outputs[0].text assert generated_text is not None - assert "\n" not in generated_text + if guided_decoding_backend != 'lm-format-enforcer': + assert "\n" not in generated_text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=sample_json_schema) @@ -225,7 +229,7 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 4: Generate SQL statement using EBNF grammar # @@ -439,7 +443,7 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) - if guided_decoding_backend != "outlines": + if guided_decoding_backend not in ["outlines", "lm-format-enforcer"]: # # Test 11: Generate structured output using structural_tag format # diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 5a05781a03f2..941aa0a77692 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -85,7 +85,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -164,7 +164,7 @@ def test_update_states_request_finished(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -194,7 +194,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -221,7 +221,7 @@ def test_update_states_request_resumed(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -252,7 +252,7 @@ def test_update_states_no_changes(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -287,7 +287,7 @@ def test_update_states_request_unscheduled(model_runner): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index d7b4746562be..703185907826 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -205,6 +205,7 @@ def _construct_cached_request_state(req_id_suffix: int): pooling_params=None, mm_kwargs=[], mm_positions=[], + mm_hashes=[], block_ids=([], ), generator=None, num_computed_tokens=len(output_token_ids), diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index b9b2314ce573..d6cd03fb01a7 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -141,7 +141,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput: scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -207,7 +207,7 @@ def test_update_states_request_finished(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids={req_id}, - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -239,7 +239,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -266,7 +266,7 @@ def test_update_states_request_resumed(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -347,7 +347,7 @@ def test_update_states_no_changes(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) @@ -384,7 +384,7 @@ def test_update_states_request_unscheduled(model_runner, dist_init): scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), - free_encoder_input_ids=[], + free_encoder_mm_hashes=[], structured_output_request_ids={}, grammar_bitmask=None, ) diff --git a/tools/check_pickle_imports.py b/tools/check_pickle_imports.py index 444e2bf53f99..ad0ae45d1d46 100644 --- a/tools/check_pickle_imports.py +++ b/tools/check_pickle_imports.py @@ -37,7 +37,7 @@ 'vllm/distributed/utils.py', 'vllm/distributed/parallel_state.py', 'vllm/engine/multiprocessing/client.py', - 'vllm/distributed/device_communicators/custom_all_reduce_utils.py', + 'vllm/distributed/device_communicators/all_reduce_utils.py', 'vllm/distributed/device_communicators/shm_broadcast.py', 'vllm/engine/multiprocessing/engine.py', 'benchmarks/kernels/graph_machete_bench.py', diff --git a/tools/install_deepgemm.sh b/tools/install_deepgemm.sh new file mode 100755 index 000000000000..33849581d2c0 --- /dev/null +++ b/tools/install_deepgemm.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# Script to install DeepGEMM from source +# This script can be used both in Docker builds and by users locally + +set -e + +# Default values +DEEPGEMM_GIT_REPO="https://github.com/deepseek-ai/DeepGEMM.git" +DEEPGEMM_GIT_REF="7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --ref) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --ref requires an argument." >&2 + exit 1 + fi + DEEPGEMM_GIT_REF="$2" + shift 2 + ;; + --cuda-version) + if [[ -z "$2" || "$2" =~ ^- ]]; then + echo "Error: --cuda-version requires an argument." >&2 + exit 1 + fi + CUDA_VERSION="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --ref REF Git reference to checkout (default: $DEEPGEMM_GIT_REF)" + echo " --cuda-version VER CUDA version (auto-detected if not provided)" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + exit 1 + ;; + esac +done + +# Auto-detect CUDA version if not provided +if [ -z "$CUDA_VERSION" ]; then + if command -v nvcc >/dev/null 2>&1; then + CUDA_VERSION=$(nvcc --version | grep "release" | sed -n 's/.*release \([0-9]\+\.[0-9]\+\).*/\1/p') + echo "Auto-detected CUDA version: $CUDA_VERSION" + else + echo "Warning: Could not auto-detect CUDA version. Please specify with --cuda-version" + exit 1 + fi +fi + +# Extract major and minor version numbers +CUDA_MAJOR="${CUDA_VERSION%%.*}" +CUDA_MINOR="${CUDA_VERSION#${CUDA_MAJOR}.}" +CUDA_MINOR="${CUDA_MINOR%%.*}" + +echo "CUDA version: $CUDA_VERSION (major: $CUDA_MAJOR, minor: $CUDA_MINOR)" + +# Check CUDA version requirement +if [ "$CUDA_MAJOR" -lt 12 ] || { [ "$CUDA_MAJOR" -eq 12 ] && [ "$CUDA_MINOR" -lt 8 ]; }; then + echo "Skipping DeepGEMM installation (requires CUDA 12.8+ but got ${CUDA_VERSION})" + exit 0 +fi + +echo "Installing DeepGEMM from source..." +echo "Repository: $DEEPGEMM_GIT_REPO" +echo "Reference: $DEEPGEMM_GIT_REF" + +# Create a temporary directory for the build +INSTALL_DIR=$(mktemp -d) +trap 'rm -rf "$INSTALL_DIR"' EXIT + +# Clone the repository +git clone --recursive --shallow-submodules "$DEEPGEMM_GIT_REPO" "$INSTALL_DIR/deepgemm" + +echo "šŸ—ļø Building DeepGEMM" +pushd "$INSTALL_DIR/deepgemm" + +# Checkout the specific reference +git checkout "$DEEPGEMM_GIT_REF" + +# Build DeepGEMM +# (Based on https://github.com/deepseek-ai/DeepGEMM/blob/main/install.sh) +rm -rf build dist +rm -rf *.egg-info +python3 setup.py bdist_wheel + +# Install the wheel +if command -v uv >/dev/null 2>&1; then + echo "Installing DeepGEMM wheel using uv..." + # Use --system in Docker contexts, respect user's environment otherwise + if [ -n "$VLLM_DOCKER_BUILD_CONTEXT" ]; then + uv pip install --system dist/*.whl + else + uv pip install dist/*.whl + fi +else + echo "Installing DeepGEMM wheel using pip..." + python3 -m pip install dist/*.whl +fi + +popd + +echo "āœ… DeepGEMM installation completed successfully" \ No newline at end of file diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 0043456e0009..3e3b43ce2abe 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -474,6 +474,30 @@ def machete_prepack_B_fake( return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) + @register_fake("_C::cutlass_w4a8_mm") + def cutlass_w4a8_mm_fake( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + out_dtype = out_type if out_type is not None else torch.bfloat16 + return torch.empty((m, n), device=a.device, dtype=out_dtype) + + @register_fake("_C::cutlass_pack_scale_fp8") + def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: + return torch.empty_like(scales, memory_format=torch.contiguous_format) + + @register_fake("_C::cutlass_encode_and_reorder_int4b") + def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: + return torch.empty_like(b, memory_format=torch.contiguous_format) + if hasattr(torch.ops._C, "allspark_w8a16_gemm"): @@ -1032,6 +1056,30 @@ def machete_prepack_B( group_scales_type) +# CUTLASS W4A8 +def cutlass_w4a8_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b + b_q: torch.Tensor, + b_group_scales: torch.Tensor, + b_group_size: int, + b_channel_scales: torch.Tensor, + a_token_scales: torch.Tensor, + out_type: Optional[torch.dtype] = None, + maybe_schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.cutlass_w4a8_mm(a, b_q, b_group_scales, b_group_size, + b_channel_scales, a_token_scales, + out_type, maybe_schedule) + + +def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_pack_scale_fp8(scales) + + +def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: + return torch.ops._C.cutlass_encode_and_reorder_int4b(b) + + if hasattr(torch.ops._C, "permute_cols"): @register_fake("_C::permute_cols") diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index d21f07756871..0b9c625533cb 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,8 +9,7 @@ import torch -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -285,20 +284,17 @@ def forward( attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization onto implementations that support it. - TODO(luka) merge parameters into QuantDescriptor - :param dtype: quantized dtype - :param static: static or dynamic quantization - :param group_shape: quant group shape. + :param quant_key: QuantKey object that describes the quantization op :return: is fusion supported for this type of quantization """ return False @@ -317,6 +313,7 @@ def forward( attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index fac3c318a87a..ce9467efd23c 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -800,6 +800,7 @@ def forward( attn_metadata: DifferentialFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -817,6 +818,11 @@ def forward( {q,k,v}_descale to be (num_sequences, num_kv_heads). We use torch's .expand() to avoid duplicating values """ + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for DifferentialFlashAttentionImpl") + if self.lambda_full is None: self.lambda_init = self.differential_flash_attention_config[ "lambda_init"] diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index fa6f3f1b39cc..85957bea1e26 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -371,6 +371,7 @@ def forward( # type: ignore attn_metadata: DualChunkFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with DualChunkFlashAttention. Args: @@ -386,7 +387,7 @@ def forward( # type: ignore """ assert output is None, "Output tensor not supported for DualChunk" - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index e52480d5c5ce..ba7a9afe8678 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -596,6 +596,7 @@ def forward( attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -615,7 +616,7 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 9d6ab7e3217b..c5ed4c6e4032 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1238,12 +1238,13 @@ def forward( attn_metadata: T, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLAImplBase") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 63e467f5a7a2..e4c27a0ef36e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -20,7 +20,7 @@ from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -529,11 +529,9 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): + def fused_output_quant_supported(self, quant_key: QuantKey): if self.use_triton_flash_attn: - return dtype == current_platform.fp8_dtype( - ) and static and group_shape == GroupShape.PER_TENSOR + return quant_key == kFp8StaticTensorSym # Only supported in the Triton backend return False @@ -548,6 +546,7 @@ def forward( attn_metadata: ROCmFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -606,6 +605,11 @@ def forward( "fused output quantization only supported for Triton" " implementation in ROCMFlashAttentionImpl for now") + if output_block_scale is not None: + raise NotImplementedError( + "fused nvfp4 output quantization is not supported" + " for ROCMFlashAttentionImpl") + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 0bc38b414290..c1213f7620a7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -432,6 +432,7 @@ def forward( attn_metadata: "XFormersMetadata", output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -484,7 +485,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersImpl") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 04ab100c8775..2d288bcbe0c9 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -18,6 +18,7 @@ is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -54,7 +55,7 @@ def check_xformers_availability(): return USE_XFORMERS_OPS -class Attention(nn.Module): +class Attention(nn.Module, AttentionLayerBase): """Attention layer. This class takes query, key, and value tensors as input. The input tensors @@ -495,6 +496,7 @@ def unified_attention_with_output( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() @@ -510,7 +512,8 @@ def unified_attention_with_output( kv_cache, attn_metadata, output=output, - output_scale=output_scale) + output_scale=output_scale, + output_block_scale=output_block_scale) maybe_save_kv_layer_to_connector(layer_name, kv_cache) @@ -522,6 +525,7 @@ def unified_attention_with_output_fake( output: torch.Tensor, layer_name: str, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> None: return @@ -529,7 +533,7 @@ def unified_attention_with_output_fake( direct_register_custom_op( op_name="unified_attention_with_output", op_func=unified_attention_with_output, - mutates_args=["output"], + mutates_args=["output", "output_block_scale"], fake_impl=unified_attention_with_output_fake, dispatch_key=current_platform.dispatch_key, ) diff --git a/vllm/attention/layers/encoder_only_attention.py b/vllm/attention/layers/encoder_only_attention.py index 7b3dcbd823c0..cea05df5b96d 100644 --- a/vllm/attention/layers/encoder_only_attention.py +++ b/vllm/attention/layers/encoder_only_attention.py @@ -5,13 +5,13 @@ from typing import Optional import torch -from transformers import CacheConfig from vllm import envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.layer import Attention from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, subclass_attention_backend) diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 920d21bda3c5..ec763e9baac6 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -11,18 +11,21 @@ - HuggingFace - VisionArena """ +import ast import base64 import io import json import logging +import math import random from abc import ABC, abstractmethod -from collections.abc import Mapping +from collections.abc import Iterator, Mapping +from contextlib import suppress from copy import deepcopy from dataclasses import dataclass from functools import cache from io import BytesIO -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, Union, cast import numpy as np from PIL import Image @@ -114,7 +117,9 @@ def __init__( def apply_multimodal_chat_transformation( self, prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + mm_content: Optional[ + Union[MultiModalDataDict, dict, list[dict]] + ] = None) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -122,7 +127,15 @@ def apply_multimodal_chat_transformation( """ content = [{"text": prompt, "type": "text"}] if mm_content is not None: - content.append(mm_content) + if isinstance(mm_content, list): + content.extend(cast(list[dict[str, Any]], mm_content)) + elif isinstance(mm_content, dict): + content.append(mm_content) + else: + raise TypeError( + "Could not process multimodal content of type: " + + f"{type(mm_content)}" + ) return [{"role": "user", "content": content}] def load_data(self) -> None: @@ -362,90 +375,536 @@ def process_video(video: Any) -> Mapping[str, Any]: class RandomDataset(BenchmarkDataset): + """ + Synthetic text-only dataset for serving/throughput benchmarks. + + Strategy: + - Sample input/output token lengths per request from integer-uniform ranges + around configured means (controlled by range_ratio). + - Prepend a fixed random prefix of length prefix_len. + - Generate the remaining tokens as a reproducible sequence: + (offset + index + arange(input_len)) % vocab_size. + - Decode then re-encode/truncate to ensure prompt token counts match. + - Uses numpy.default_rng seeded with random_seed for reproducible sampling. + """ # Default values copied from benchmark_serving.py for the random dataset. DEFAULT_PREFIX_LEN = 0 DEFAULT_RANGE_RATIO = 0.0 DEFAULT_INPUT_LEN = 1024 DEFAULT_OUTPUT_LEN = 128 - def __init__( - self, - **kwargs, - ) -> None: + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - random.seed(self.random_seed) - np.random.seed(self.random_seed) + # Use numpy's default_rng for deterministic sampling + # Do not use random.seed() or np.random.seed() elsewhere in this class. + # This ensures that the RNG is isolated from global RNG state. + self._rng = np.random.default_rng(self.random_seed) def sample( self, tokenizer: PreTrainedTokenizerBase, num_requests: int, + request_id_prefix: str = "", prefix_len: int = DEFAULT_PREFIX_LEN, range_ratio: float = DEFAULT_RANGE_RATIO, input_len: int = DEFAULT_INPUT_LEN, output_len: int = DEFAULT_OUTPUT_LEN, - request_id_prefix: str = "", **kwargs, ) -> list[SampleRequest]: - # Enforce range_ratio < 1 - assert range_ratio < 1.0, ( - "random_range_ratio must be < 1.0 to ensure a valid sampling range" + + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer ) + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) vocab_size = tokenizer.vocab_size - num_special_tokens = tokenizer.num_special_tokens_to_add() - real_input_len = input_len - num_special_tokens - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + requests = [] + for i in range(num_requests): + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + requests.append( + SampleRequest( + prompt=prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + request_id=request_id_prefix + str(i), + ) + ) + return requests + + def get_prefix( + self, tokenizer: PreTrainedTokenizerBase, prefix_len: int + ) -> list[int]: + """ + Get the prefix for the dataset. + """ + return ( + self._rng.integers( + 0, tokenizer.vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) - # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(real_input_len * (1 - range_ratio)) - input_high = int(real_input_len * (1 + range_ratio)) - output_low = int(output_len * (1 - range_ratio)) - output_high = int(output_len * (1 + range_ratio)) + def get_sampling_params( + self, + num_requests: int, + range_ratio: float, + input_len: int, + output_len: int, + tokenizer: PreTrainedTokenizerBase, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get the sampling parameters for the dataset. + """ + # Enforce range_ratio < 1 + if not (0.0 <= range_ratio < 1.0): + raise ValueError("range_ratio must be in [0, 1).") + num_special_tokens = int(tokenizer.num_special_tokens_to_add()) + real_input_len = max(0, int(input_len) - num_special_tokens) + # Bounds use floor for low and ceil for high + input_low = math.floor(real_input_len * (1 - range_ratio)) + input_high = math.ceil(real_input_len * (1 + range_ratio)) + output_low = math.floor(output_len * (1 - range_ratio)) + output_high = math.ceil(output_len * (1 + range_ratio)) + # Ensure the lower bound for output length is at least 1 to + # prevent sampling 0 tokens. + output_low = max(output_low, 1) + + if input_low > input_high: + raise ValueError( + "Invalid input sampling interval: " + f"low={input_low} > high={input_high}" + ) + if output_low > output_high: + raise ValueError( + "Invalid output sampling interval: " + f"low={output_low} > high={output_high}" + ) - # Add logging for debugging logger.info( "Sampling input_len from [%s, %s] and output_len from [%s, %s]", - input_low, input_high, output_low, output_high) + input_low, + input_high, + output_low, + output_high, + ) - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, + input_lens = self._rng.integers(input_low, input_high + 1, + size=num_requests) + output_lens = self._rng.integers(output_low, output_high + 1, + size=num_requests) + offsets = self._rng.integers(0, tokenizer.vocab_size, size=num_requests) - offsets = np.random.randint(0, vocab_size, size=num_requests) + return input_lens, output_lens, offsets - requests = [] + + def generate_token_sequence( + self, + *, + tokenizer: PreTrainedTokenizerBase, + prefix_token_ids: list[int], + prefix_len: int, + vocab_size: int, + input_len: int, + offset: int, + index: int, + ) -> tuple[str, int]: + """ + Returns (prompt, total_input_len). + + NOTE: After decoding the prompt we have to encode and decode it again. + This is done because in some cases N consecutive tokens + give a string tokenized into != N number of tokens. + For example for GPT2Tokenizer: + [6880, 6881] -> ['Ä calls', 'here'] -> + [1650, 939, 486] -> ['Ä call', 'sh', 'ere'] + To avoid uncontrolled change of the prompt length, + the encoded sequence is truncated before being decode again. + """ + # Build the inner sequence by sampling sequentially from the vocab + inner_seq = ((offset + index + np.arange(input_len)) + % vocab_size).tolist() + token_sequence = prefix_token_ids + inner_seq + + # Decode, then re-encode and truncate to preserve token count invariants + prompt = tokenizer.decode(token_sequence) + total_input_len = prefix_len + int(input_len) + + re_encoded_sequence = tokenizer.encode( + prompt, add_special_tokens=False)[:total_input_len] + prompt = tokenizer.decode(re_encoded_sequence) + total_input_len = len(re_encoded_sequence) + + return prompt, total_input_len + + +# ----------------------------------------------------------------------------- +# MultiModalDataset Implementation +# ----------------------------------------------------------------------------- + +class RandomMultiModalDataset(RandomDataset): + """ + Synthetic multimodal dataset (text + images) that extends RandomDataset. + + Status: + - Images: supported via synthetic RGB data. + - Video: not yet supported (TODO: implement video generation method). + - Audio: not yet supported. + + Sampling overview: + 1) Number of items per request is sampled uniformly from the integer range + [floor(nĀ·(1āˆ’r)), ceil(nĀ·(1+r))], where n is the base count and r is + `num_mm_items_range_ratio` in [0, 1]. r=0 keeps it fixed; r=1 allows 0. + The maximum is further clamped to the sum of per-modality limits. + 2) Each item’s modality and shape is sampled from `bucket_config`, a dict + mapping (height, width, num_frames) → probability. We treat + `num_frames`=1 as image and and `num_frames` > 1 as video. + Entries with zero probability are removed and the rest are renormalized + to sum to 1. + 3) Per-modality hard caps are enforced via `limit_mm_per_prompt`. + When a modality reaches its cap, all of its buckets are excluded and the + remaining probabilities are renormalized. + + Example bucket configuration: + {(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.1} + - Two image buckets (`num_frames`=1) and one video bucket + (`num_frames`=16). + OBS.: Only image sampling is supported for now. + """ + + IS_MULTIMODAL = True + # NOTE: video sampling is WIP. Setting it to 0. + DEFAULT_LIMIT_MM_PER_PROMPT = {"image": 255, "video": 0} + + DEFAULT_BASE_ITEMS_PER_REQUEST = 1 + DEFAULT_NUM_MM_ITEMS_RANGE_RATIO = 0.0 + DEFAULT_MM_ITEM_BUCKET_CONFIG = { + (256, 256, 1): 0.5, + (720, 1280, 1): 0.5, + (720, 1280, 16): 0.0, + } + DEFAULT_ENABLE_MULTIMODAL_CHAT = False + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + + def generate_synthetic_image(self, width: int, height: int) -> Image.Image: + """Generate synthetic PIL image with random RGB values. + + NOTE: iid pixel sampling results in worst-case compression + (good for stressing I/O), but very unlike real photos. + We could consider a ā€œlow-freqā€ mode (e.g., noise blur) + to emulate network realism instead of max stress. + """ + random_pixels = self._rng.integers( + 0, + 256, + (height, width, 3), + dtype=np.uint8, + ) + return Image.fromarray(random_pixels) + + def generate_synthetic_video(self, width: int, + height: int, + num_frames: int) -> Any: + """Generate synthetic video with random values. + + TODO: Finish this method. + """ + raise NotImplementedError("Video sampling is WIP.") + + def map_config_to_modality(self, config: tuple[int, int, int]) -> str: + """Map the configuration to the modality.""" + if config[-1] == 1: + return "image" + elif config[-1] > 1: + return "video" + else: + raise ValueError(f"Invalid multimodal item configuration: {config}") + + def normalize_bucket_config(self, bucket_config: dict[tuple[int, int, int], + float]) -> dict[tuple[int, int, int], float]: + """ + Remove zero probability entries + and normalize the bucket config to sum to 1. + """ + # Raise error if value is negative + if any(v < 0 for v in bucket_config.values()): + raise ValueError("Bucket config values must be non-negative.") + # Remove zero probability entries + bucket_config = {k: v for k, v in bucket_config.items() if v > 0} + # if bucket config is empty, raise error + if not bucket_config: + raise ValueError("Got invalid bucket config. " + "Bucket config values must be non-zero.") + # Normalize the remaining bucket config to sum to 1 + total = sum(bucket_config.values()) + return {k: v / total for k, v in bucket_config.items()} + + + def generate_mm_item(self, + mm_item_config: tuple[int, int, int], + ) -> Mapping[str, Any]: + """ + Create synthetic images and videos and + apply process_image/process_video respectively. + This follows the OpenAI API chat completions + https://github.com/openai/openai-python + """ + + if self.map_config_to_modality(mm_item_config) == "image": + return process_image(self.generate_synthetic_image( + mm_item_config[1], + mm_item_config[0])) + elif self.map_config_to_modality(mm_item_config) == "video": + return process_video(self.generate_synthetic_video( + mm_item_config[1], + mm_item_config[0], + mm_item_config[2])) + else: + raise ValueError(f"Invalid multimodal item configuration: " + f"{mm_item_config}") + + + def get_mm_item_sampling_params( + self, + base_items_per_request: int, + num_mm_items_range_ratio: float, + limit_mm_per_prompt: dict[str, int], + bucket_config: dict[tuple[int, int, int], float], + ) -> tuple[int, int, dict[str, int], dict[tuple[int, int, int], float]]: + """ + Get the sampling parameters for the multimodal items. + """ + # Enforce num_mm_items_range_ratio <= 1 + if not (0.0 <= num_mm_items_range_ratio <= 1.0): + raise ValueError("num_mm_items_range_ratio must be in [0, 1].") + + # Ensure modalities to sample are in limit_mm_per_prompt + for k, v in bucket_config.items(): + # get modality from bucket config + modality = self.map_config_to_modality(k) + if modality not in limit_mm_per_prompt: + raise ValueError(f"Modality {modality} is not in " + f"limit_mm_per_prompt: " + f"{limit_mm_per_prompt.keys()}") + + # Remove zero probability entries + # and normalize bucket config to sum to 1 + bucket_config = self.normalize_bucket_config(bucket_config) + logger.info( + "Normalized bucket config: %s", bucket_config, + ) + # Only consider limit per prompt for modalities in bucket config + allowed_modalities = {self.map_config_to_modality(cfg) + for cfg in bucket_config} + limit_mm_per_prompt = { + k: v for k, v in limit_mm_per_prompt.items() + if k in allowed_modalities} + if not limit_mm_per_prompt: + raise ValueError("No valid limits for modalities present in " + "bucket_config.") + + logger.info( + "Updated mm-limit-per-prompt: %s", limit_mm_per_prompt, + ) + + # Get max and min num mm items and ensure + # it is at most the sum of limit_mm_per_prompt for all modalities + max_num_mm_items = min( + sum(limit_mm_per_prompt.values()), + math.ceil(base_items_per_request * (1 + num_mm_items_range_ratio)) + ) + # Ensure min num mm items is at least 0 + min_num_mm_items = max( + 0, + math.floor(base_items_per_request * (1 - num_mm_items_range_ratio)) + ) + # Raise error if min num mm items is greater than max num mm items + if min_num_mm_items > max_num_mm_items: + raise ValueError(f"Min num mm items is greater than max mm items: " + f"{min_num_mm_items} > {max_num_mm_items}") + + logger.info( + "Sampling number of multimodal items from [%s, %s]", + min_num_mm_items, max_num_mm_items, + ) + + return ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) + + def get_mm_item_iterator( + self, + min_num_mm_items: int, + max_num_mm_items: int, + bucket_config: dict[tuple[int, int, int], float], + limit_mm_per_prompt: dict[str, int], + ) -> Iterator[tuple[int,int, int]]: + """ + Iterator over the multimodal items for each request + whose size is between min_num_mm_items and max_num_mm_items. + + Loop over the bucket config and sample a multimodal item. + Loop until the number of multimodal items sampled is equal to + request_num_mm_items or limit of multimodal items per prompt + for all modalities is reached. + + Note: + - This function operates on a per-request shallow copy of + `bucket_config` (tuple->float). The original dict passed to + `sample` is not mutated. If this ever changes, a test + is implemented and will fail. + """ + # Get the number of multimodal items to sample + request_num_mm_items = int( + self._rng.integers(min_num_mm_items, max_num_mm_items + 1) + ) + # If request_num_mm_items is 0, yield an empty iterator + if request_num_mm_items == 0: + return + # Initialize modality counters + modality_counter = {self.map_config_to_modality(k): 0 + for k in bucket_config} + # Copy the bucket config to avoid modifying the original + bucket_config_copy = bucket_config.copy() + # Loop over the number of multimodal items to sample + while sum(modality_counter.values()) < request_num_mm_items: + # Sample a multimodal item config + mm_item_config = self._rng.choice(list(bucket_config_copy.keys()), + p=list(bucket_config_copy.values())) + modality = self.map_config_to_modality(mm_item_config) + # Check that modality count is less than limit per prompt + if modality_counter[modality] < limit_mm_per_prompt[modality]: + modality_counter[modality] += 1 + yield ( + mm_item_config + ) + else: + # If the counter is greater than the limit per prompt + # set all multimodal items of this modality to 0 + for k, v in bucket_config_copy.items(): + if self.map_config_to_modality(k) == modality: + bucket_config_copy[k] = 0 + # If all configs are 0, break the loop + # This should not happen as request_num_mm_items is at most + # the sum of limit_mm_per_prompt for all modalities + if all(v == 0 for v in bucket_config_copy.values()): + logger.warning("Exhausted all multimodal items " + "of modality %s", + modality) + break + # Renormalize the bucket config + bucket_config_copy = self.normalize_bucket_config( + bucket_config_copy) + + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + request_id_prefix: str = "", + prefix_len: int = RandomDataset.DEFAULT_PREFIX_LEN, + range_ratio: float = RandomDataset.DEFAULT_RANGE_RATIO, + input_len: int = RandomDataset.DEFAULT_INPUT_LEN, + output_len: int = RandomDataset.DEFAULT_OUTPUT_LEN, + limit_mm_per_prompt: dict[str, int] = DEFAULT_LIMIT_MM_PER_PROMPT, + base_items_per_request: int = DEFAULT_BASE_ITEMS_PER_REQUEST, + num_mm_items_range_ratio: float = DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + bucket_config: dict[tuple[int, int, int], float] = + DEFAULT_MM_ITEM_BUCKET_CONFIG, + enable_multimodal_chat: bool = DEFAULT_ENABLE_MULTIMODAL_CHAT, + **kwargs, + ) -> list[SampleRequest]: + + # NOTE: Video sampling is WIP. Raise error if video is in bucket config + # and probability is non-zero. + if any(self.map_config_to_modality(cfg) == "video" and p > 0 + for cfg, p in bucket_config.items()): + raise NotImplementedError("Video sampling not implemented; " + "set its probability to 0.") + + # Get the sampling parameters for the dataset + input_lens, output_lens, offsets = self.get_sampling_params( + num_requests, range_ratio, input_len, output_len, tokenizer + ) + + ( + min_num_mm_items, + max_num_mm_items, + limit_mm_per_prompt, + bucket_config, + ) = self.get_mm_item_sampling_params( + base_items_per_request, + num_mm_items_range_ratio, + limit_mm_per_prompt, + bucket_config, + ) + + # Generate prefix once + prefix_token_ids = self.get_prefix(tokenizer, prefix_len) + vocab_size = tokenizer.vocab_size + # Add synthetic multimodal items to each request + mm_requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() - token_sequence = prefix_token_ids + inner_seq - prompt = tokenizer.decode(token_sequence) - # After decoding the prompt we have to encode and decode it again. - # This is done because in some cases N consecutive tokens - # give a string tokenized into != N number of tokens. - # For example for GPT2Tokenizer: - # [6880, 6881] -> ['Ä calls', 'here'] -> - # [1650, 939, 486] -> ['Ä call', 'sh', 'ere'] - # To avoid uncontrolled change of the prompt length, - # the encoded sequence is truncated before being decode again. - total_input_len = prefix_len + int(input_lens[i]) - re_encoded_sequence = tokenizer.encode( - prompt, add_special_tokens=False)[:total_input_len] - prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = len(re_encoded_sequence) - requests.append( - SampleRequest( + prompt, total_input_len = self.generate_token_sequence( + tokenizer=tokenizer, + prefix_token_ids=prefix_token_ids, + prefix_len=prefix_len, + vocab_size=vocab_size, + input_len=int(input_lens[i]), + offset=int(offsets[i]), + index=i, + ) + # Get multimodal item iterator for a given request + mm_item_iterator = self.get_mm_item_iterator( + min_num_mm_items, + max_num_mm_items, + bucket_config, + limit_mm_per_prompt, + ) + + mm_content = cast(list[dict[str, Any]], [ + self.generate_mm_item(mm_item_config) + for mm_item_config in mm_item_iterator + ]) + + if enable_multimodal_chat: + # NOTE: For now this option is only provided for completeness + # given that the serve.py benchmark currently does not use it. + mm_chat_prompt: Any = prompt + mm_chat_prompt = self.apply_multimodal_chat_transformation( + prompt, mm_content) + sample_request = SampleRequest( + prompt=mm_chat_prompt, + prompt_len=total_input_len, + expected_output_len=int(output_lens[i]), + multi_modal_data=None, + request_id=request_id_prefix + str(i), + ) + else: + sample_request = SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), + multi_modal_data=mm_content, request_id=request_id_prefix + str(i), - )) - return requests - + ) + mm_requests.append(sample_request) + return mm_requests # ----------------------------------------------------------------------------- # ShareGPT Dataset Implementation @@ -545,8 +1004,8 @@ def add_dataset_parser(parser: FlexibleArgumentParser): type=str, default="random", choices=[ - "sharegpt", "burstgpt", "sonnet", "random", "hf", "custom", - "prefix_repetition" + "sharegpt", "burstgpt", "sonnet", "random", "random-mm", "hf", + "custom", "prefix_repetition" ], help="Name of the dataset to benchmark on.", ) @@ -647,6 +1106,98 @@ def add_dataset_parser(parser: FlexibleArgumentParser): "input_len * (1 + range_ratio)]."), ) + # random multimodal dataset options + random_mm_group = parser.add_argument_group( + "random multimodal dataset options extended from random dataset") + random_mm_group.add_argument( + "--random-mm-base-items-per-request", + type=int, + default=RandomMultiModalDataset.DEFAULT_BASE_ITEMS_PER_REQUEST, + help=( + "Base number of multimodal items per request for random-mm. " + "Actual per-request count is sampled around this base using " + "--random-mm-num-mm-items-range-ratio." + ), + ) + random_mm_group.add_argument( + "--random-mm-num-mm-items-range-ratio", + type=float, + default=RandomMultiModalDataset.DEFAULT_NUM_MM_ITEMS_RANGE_RATIO, + help=( + "Range ratio r in [0, 1] for sampling items per request. " + "We sample uniformly from the closed integer range " + "[floor(n*(1-r)), ceil(n*(1+r))] " + "where n is the base items per request. " + "r=0 keeps it fixed; r=1 allows 0 items. The maximum is clamped " + "to the sum of per-modality limits from " + "--random-mm-limit-mm-per-prompt. " + "An error is raised if the computed min exceeds the max." + ), + ) + random_mm_group.add_argument( + "--random-mm-limit-mm-per-prompt", + type=json.loads, + default=RandomMultiModalDataset.DEFAULT_LIMIT_MM_PER_PROMPT, + help=( + "Per-modality hard caps for items attached per request, e.g. " + "'{\"image\": 3, \"video\": 0}'. The sampled per-request item " + "count is clamped to the sum of these limits. When a modality " + "reaches its cap, its buckets are excluded and probabilities are " + "renormalized." + "OBS.: Only image sampling is supported for now." + ), + ) + + def _parse_mm_bucket_config(v: object) -> dict[tuple[int, int, int], float]: + # If already a dict (e.g., programmatic call), normalize keys + def normalize(d: dict) -> dict[tuple[int, int, int], float]: + out: dict[tuple[int, int, int], float] = {} + for k, val in d.items(): + key = k + if isinstance(key, str): + with suppress(Exception): + key = ast.literal_eval(key) + if not (isinstance(key, tuple) and len(key) == 3 + and all(isinstance(x, int) for x in key)): + raise ValueError( + f"Invalid bucket key {k!r}. Expected tuple (H, W, T)." + ) + out[(int(key[0]), int(key[1]), int(key[2]))] = float(val) + return out + + if isinstance(v, dict): + return normalize(v) + if isinstance(v, str): + # Python literal (supports tuple keys) + parsed = ast.literal_eval(v) + if not isinstance(parsed, dict): + raise ValueError("Bucket config must parse to a dict.") + return normalize(parsed) + raise ValueError("Unsupported value for --random-mm-bucket-config.") + + random_mm_group.add_argument( + "--random-mm-bucket-config", + type=_parse_mm_bucket_config, + default=RandomMultiModalDataset.DEFAULT_MM_ITEM_BUCKET_CONFIG, + help=( + "The bucket config is a dictionary mapping a multimodal item" + "sampling configuration to a probability." + "Currently allows for 2 modalities: images and videos. " + "An bucket key is a tuple of (height, width, num_frames)" + "The value is the probability of sampling that specific item. " + "Example: " + "--random-mm-bucket-config " + "{(256, 256, 1): 0.5, (720, 1280, 1): 0.4, (720, 1280, 16): 0.10} " + "First item: images with resolution 256x256 w.p. 0.5" + "Second item: images with resolution 720x1280 w.p. 0.4 " + "Third item: videos with resolution 720x1280 and 16 frames w.p. 0.1" + "OBS.: If the probabilities do not sum to 1, they are normalized." + "OBS bis.: Only image sampling is supported for now." + ), + ) + + + hf_group = parser.add_argument_group("hf dataset options") hf_group.add_argument("--hf-subset", type=str, @@ -697,6 +1248,10 @@ def add_dataset_parser(parser: FlexibleArgumentParser): def get_samples(args, tokenizer) -> list[SampleRequest]: + + if not hasattr(args, "request_id_prefix"): + args.request_id_prefix = "" + if args.dataset_name == "custom": dataset = CustomDataset(dataset_path=args.dataset_path) input_requests = dataset.sample( @@ -821,6 +1376,22 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: range_ratio=args.random_range_ratio, request_id_prefix=args.request_id_prefix, ), + "random-mm": + lambda: RandomMultiModalDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + prefix_len=args.random_prefix_len, + range_ratio=args.random_range_ratio, + input_len=args.random_input_len, + output_len=args.random_output_len, + base_items_per_request=args.random_mm_base_items_per_request, + limit_mm_per_prompt=args.random_mm_limit_mm_per_prompt, + num_mm_items_range_ratio=args.random_mm_num_mm_items_range_ratio, + bucket_config=args.random_mm_bucket_config, + request_id_prefix=args.request_id_prefix, + ), "prefix_repetition": lambda: PrefixRepetitionRandomDataset( random_seed=args.seed, dataset_path=args.dataset_path @@ -836,6 +1407,13 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: } try: + # Enforce endpoint compatibility for multimodal datasets. + if args.dataset_name == "random-mm" and args.endpoint_type not in [ + "openai-chat"]: + raise ValueError( + "Multi-modal content (images) is only supported on " + "'openai-chat' backend." + ) input_requests = dataset_mapping[args.dataset_name]() except KeyError as err: raise ValueError(f"Unknown dataset: {args.dataset_name}") from err diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 3dec939c2835..0d8d562514e3 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -12,7 +12,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + GroupShape, QuantKey, ScaleDesc, kFp8DynamicTensorSym, kFp8DynamicTokenSym, + kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe @@ -21,6 +22,7 @@ logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 def empty_bf16(*args, **kwargs): @@ -31,41 +33,12 @@ def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") -RMS_OP = torch.ops._C.rms_norm.default -RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default - - -class QuantKey(NamedTuple): - """ - Named tuple for identifying the type of quantization. - dtype: quantized data type - static: static quantization if True, dynamic if False - group_shape: quantization group shape - symmetric: symmetric if True, asymmetric if False - - TODO(luka) use QuantDescriptor once standardized: - https://github.com/vllm-project/vllm/issues/8913 +def empty_i32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.int32, device="cuda") - """ - dtype: torch.dtype - static: bool - group_shape: GroupShape - symmetric: bool = True - def __str__(self): - group_shape = ('per_tensor' - if self.group_shape == GroupShape.PER_TENSOR else - ('per_token' if self.group_shape == GroupShape.PER_TOKEN - else str(self.group_shape))) - - return (f"QuantKey({'static' if self.static else 'dynamic'}," - f"{fx.graph.dtype_abbrs[self.dtype]},{group_shape}," - f"{'a' if not self.symmetric else ''}symmetric)") - - -kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, GroupShape.PER_TENSOR, True) -kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TENSOR, True) -kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, GroupShape.PER_TOKEN, True) +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default QUANT_OPS: dict[QuantKey, OpOverload] = { kFp8StaticTensorSym: @@ -75,6 +48,9 @@ def __str__(self): kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501 } +if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): + QUANT_OPS[ + kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 class FusedRMSQuantKey(NamedTuple): @@ -187,11 +163,9 @@ def __init__(self, quant_dtype: torch.dtype, symmetric=True): fused_key = FusedRMSQuantKey(fused_add=False, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, fused_key) def register(self, pm_pass: PatternMatcherPass): @@ -244,11 +218,9 @@ def __init__(self, quant_dtype: torch.dtype, symmetric=True): key = FusedRMSQuantKey(fused_add=True, - quant=QuantKey( - dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric)) + quant=QuantKey(dtype=quant_dtype, + scale=kStaticTensorScale, + symmetric=symmetric)) super().__init__(epsilon, key) def register(self, pm_pass: PatternMatcherPass, @@ -337,10 +309,10 @@ def __init__(self, quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=False, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) @@ -435,10 +407,10 @@ def __init__(self, quant_dtype: torch.dtype, group_shape: GroupShape = GroupShape.PER_TOKEN, symmetric=True): + scale = ScaleDesc(torch.float32, False, group_shape) key = FusedRMSQuantKey(fused_add=True, quant=QuantKey(dtype=quant_dtype, - static=False, - group_shape=group_shape, + scale=scale, symmetric=symmetric)) super().__init__(epsilon, key) diff --git a/vllm/compilation/fusion_attn.py b/vllm/compilation/fusion_attn.py index 1f77a2667613..f942afe6a28e 100644 --- a/vllm/compilation/fusion_attn.py +++ b/vllm/compilation/fusion_attn.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod + import torch import torch._inductor.pattern_matcher as pm from torch._higher_order_ops.auto_functionalize import auto_functionalized @@ -11,44 +13,41 @@ from vllm.attention import Attention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform +from vllm.utils import round_up -from .fusion import QUANT_OPS, GroupShape, QuantKey, empty_bf16, empty_fp32 +from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 ATTN_OP = torch.ops.vllm.unified_attention_with_output.default RESHAPE_OP = torch.ops.aten.reshape.default -class AttentionStaticQuantPattern: +class AttentionQuantPattern(ABC): """ - Fusion for Attention+StaticQuant. - - Only triggers when the attention implementation returns True in - `fused_output_quant_supported()`. If the pattern is found, the StaticQuant - op will be removed from the graph, and its scale will be passed into - Attention op as the `output_scale` argument. + The base class for Attn+Quant fusions. + Should not be used directly. """ def __init__( self, layer: Attention, - quant_dtype: torch.dtype, - symmetric=True, + quant_key: QuantKey, ): self.layer = layer self.layer_name = layer.layer_name self.num_heads = layer.num_heads self.head_size = layer.head_size - self.quant_dtype = quant_dtype - self.quant_key = QuantKey(dtype=quant_dtype, - static=True, - group_shape=GroupShape.PER_TENSOR, - symmetric=symmetric) + self.quant_key = quant_key + self.quant_dtype = quant_key.dtype + assert self.quant_key in QUANT_OPS, \ f"unsupported quantization scheme {self.quant_key}" self.QUANT_OP = QUANT_OPS[self.quant_key] @@ -57,12 +56,49 @@ def empty_quant(self, *args, **kwargs): kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} return torch.empty(*args, **kwargs) + @staticmethod + def wrap_trace_fn(process_fx, trace_fn): + + def wrapped(*args, **kwargs): + return process_fx(trace_fn(*args, **kwargs)) + + return wrapped + + @staticmethod + def fx_view_to_reshape(gm: torch.fx.GraphModule): + from torch._inductor.fx_passes.post_grad import view_to_reshape + view_to_reshape(gm) + return gm + def register_if_supported(self, pm_pass: PatternMatcherPass): - if self.layer.impl.fused_output_quant_supported( - self.quant_dtype, self.quant_key.static, - self.quant_key.group_shape): + if self.layer.impl.fused_output_quant_supported(self.quant_key): self._register(pm_pass) + @abstractmethod + def _register(self, pm_pass: PatternMatcherPass): + raise NotImplementedError + + +class AttentionFp8StaticQuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Fp8StaticQuant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Fp8StaticQuant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ + + def __init__( + self, + layer: Attention, + symmetric: bool = True, + ): + quant_key = QuantKey(dtype=FP8_DTYPE, + scale=kStaticTensorScale, + symmetric=symmetric) + super().__init__(layer, quant_key) + def _register(self, pm_pass: PatternMatcherPass): def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -74,9 +110,10 @@ def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, value=v, output=output_attn, layer_name=self.layer_name, - output_scale=None) - attn_out_view = RESHAPE_OP(at1[1], - [-1, self.num_heads * self.head_size]) + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) at2 = auto_functionalized(self.QUANT_OP, result=output_quant, input=attn_out_view, @@ -98,7 +135,8 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, value=v, output=output_attn, layer_name=self.layer_name, - output_scale=scale) + output_scale=scale, + output_block_scale=None) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) # Need custom fake mode, otherwise tracing happens with real tensors. @@ -114,21 +152,94 @@ def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, empty_fp32(1, 1) # scale ] - def wrap_trace_fn(process_fx, trace_fn): + pm.register_replacement( + pattern, replacement, inputs, + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) + - def wrapped(*args, **kwargs): - return process_fx(trace_fn(*args, **kwargs)) +class AttentionNvfp4QuantPattern(AttentionQuantPattern): + """ + Fusion for Attention+Nvfp4Quant. + + Only triggers when the attention implementation returns True in + `fused_output_quant_supported()`. If the pattern is found, the + Nvfp4Quant op will be removed from the graph, and its scale + will be passed into Attention op as the `output_scale` argument. + """ - return wrapped + def __init__(self, layer: Attention): + super().__init__(layer, kNvfp4Quant) - def fx_view_to_reshape(gm: torch.fx.GraphModule): - from torch._inductor.fx_passes.post_grad import view_to_reshape - view_to_reshape(gm) - return gm + def _register(self, pm_pass: PatternMatcherPass): + + def pattern(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + at1 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=None, + output_block_scale=None) + attn_out_view = RESHAPE_OP( + at1[1], [q.shape[0], self.num_heads * self.head_size]) + at2 = auto_functionalized(self.QUANT_OP, + output=output_quant, + input=attn_out_view, + output_scale=output_scale, + input_scale=input_scale) + output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) + return at2[1], output_scale_view + + def replacement(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + output_attn: torch.Tensor, output_quant: torch.Tensor, + output_scale: torch.Tensor, input_scale: torch.Tensor): + # attention output in quant_dtype + output_attn = torch.ops.aten.full.default( + [q.shape[0], self.num_heads, self.head_size // 2], + 0.0, + dtype=self.quant_dtype, + device=q.device) + # attention output block scale + output_scale_view = torch.ops.aten.view.dtype( + output_scale, FP8_DTYPE) + at2 = auto_functionalized(ATTN_OP, + query=q, + key=k, + value=v, + output=output_attn, + layer_name=self.layer_name, + output_scale=input_scale, + output_block_scale=output_scale_view) + output = RESHAPE_OP(at2[1], + [-1, self.num_heads * self.head_size // 2]) + return output, at2[2] + + # Need custom fake mode, otherwise tracing happens with real tensors. + # That would not work for the unified_attention custom op. + with unset_fake_temporarily(), FakeTensorMode(): + inputs = [ + empty_bf16(5, self.num_heads, self.head_size), # q + empty_bf16(5, self.num_heads, self.head_size), # k + empty_bf16(5, self.num_heads, self.head_size), # v + empty_bf16(5, self.num_heads, self.head_size), # output_attn + self.empty_quant(5, self.num_heads * self.head_size // + 2), # output_quant + empty_i32(128, + round_up(self.num_heads * self.head_size // 16, + 4)), # output_scale + empty_fp32(1, 1), # input_scale + ] pm.register_replacement( pattern, replacement, inputs, - wrap_trace_fn(fx_view_to_reshape, pm.fwd_only), pm_pass) + AttentionQuantPattern.wrap_trace_fn( + AttentionQuantPattern.fx_view_to_reshape, pm.fwd_only), + pm_pass) class AttnFusionPass(VllmInductorPass): @@ -151,8 +262,12 @@ def __init__(self, config: VllmConfig): attn_layers = get_layers_from_vllm_config(config, Attention) for layer_name, layer in attn_layers.items(): - pattern = AttentionStaticQuantPattern(layer, FP8_DTYPE) - pattern.register_if_supported(self.patterns) + pattern_fp8 = AttentionFp8StaticQuantPattern(layer) + pattern_fp8.register_if_supported(self.patterns) + + pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) + pattern_nvfp4.register_if_supported(self.patterns) + if len(attn_layers) == 0: logger.warning( "Attention + quant fusion is enabled, but no attention layers " @@ -175,4 +290,6 @@ def __call__(self, graph: torch.fx.graph.Graph) -> None: self.end_and_log() def uuid(self): - return VllmInductorPass.hash_source(self, AttentionStaticQuantPattern) + return VllmInductorPass.hash_source(self, AttentionQuantPattern, + AttentionFp8StaticQuantPattern, + AttentionNvfp4QuantPattern) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index fbc4dd3989f5..cd0e17977ede 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1119,9 +1119,20 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ - "fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", - "fbgemm_fp8", "compressed-tensors", "experts_int8", "quark", - "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" + "fp8", + "modelopt", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "fbgemm_fp8", + "compressed-tensors", + "experts_int8", + "quark", + "modelopt_fp4", + "bitblas", + "gptq_bitblas", + "inc", + "petit_nvfp4", ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -1153,6 +1164,7 @@ def _verify_quantization(self) -> None: "moe_wna16", "modelopt", "modelopt_fp4", + "petit_nvfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides @@ -3045,7 +3057,8 @@ def get_served_model_name(model: str, return served_model_name -GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines"] +GuidedDecodingBackend = Literal["auto", "xgrammar", "guidance", "outlines", + "lm-format-enforcer"] @config diff --git a/vllm/config/cache.py b/vllm/config/cache.py index ae11dec3ca5e..a9550d4390ad 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -116,7 +116,7 @@ class CacheConfig: In some KV sharing setups, e.g. YOCO (https://arxiv.org/abs/2405.05254), some layers can skip tokens corresponding to prefill. This flag enables attention metadata for eligible layers to be overriden with metadata - necessary for implementating this optimization in some models (e.g. Gemma3n) + necessary for implementing this optimization in some models (e.g. Gemma3n) """ def compute_hash(self) -> str: diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index f7b8b1d0a565..9ea883d4a03c 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from pydantic import TypeAdapter, model_validator +from pydantic import model_validator from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from typing_extensions import Self @@ -56,13 +56,6 @@ class EPLBConfig: This is turned off by default since it will cause communication overhead. """ - @classmethod - def from_cli(cls, cli_value: str) -> "EPLBConfig": - """Parse the CLI value for the compilation config. - -O1, -O2, -O3, etc. is handled in FlexibleArgumentParser. - """ - return TypeAdapter(EPLBConfig).validate_json(cli_value) - @config @dataclass diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 942e866ed97e..7963fb15c419 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -152,8 +152,13 @@ def __init__(self): self.pointer_to_data: dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: dict[str, Any] = {} + # Creating strong references to the two callbacks here to prevent + # these ephemeral bound-method objects being garbage collected. + # See discussions in https://github.com/vllm-project/vllm/pull/22724 + self.python_malloc_callback = self._python_malloc_callback + self.python_free_callback = self._python_free_callback - def python_malloc_callback(self, allocation_handle: HandleType) -> None: + def _python_malloc_callback(self, allocation_handle: HandleType) -> None: """ Internal method to store the allocation data when memory is allocated in the memory pool.""" @@ -162,7 +167,7 @@ def python_malloc_callback(self, allocation_handle: HandleType) -> None: allocation_handle, self.current_tag) return - def python_free_callback(self, ptr: int) -> HandleType: + def _python_free_callback(self, ptr: int) -> HandleType: """ Internal method to look up the allocation data when memory is freed in the memory pool.""" @@ -212,9 +217,9 @@ def sleep( def wake_up(self, tags: Optional[list[str]] = None) -> None: """ Wake up the allocator from sleep mode. - All data that is previously offloaded will be loaded back to GPU + All data that is previously offloaded will be loaded back to GPU memory, and the rest of the data will have empty memory. - + :param tags: The tags of the memory allocation that will be loaded back to GPU memory. If None, all memory allocation will be loaded back to GPU memory. diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py similarity index 93% rename from vllm/distributed/device_communicators/custom_all_reduce_utils.py rename to vllm/distributed/device_communicators/all_reduce_utils.py index 7c6001e87039..5c64e7d5c4ba 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -23,6 +23,39 @@ logger = init_logger(__name__) +MiB = 1024 * 1024 +# Max size for each world size in case symmetric memory is available +# For different SM architectures +CUSTOM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: MiB // 2, # 512 KB + 8: MiB // 4, # 256 KB + }, + "10.0": { + 2: 2 * MiB, # 2 MB + 4: 2 * MiB, # 2 MB + 6: 2 * MiB, # 2 MB + 8: 2 * MiB, # 2 MB + } +} + +SYMM_MEM_ALL_REDUCE_MAX_SIZES = { + "9.0": { + 2: 64 * MiB, # 64 MB + 4: 32 * MiB, # 32 MB + 6: 64 * MiB, # 64 MB + 8: 64 * MiB, # 64 MB + }, + "10.0": { + 2: 8 * MiB, # 8 MB + 4: 32 * MiB, # 32 MB + 6: 128 * MiB, # 128 MB + 8: 128 * MiB, # 128 MB + } +} + def producer(batch_src: Sequence[int], producer_queue, diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 66d4940c9cec..0ea8de2f36f4 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -44,6 +44,8 @@ def __init__(self, PyNcclCommunicator) from vllm.distributed.device_communicators.quick_all_reduce import ( QuickAllReduce) + from vllm.distributed.device_communicators.symm_mem import ( + SymmMemCommunicator) self.pynccl_comm: Optional[PyNcclCommunicator] = None if use_pynccl and self.world_size > 1: @@ -54,6 +56,7 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None + self.symm_mem_comm: Optional[SymmMemCommunicator] = None if use_custom_allreduce and self.world_size > 1: # Initialize a custom fast all-reduce implementation. self.ca_comm = CustomAllreduce( @@ -69,6 +72,12 @@ def __init__(self, # currently be an MI300 series. self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device) + if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + self.symm_mem_comm = SymmMemCommunicator( + group=self.cpu_group, + device=self.device, + ) + if self.use_all2all: all2all_backend = envs.VLLM_ALL2ALL_BACKEND if all2all_backend == "naive": @@ -105,6 +114,12 @@ def all_reduce(self, input_): out = ca_comm.custom_all_reduce(input_) assert out is not None return out + symm_mem_comm = self.symm_mem_comm + if symm_mem_comm is not None and \ + symm_mem_comm.should_use_symm_mem(input_): + out = symm_mem_comm.all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None out = pynccl_comm.all_reduce(input_) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 8dfb7959a510..80aca81234eb 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -10,8 +10,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.custom_all_reduce_utils import ( - gpu_p2p_access_check) +from vllm.distributed.device_communicators.all_reduce_utils import ( + CUSTOM_ALL_REDUCE_MAX_SIZES, gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as from vllm.logger import init_logger from vllm.platforms import current_platform @@ -109,7 +109,13 @@ def __init__(self, # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - + device_capability = current_platform.get_device_capability( + ).as_version_str() + if (current_platform.is_cuda() and envs.VLLM_ALLREDUCE_USE_SYMM_MEM + and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES): + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], + max_size) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py new file mode 100644 index 000000000000..d907e1b833d0 --- /dev/null +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -0,0 +1,111 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.distributed.device_communicators.all_reduce_utils import ( + SYMM_MEM_ALL_REDUCE_MAX_SIZES) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + symm_mem_available = True +except ImportError: + symm_mem_available = False + +logger = init_logger(__name__) + + +class SymmMemCommunicator: + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], + } + + def __init__(self, group: ProcessGroup, device: Union[int, str, + torch.device]): + self.disabled = True + + if not symm_mem_available: + return + + if not current_platform.is_cuda(): + logger.warning("SymmMemCommunicator: symmetric " + "memory is not available.") + return + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + torch.cuda.set_device(device) + self.dtype = torch.bfloat16 + self.device = device + self.group = group + self.world_size = dist.get_world_size(self.group) + self.device_capability = current_platform.get_device_capability( + ).as_version_str() + if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: + logger.warning( + "SymmMemCommunicator: Device capability %s not supported, " + "communicator is not available.", + self.device_capability, + ) + return + if self.world_size not in SYMM_MEM_ALL_REDUCE_MAX_SIZES[ + self.device_capability]: + logger.warning( + "SymmMemCommunicator: World size %d not supported, " + "communicator is not available.", + self.world_size, + ) + return + self.max_size = SYMM_MEM_ALL_REDUCE_MAX_SIZES[self.device_capability][ + self.world_size] + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=self.device, + dtype=self.dtype, + ) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + if handle.multicast_ptr == 0: + logger.warning("SymmMemCommunicator: symmetric memory " + "multicast operations are not supported.") + return + self.disabled = False + + def should_use_symm_mem(self, inp: torch.Tensor): + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + return inp_size < self.max_size + + def all_reduce( + self, + inp: torch.Tensor, + *, + out: Optional[torch.Tensor] = None) -> Optional[torch.Tensor]: + if not self.should_use_symm_mem(inp): + return None + if out is None: + out = torch.empty_like(inp) + self.buffer[:inp.numel()].copy_(inp.view(-1)) + if self.world_size in self._WORLD_SIZES_MULTIMEM[ + self.device_capability]: + torch.ops.symm_mem.multimem_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + else: + torch.ops.symm_mem.two_shot_all_reduce_(self.buffer[:inp.numel()], + "sum", + self.group.group_name) + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + return out diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 965264ee3097..3ab1115f1446 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -455,8 +455,7 @@ def __post_init__(self): self.compilation_config = CompilationConfig( **self.compilation_config) if isinstance(self.eplb_config, dict): - self.eplb_config = EPLBConfig.from_cli(json.dumps( - self.eplb_config)) + self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bbe958351e87..dbf8d3ba5014 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1822,7 +1822,7 @@ def _validate_model_input( assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 65aac23ee618..8b50153f0115 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -663,9 +663,9 @@ async def chat_completion_stream_generator( harmony_parser = harmony_parsers[i] for token_id in output.token_ids: harmony_parser.process(token_id) - # FIXME(woosuk): Support function calling - is_final = harmony_parser.current_channel == "final" - if not (request.include_reasoning or is_final): + is_reasoning = \ + harmony_parser.current_channel == "analysis" + if not request.include_reasoning and is_reasoning: # Skip the reasoning content. continue delta_text = harmony_parser.last_content_delta or "" @@ -695,11 +695,11 @@ async def chat_completion_stream_generator( current_token_ids = as_list(output.token_ids) if self.use_harmony: - if is_final: - delta_message = DeltaMessage(content=delta_text) - else: + if is_reasoning: delta_message = DeltaMessage( reasoning_content=delta_text) + else: + delta_message = DeltaMessage(content=delta_text) # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if (self.reasoning_parser and not reasoning_end_arr[i] diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 6b131bbb04d1..5adcb310e346 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -1069,7 +1069,48 @@ def _send_event(event: BaseModel): delta=ctx.parser.last_content_delta, sequence_number=-1, )) - + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif (ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient == "python": + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + openai_responses_types. + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=openai_responses_types. + ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallInProgressEvent( + type= + "response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + )) + yield _send_event( + openai_responses_types. + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + delta=ctx.parser.last_content_delta, + )) if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] if (self.tool_server is not None @@ -1165,30 +1206,6 @@ def _send_event(event: BaseModel): and self.tool_server.has_tool("python") and previous_item.recipient is not None and previous_item.recipient.startswith("python")): - yield _send_event( - openai_responses_types.ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=openai_responses_types. - ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code="", - container_id="auto", - outputs=[], - status="in_progress", - ), - )) - yield _send_event( - openai_responses_types. - ResponseCodeInterpreterCallInProgressEvent( - type="response.code_interpreter_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - )) - # TODO: do we need to add delta event here? yield _send_event( openai_responses_types. ResponseCodeInterpreterCallCodeDoneEvent( @@ -1196,7 +1213,8 @@ def _send_event(event: BaseModel): sequence_number=-1, output_index=current_output_index, item_id=current_item_id, - code=previous_item.content[0].text)) + code=previous_item.content[0].text, + )) yield _send_event( openai_responses_types. ResponseCodeInterpreterCallInterpretingEvent( diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 468c3799bd1f..44aa1208a54c 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,6 +3,7 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .deepseekv31_tool_parser import DeepSeekV31ToolParser from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser @@ -36,6 +37,7 @@ "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser", + "DeepSeekV31ToolParser", "xLAMToolParser", "MinimaxToolParser", "KimiK2ToolParser", diff --git a/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py new file mode 100644 index 000000000000..ff9188190f3f --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("deepseek_v31") +class DeepSeekV31ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool▁calls▁begin|>" + self.tool_calls_end_token: str = "<|tool▁calls▁end|>" + + self.tool_call_start_token: str = "<|tool▁call▁begin|>" + self.tool_call_end_token: str = "<|tool▁call▁end|>" + + self.tool_call_regex = re.compile( + r"<|tool▁call▁begin|>(?P.*)<|tool▁sep|>(?P.*)<|tool▁call▁end|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P.*)<|tool▁sep|>(?P.*)") + + self.stream_tool_call_name_regex = re.compile( + r"(?P.*)<|tool▁sep|>") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "DeepSeek-V3.1 Tool parser could not locate tool call " + "start/end tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + tool_calls = [] + for match in function_call_tuples: + function_name, function_args = match + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_name, tool_args = current_tool_call_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_name = current_tool_call_name_matches.groups() + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/envs.py b/vllm/envs.py index 296c1730892d..5d0e972f43ad 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -158,8 +158,10 @@ VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False VLLM_ENABLE_RESPONSES_API_STORE: bool = False VLLM_USE_TRTLLM_ATTENTION: Optional[str] = None + VLLM_HAS_FLASHINFER_CUBIN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None @@ -1105,6 +1107,11 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_TRTLLM_ATTENTION": lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None), + # If set, it means we pre-downloaded cubin files and flashinfer will + # read the cubin files directly. + "VLLM_HAS_FLASHINFER_CUBIN": + lambda: os.getenv("VLLM_HAS_FLASHINFER_CUBIN", False), + # If set to 1, force the use of TRTLLM FP4 GEMM backend in flashinfer. # Otherwise, uses the first available of: flashinfer cutlass GEMM, # vllm cutlass GEMM, marlin GEMM. @@ -1150,6 +1157,10 @@ def get_vllm_port() -> Optional[int]: "VLLM_ENABLE_RESPONSES_API_STORE": lambda: bool(int(os.getenv("VLLM_ENABLE_RESPONSES_API_STORE", "0"))), + # Whether to use pytorch symmetric memory for allreduce + "VLLM_ALLREDUCE_USE_SYMM_MEM": + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 86ab4f546d12..f3248589abc4 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -422,12 +422,23 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): lambda: nn.SiLU(), "quick_gelu": lambda: QuickGELU(), + "tanh": + lambda: nn.Tanh(), + "sigmoid": + lambda: nn.Sigmoid(), }) def get_act_fn(act_fn_name: str) -> nn.Module: """Get an activation function by name.""" act_fn_name = act_fn_name.lower() + + if act_fn_name.startswith("torch.nn.modules."): + activation_name = act_fn_name.split(".")[-1] + if activation_name == "identity": + return nn.Identity() + act_fn_name = activation_name + if act_fn_name not in _ACTIVATION_REGISTRY: raise ValueError( f"Activation function {act_fn_name!r} is not supported.") diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py new file mode 100644 index 000000000000..782818f55fbc --- /dev/null +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -0,0 +1,23 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Base class for attention-like layers.""" +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class AttentionLayerBase(ABC): + """ + Base class for attention-like layers (Attention, Mamba, etc.) + that support the v1 engine. + + This provides a common interface for getting attention backends + from different layer types. + """ + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this layer.""" + pass diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..63de4bfa4cb5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,122 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..6efcc02b4d9a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,114 @@ +{ + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..d677d69c57a2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,154 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9b1ab7af0ac8..dd54aebeb011 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -52,6 +52,7 @@ "HQQMarlinMethod", "QuarkLinearMethod", "ModelOptNvFp4LinearMethod", + "PetitNvFp4LinearMethod", ] @@ -1377,7 +1378,7 @@ def forward( return output, output_bias def extra_repr(self) -> str: - s = f"input_features={self.input_size_per_partition}" + s = f"in_features={self.input_size_per_partition}" s += f", output_features={self.output_size}" s += f", bias={self.bias is not None}" s += f", tp_size={self.tp_size}" diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py index daebe46f6f77..a524e1340580 100644 --- a/vllm/model_executor/layers/mamba/abstract.py +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -1,12 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod +from abc import abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING import torch +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -class MambaBase(ABC): +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +class MambaBase(AttentionLayerBase): """ Base class for Mamba-like layers which support the v1 engine. Inherit from this class if you implement a custom layer. @@ -32,3 +38,8 @@ def get_state_shape(self) -> Iterable[tuple[int, ...]]: @abstractmethod def mamba_type(self) -> str: pass + + @abstractmethod + def get_attn_backend(self) -> type["AttentionBackend"]: + """Get the attention backend class for this Mamba layer.""" + pass diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index a24e72778b34..e704bfd451bc 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import NamedTuple, Optional +from typing import TYPE_CHECKING, NamedTuple, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -404,6 +407,11 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def mamba_type(self) -> str: return "mamba1" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba1_attn import ( + Mamba1AttentionBackend) + return Mamba1AttentionBackend + def _time_proj_bias(self) -> Optional[torch.Tensor]: if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: return self.dt_proj.bias.float() diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 743e520ec8ee..bb3fdd38dbef 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch from torch import nn @@ -758,6 +761,11 @@ def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def mamba_type(self) -> str: return "mamba2" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.mamba2_attn import ( + Mamba2AttentionBackend) + return Mamba2AttentionBackend + def mamba_mixer2( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/short_conv.py b/vllm/model_executor/layers/mamba/short_conv.py index fead1e73e345..335191a5c82c 100644 --- a/vllm/model_executor/layers/mamba/short_conv.py +++ b/vllm/model_executor/layers/mamba/short_conv.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import torch @@ -232,6 +235,11 @@ def get_state_shape(self) -> tuple[tuple[int, ...]]: def mamba_type(self) -> str: return "short_conv" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.short_conv_attn import ( + ShortConvAttentionBackend) + return ShortConvAttentionBackend + def short_conv( hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d34fb58cb5cb..eebf7b2508db 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import IntEnum from itertools import groupby -from typing import Callable, Optional, TypeVar, Union +from typing import Callable, Optional, TypeVar, Union, cast import torch import torch.nn as nn @@ -435,9 +435,31 @@ class EmbeddingPoolerHead(PoolerHead): def __init__(self) -> None: super().__init__(activation=PoolerNormalize()) + # Load ST projector if available + from vllm.config import get_current_vllm_config + from vllm.model_executor.models.adapters import _load_st_projector + + vllm_config = get_current_vllm_config() + self.projector = _load_st_projector( + vllm_config.model_config) if vllm_config else None + def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): + # Apply ST projector + if self.projector is not None: + projector = cast(nn.Module, self.projector) + + def _proj(x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + y = projector(x.to(torch.float32)) + return y.to(orig_dtype) + + if isinstance(pooled_data, torch.Tensor): + pooled_data = _proj(pooled_data) + else: + pooled_data = [_proj(t) for t in pooled_data] + pooling_params = get_pooling_params(pooling_metadata) if isinstance(pooled_data, list): diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ea51468422dc..d73fcf368f26 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -35,6 +35,7 @@ "rtn", "inc", "mxfp4", + "petit_nvfp4", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -108,6 +109,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .neuron_quant import NeuronQuantConfig + from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .rtn import RTNConfig from .torchao import TorchAOConfig @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "rtn": RTNConfig, "inc": INCConfig, "mxfp4": Mxfp4Config, + "petit_nvfp4": PetitNvFp4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 637a84372990..ce74375aab42 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -26,10 +26,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, CompressedTensors24, CompressedTensorsScheme, CompressedTensorsW4A4Fp4, - CompressedTensorsW4A8Int, CompressedTensorsW4A16Fp4, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Int, + CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( find_matched_target, is_activation_quantization_format, should_ignore_layer) @@ -200,8 +200,10 @@ def _quantization_scheme_map_from_config( format ) if format is not None else is_activation_quantization_format( quant_format) - if act_quant_format: - input_activations = quant_config.get("input_activations") + # TODO(czhu): w4a8fp8 is in packed-quantized format + # but needs input activation quantization + input_activations = quant_config.get("input_activations") + if act_quant_format or input_activations: # The only case where we have activation quant supported # but no input_activations provided in the config # should be w8a16fp8 w8a16fp8 can also run for cases where @@ -352,6 +354,28 @@ def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant.strategy == QuantizationStrategy.TENSOR) return is_symmetric_activation and is_per_tensor_activation + def _is_fp8_w4a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + if not weight_quant or not input_quant: + return False + is_weight_4_bits = weight_quant.num_bits == 4 + is_activation_8_bits = input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + is_symmetric = weight_quant.symmetric and input_quant.symmetric + # Only per-group symmetric weight (4bit) + # + per-tok symmetric activation (8bit) quantization supported. + return (is_weight_4_bits and is_activation_8_bits and is_token + and is_symmetric and is_dynamic) + + def _is_fp8_w4a8_sm90(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported(90, error=False, match_exact=True) + and self._is_fp8_w4a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: return (self._check_scheme_supported(90, error=False, match_exact=True) @@ -405,6 +429,13 @@ def _get_scheme_from_parts( if self._is_fp4a16_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A16Fp4() + if self._is_fp8_w4a8_sm90(weight_quant, input_quant): + return CompressedTensorsW4A8Fp8(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + symmetric=weight_quant.symmetric, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + if self._is_wNa16_group_channel(weight_quant, input_quant): if (self.quant_format == CompressionFormat.marlin_24.value and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 734fa603ba7b..cac65cca5093 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -3,6 +3,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8 from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) @@ -21,5 +22,6 @@ "CompressedTensorsW8A8Int8", "CompressedTensorsW8A8Fp8", "WNA16_SUPPORTED_BITS", "W4A16SPARSE24_SUPPORTED_BITS", "CompressedTensors24", "CompressedTensorsW4A16Fp4", - "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int" + "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", + "CompressedTensorsW4A8Fp8" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py new file mode 100644 index 000000000000..f6cc49c2316b --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py @@ -0,0 +1,160 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Callable, Optional + +import torch +from compressed_tensors.quantization import ActivationOrdering + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +# yapf: enable +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsW4A8Fp8"] +W4A8_SUPPORTED_TYPES_MAP = { + 4: scalar_types.int4, +} +W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A8Fp8(CompressedTensorsScheme): + _kernel_backends_being_used: set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + symmetric: Optional[bool] = True, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.symmetric = symmetric + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size != 128 or self.strategy != "group": + raise ValueError("W4A8 kernels require group quantization " \ + "with group size 128") + + if num_bits not in W4A8_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # hopper + return 90 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=torch.float8_e4m3fn, # always use fp8(e4m3) + group_size=self.group_size, + zero_points=not self.symmetric, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsW4A8Fp8", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + # TODO(czhu): allocate the packed fp8 scales memory here? + # the scales will be expanded by 8x via `cutlass_pack_scale_fp8` + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ) + } + + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name="weight_zero_point", + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index a5084f6ee92c..4bcfcd04b3d8 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -10,6 +10,8 @@ BitBLASLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 ConchLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.cutlass import ( # noqa: E501 + CutlassW4A8LinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.dynamic_4bit import ( # noqa: E501 Dynamic4bitLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 @@ -24,6 +26,7 @@ # in priority/performance order (when available) _POSSIBLE_KERNELS: list[type[MPLinearKernel]] = [ + CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, MarlinLinearKernel, diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py new file mode 100644 index 000000000000..f1d49693fc01 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class CutlassW4A8LinearKernel(MPLinearKernel): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # dynamic per-tok fp8 activation quantization + self.quant_fp8 = QuantFP8(static=False, + group_shape=GroupShape.PER_TOKEN) + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if not current_platform.is_cuda(): + return False, "CUTLASS only supported on CUDA" + + if not current_platform.is_device_capability(90): + return False, "CUTLASS W4A8 requires compute capability of 90 "\ + "(Hopper)" + + if c.act_type != torch.float8_e4m3fn: + return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations" + + if c.has_g_idx: + return False, "Act reordering not supported by CUTLASS W4A8" + + if c.zero_points: + return False, "Zero points not supported by CUTLASS W4A8" + + if c.weight_type != scalar_types.int4: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "CUTLASS W4A8, only supported int4" + + # TODO(czhu): support -1 (column-wise) + if c.group_size != 128: + return False, "Only group_size 128 is supported" + + in_features, out_features = c.partition_weight_shape + if in_features % 128 or out_features % 128: + return False, "K and N must be divisible by 128, got "\ + f"{c.partition_weight_shape}" + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # TODO(czhu): optimize speed/mem usage + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.cutlass_encode_and_reorder_int4b( + x.data.t().contiguous().t()) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous().to(torch.float8_e4m3fn) + x.data = ops.cutlass_pack_scale_fp8(x.data) + return x + + # Encode/reorder weights and pack scales + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + # TODO(czhu): support loading channel scales + self.w_ch_s = torch.ones((c.partition_weight_shape[1], ), + dtype=torch.float32, + device='cuda') + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None, "bias not supported by CUTLASS W4A8" + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + x_2d, act_scales = self.quant_fp8(x_2d) + output = ops.cutlass_w4a8_mm(a=x_2d, + b_q=w_q, + b_group_scales=w_s, + b_group_size=c.group_size, + a_token_scales=act_scales, + b_channel_scales=self.w_ch_s) + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/petit.py b/vllm/model_executor/layers/quantization/petit.py new file mode 100644 index 000000000000..5b9fee69bb02 --- /dev/null +++ b/vllm/model_executor/layers/quantization/petit.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +from typing import Any, Optional + +import regex as re +import torch +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.petit_utils import ( + apply_petit_nvfp4_linear, prepare_nvfp4_layer_for_petit, + verify_petit_nvfp4_supported) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +# Initialize logger for the module +logger = init_logger(__name__) + + +# Configuration class to support the NVFP4 quantized model +# generated by the ModelOpt quantization tool +class PetitNvFp4Config(QuantizationConfig): + """Config class for Petit FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: Optional[str] = None, + group_size: Optional[int] = None, + exclude_modules: Optional[list[str]] = None, + ) -> None: + self._check_hardware_support() + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning("Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change.") + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + def _check_hardware_support(self) -> None: + """ + Verifies that the current hardware is supported by the Petit backend. + This backend is specifically designed for AMD GPUs and is not + supported on the CUDA platform. + """ + # This check ensures the code is NOT running on an NVIDIA GPU. + if current_platform.is_cuda(): + raise ValueError( + "The 'petit' quantization backend is designed for AMD GPUs " + "and is not supported on the CUDA platform. For NVIDIA GPUs, " + "please use a different quantization method such as FP8, AWQ, " + "or GPTQ.") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "petit_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # Petit supports the gfx90a and gfx942 GPUs + return 90 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "PetitNvFp4Config": + qc = cls.get_from_keys(config, ["quantization"]) + + quant_method_raw = qc.get("quant_algo") + if not isinstance(quant_method_raw, str) or not quant_method_raw: + raise ValueError( + "Missing or invalid 'quant_algo' in quantization config.") + quant_method = quant_method_raw.upper() + + group_size_raw = qc.get("group_size") + if not isinstance(group_size_raw, int): + raise ValueError( + "Missing or invalid 'group_size' (int) in hf_quant_config.json." + ) + group_size = group_size_raw + + verify_petit_nvfp4_supported(quant_method, group_size) + + kv_cache_quant_algo_raw = qc.get("kv_cache_quant_algo") or "auto" + if not isinstance(kv_cache_quant_algo_raw, str): + raise ValueError( + "'kv_cache_quant_algo' must be a string if provided.") + kv_cache_quant_algo = kv_cache_quant_algo_raw + + exclude_raw = qc.get("exclude_modules", []) + if exclude_raw is None: + exclude_modules: list[str] = [] + elif isinstance(exclude_raw, list) and all( + isinstance(x, str) for x in exclude_raw): + exclude_modules = exclude_raw + else: + raise ValueError( + "'exclude_modules' must be a list[str] (or omitted).") + + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + + return cls( + is_checkpoint_nvfp4_serialized=is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo=kv_cache_quant_algo, + group_size=group_size, + exclude_modules=exclude_modules, + ) + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + if not current_platform.is_rocm(): + return None + + qc = hf_quant_cfg.get("quantization", hf_quant_cfg) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + if algo in ("NVFP4", "MODELOPT_FP4", "MODELOPT"): + return cls.get_name() # "petit_nvfp4" + return None + + @classmethod + def is_petit_nvfp4_compatible(cls, quant_config: dict[str, Any]) -> bool: + qc = quant_config.get("quantization", quant_config) + algo = (qc.get("quant_algo") or qc.get("quant_method") or "").upper() + return algo == "NVFP4" + + def is_layer_excluded(self, prefix: str, + exclude_modules: list[str]) -> bool: + for pattern in exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + if re.fullmatch(regex_str, prefix): + return True + return False + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + exclude = self.require_exclude_modules() + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, exclude) or self.is_layer_excluded( + prefix, exclude): + return UnquantizedLinearMethod() + return PetitNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return PetitFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> list[str]: + return [] + + def require_group_size(self) -> int: + if self.group_size is None: + logger.warning("group_size not set; defaulting to 16 for NVFP4.") + return 16 + return self.group_size + + def require_kv_cache_quant_algo(self) -> str: + return self.kv_cache_quant_algo or "auto" + + def require_exclude_modules(self) -> list[str]: + return list(self.exclude_modules or []) + + +class PetitFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + super().__init__(quant_config) + + +class PetitNvFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: PetitNvFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + group_size = self.quant_config.require_group_size() + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + prepare_nvfp4_layer_for_petit(layer) + del layer.input_scale + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_petit_nvfp4_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) diff --git a/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..fbca5ce05d01 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/quantization/utils/configs/README.md b/vllm/model_executor/layers/quantization/utils/configs/README.md new file mode 100644 index 000000000000..1110ced4fa06 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/configs/README.md @@ -0,0 +1,3 @@ +# Quantization Kernel Config + +Use scripts under `benchmarks/kernels/` to generate these config files. diff --git a/vllm/model_executor/layers/quantization/utils/petit_utils.py b/vllm/model_executor/layers/quantization/utils/petit_utils.py new file mode 100644 index 000000000000..00d3def1db81 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/petit_utils.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + +import torch + +# TYPE_CHECKING is used for static type analysis to prevent circular imports. +if TYPE_CHECKING: + from types import ModuleType + +# 1. Create a global variable as a placeholder for the module +_petit_kernel: Optional["ModuleType"] = None + +_PETIT_INSTALL_MSG = ("Petit is not installed. Please install it with " + "`pip install petit-kernel`.") + + +def _import_petit_kernel() -> "ModuleType": + """ + A helper function to handle the lazy import. + The first time this function is called, it will import the petit_kernel + library and store it in the global _petit_kernel variable. + Subsequent calls will return the already-loaded module directly. + """ + global _petit_kernel + if _petit_kernel is not None: + return _petit_kernel + + try: + import petit_kernel + _petit_kernel = petit_kernel + return _petit_kernel + except ImportError: + # The 'from None' syntax prevents chaining the original ImportError, + # making the traceback cleaner. + raise ImportError(_PETIT_INSTALL_MSG) from None + + +# The _require_petit function can now be a simple alias for consistency. +_require_petit = _import_petit_kernel + + +def _check_petit_nvfp4_supported( + quant_method: str, + group_size: Optional[int]) -> tuple[bool, Optional[str]]: + if quant_method != "NVFP4": + return ( + False, + ("Petit currently only supports: NVFP4 quantizations in sglang. " + "Please check the `hf_quant_config.json` file for your model's " + "quant configuration."), + ) + if group_size is not None and group_size != 16: + return ( + False, + "Petit currently only supports: group_size=16 quantizations.", + ) + return (True, None) + + +def verify_petit_nvfp4_supported(quant_method: str, + group_size: Optional[int]) -> None: + supported, error_msg = _check_petit_nvfp4_supported( + quant_method, group_size) + if not supported: + assert error_msg is not None + raise ValueError(error_msg) + + +def prepare_nvfp4_layer_for_petit(layer: torch.nn.Module) -> None: + # 2. Call _import_petit_kernel() to trigger (or get) the import. + petit_kernel = _import_petit_kernel() + + # Repack weights to petit format + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + qweight = layer.weight.view(torch.int32).contiguous() + + # 3. Call functions through the imported module variable. + petit_qweight = petit_kernel.repack_nvfp4(qweight, + size_n=part_size_n, + size_k=part_size_k) + layer.weight = torch.nn.Parameter(petit_qweight, requires_grad=False) + + # Permute scales + weight_scale = petit_kernel.process_nvfp4_scales(scales=layer.weight_scale, + size_k=part_size_k, + size_n=part_size_n) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + +def apply_petit_nvfp4_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # Trigger (or get) the import here as well. + petit_kernel = _import_petit_kernel() + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + # TODO: Use auto-tuning to find the performant solution_id + # Call the function via the module variable. + output = petit_kernel.mul_nvfp4_a16( + a=reshaped_x, + b=weight, + s=weight_scale, + global_scale=weight_scale_2, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + solution_id=-1, + ) + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 97e5922ebd55..6154fca2e416 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -2,16 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """This file is used for /tests and /benchmarks""" from collections.abc import Mapping +from dataclasses import dataclass from types import MappingProxyType from typing import ClassVar, NamedTuple, Optional import numpy import torch +from torch import fx from vllm._custom_ops import cutlass_scaled_mm_supports_fp4 from vllm.platforms import current_platform from vllm.scalar_type import ScalarType, scalar_types +FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 + # Use proxy as NamedTuple direct subclasses cannot have static members class _GroupShape(NamedTuple): @@ -34,6 +39,64 @@ class GroupShape(_GroupShape): GroupShape.PER_TOKEN = GroupShape(1, -1) +@dataclass(frozen=True) +class ScaleDesc: + """ + Class for describing a single quantization scaling factor. + dtype: data type of the scale + static: static scale if True, dynamic if False + group_shape: group shape of the scale + """ + dtype: torch.dtype + static: bool + group_shape: GroupShape + + def __str__(self): + group_shape = ('per_tensor' + if self.group_shape == GroupShape.PER_TENSOR else + ('per_token' if self.group_shape == GroupShape.PER_TOKEN + else str(self.group_shape))) + + return (f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'static' if self.static else 'dynamic'},{group_shape}") + + +@dataclass(frozen=True) +class QuantKey: + """ + Class for identifying the type of quantization. + dtype: quantized data type + scale: scale descriptor + scale2: second-level scale descriptor + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + scale: ScaleDesc + scale2: Optional[ScaleDesc] = None + symmetric: bool = True + + def __str__(self): + scale2_str = f"scale2({self.scale2})," if self.scale2 else "" + return (f"QuantKey({fx.graph.dtype_abbrs[self.dtype]}," + f"scale({self.scale}),{scale2_str}" + f"{'a' if not self.symmetric else ''}symmetric)") + + +kStaticTensorScale = ScaleDesc(torch.float32, True, GroupShape.PER_TENSOR) +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) + +kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) + +kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) + +kNvfp4GroupScale = ScaleDesc(FP8_DTYPE, False, GroupShape(1, 16)) +kNvfp4Quant = QuantKey(FP4_DTYPE, + scale=kNvfp4GroupScale, + scale2=kStaticTensorScale) + + # Normalize the group_shape to the full extent for any dims that are -1 def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 659029fd37f7..36d16960ec57 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -13,6 +13,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale @@ -156,13 +157,10 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, return output.view(*output_shape) -def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, - weight: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, bias: torch.Tensor, - input_2d: torch.Tensor, - output_shape: list) -> torch.Tensor: +def rocm_per_tensor_w8a8_scaled_mm_impl( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: from vllm.platforms.rocm import on_mi3xx if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: @@ -175,10 +173,38 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, scale_a=scale_a, scale_b=scale_b, bias=bias) + return output + + +def rocm_per_tensor_w8a8_scaled_mm_fake( + qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, + scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor) -> torch.Tensor: + return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), + dtype=out_dtype) + +def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, + weight: torch.Tensor, + out_dtype: torch.dtype, + scale_a: torch.Tensor, + scale_b: torch.Tensor, bias: torch.Tensor, + input_2d: torch.Tensor, + output_shape: list) -> torch.Tensor: + output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( + qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) +direct_register_custom_op( + op_name="rocm_per_tensor_w8a8_scaled_mm_impl", + op_func=rocm_per_tensor_w8a8_scaled_mm_impl, + mutates_args=[], + fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake, + dispatch_key=current_platform.dispatch_key, +) + + def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 1dbe70f84a62..49e9a2d65ea1 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -7,15 +7,21 @@ import torch import torch.nn as nn +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.models.config import VerifyAndUpdateConfig +from vllm.transformers_utils.config import (get_hf_file_bytes, + get_hf_file_to_dict) from .interfaces_base import VllmModelForPooling, is_pooling_model if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig _T = TypeVar("_T", bound=type[nn.Module]) +logger = init_logger(__name__) + _GENERATE_SUFFIXES = [ "ForCausalLM", "ForConditionalGeneration", @@ -24,6 +30,96 @@ ] +def _load_st_projector(model_config: "ModelConfig") -> Optional[nn.Module]: + """Load Sentence-Transformers Dense projection layers.""" + + try: + modules = get_hf_file_to_dict("modules.json", model_config.model, + model_config.revision) + if not modules: + return None + + if isinstance(modules, dict): + modules = modules.get("modules", []) + + dense_modules = [ + m for m in modules + if m.get("type") == "sentence_transformers.models.Dense" + ] + if not dense_modules: + return None + + module = dense_modules[0] + folder = module.get("path", "") + + config_path = f"{folder}/config.json" if folder else "config.json" + layer_config = get_hf_file_to_dict(config_path, model_config.model, + model_config.revision) + if not layer_config: + return None + + linear = nn.Linear(layer_config.get("in_features", 768), + layer_config.get("out_features", 768), + bias=layer_config.get("bias", True), + dtype=torch.float32) + + if _load_dense_weights(linear, folder, model_config): + layers = [linear] + if act_name := layer_config.get("activation_function"): + layers.append(get_act_fn(act_name)) + return nn.Sequential(*layers).to(dtype=torch.float32) + + except Exception: + logger.exception("ST projector loading failed") + + return None + + +def _load_dense_weights(linear: nn.Linear, folder: str, + model_config: "ModelConfig") -> bool: + """Load weights using vLLM's weight_loader pattern.""" + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) + + for filename in ["model.safetensors", "pytorch_model.bin"]: + file_path = f"{folder}/{filename}" if folder else filename + + try: + file_bytes = get_hf_file_bytes(file_path, model_config.model, + model_config.revision) + if not file_bytes: + continue + + if filename.endswith(".safetensors"): + from safetensors.torch import load as load_safetensors + state_dict = load_safetensors(file_bytes) + else: + import io + state_dict = torch.load(io.BytesIO(file_bytes), + map_location="cpu", + weights_only=True) + + for weight_key in ["weight", "linear.weight", "dense.weight"]: + if weight_key in state_dict: + weight_loader = getattr(linear.weight, "weight_loader", + default_weight_loader) + weight_loader(linear.weight, + state_dict[weight_key].to(torch.float32)) + + bias_key = weight_key.replace("weight", "bias") + if linear.bias is not None and bias_key in state_dict: + bias_loader = getattr(linear.bias, "weight_loader", + default_weight_loader) + bias_loader(linear.bias, + state_dict[bias_key].to(torch.float32)) + return True + except Exception: + logger.exception("Failed to load %s", filename) + continue + + return False + + def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: model_name = orig_model_name diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 6e4a399f3cc6..126404584892 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only +from .interfaces import SupportsPP, SupportsQuant from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -313,7 +313,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant): +class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() diff --git a/vllm/model_executor/models/donut.py b/vllm/model_executor/models/donut.py new file mode 100644 index 000000000000..c00db52371b6 --- /dev/null +++ b/vllm/model_executor/models/donut.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, Optional, Union + +import torch +import torch.nn as nn +from transformers import BatchFeature, NougatProcessor + +from vllm.config import VllmConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.bart import BartParallelLMHead, MBartDecoder +from vllm.model_executor.models.interfaces import (MultiModalEmbeddings, + SupportsMultiModal, + SupportsV0Only) +from vllm.model_executor.models.swin import SwinModel +from vllm.model_executor.models.utils import (AutoWeightsLoader, + _flatten_embeddings, flatten_bn) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargsItems) +from vllm.multimodal.parse import MultiModalDataItems +from vllm.multimodal.processing import (BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptIndexTargets, PromptInsertion, + PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.utils.tensor_schema import TensorSchema, TensorShape + + +class MBartDecoderWrapper(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.decoder = MBartDecoder(config, + cache_config, + quant_config=quant_config, + prefix=f"{prefix}.decoder") + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) + + +class DonutLanguageForConditionalGeneration(nn.Module, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + + self.config = config + self.model = MBartDecoderWrapper(vllm_config=vllm_config, + prefix=f"{prefix}.model") + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.vocab_size = config.vocab_size + self.lm_head = BartParallelLMHead(self.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.vocab_size, + config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + Returns: + Output torch.Tensor + """ + + return self.model(decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "final_logits_bias" in name: + continue + # if self.config.tie_word_embeddings and "embed_tokens" in name: + # continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class DonutImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] + data: Annotated[torch.Tensor, TensorShape("b", 3, "h", "w")] + + +class DonutProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config() + + def get_hf_processor(self): + return self.ctx.get_hf_processor() + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": 1} + + def get_num_image_tokens(self) -> int: + return 1 + + +class DonutDummyInputsBuilder(BaseDummyInputsBuilder[DonutProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_hf_config( + ).encoder.image_size + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images) + } + + +class DonutMultiModalProcessor(EncDecMultiModalProcessor[DonutProcessingInfo]): + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def create_encoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + def create_decoder_prompt( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + ) -> Union[str, list[int]]: + return prompt + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor() + if mm_data: + processed_outputs = super()._call_hf_processor( + prompt, mm_data, mm_kwargs, tok_kwargs) + if isinstance(hf_processor, NougatProcessor): + processed_outputs["input_ids"] = processed_outputs["labels"] + else: + tokenizer = hf_processor.tokenizer + processed_outputs = tokenizer(prompt, + add_special_tokens=False, + return_tensors="pt") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values=MultiModalFieldConfig.batched("image")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor() + tokenizer = hf_processor.tokenizer + pad_token_id = tokenizer.pad_token_id + num_image_tokens = self.info.get_num_image_tokens() + image_tokens = [pad_token_id] * num_image_tokens + + return [ + PromptInsertion( + modality="image", + target=PromptIndexTargets.start(), + insertion=image_tokens, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor(DonutMultiModalProcessor, + info=DonutProcessingInfo, + dummy_inputs=DonutDummyInputsBuilder) +class DonutForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + processor_config = vllm_config.model_config.hf_image_processor_config + + self.config = config + self.vision_config = config.encoder + self.processor_config = processor_config + self.encoder = SwinModel(config=config.encoder) + + self.decoder = DonutLanguageForConditionalGeneration( + vllm_config=vllm_config.with_hf_config(config.decoder), + prefix=f"{prefix}.decoder", + ) + self.pad_token_id = config.pad_token_id + + def _parse_and_validate_image_input(self, **kwargs: object): + pixel_values: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[list[list[torch.Tensor]], + list[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + h, w = self.config.encoder.image_size + return DonutImagePixelInputs(type="pixel_values", + data=flatten_bn(pixel_values, + concat=True), + resolve_bindings={ + "h": h, + "w": w, + }) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, image_input: DonutImagePixelInputs) -> torch.Tensor: + assert image_input["type"] == "pixel_values" + pixel_values = image_input["data"] + dtype = next(self.encoder.parameters()).dtype + pixel_values = pixel_values.to(dtype) + return self.encoder(pixel_values) + + def get_language_model(self) -> torch.nn.Module: + return self.decoder + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings, + ) -> torch.Tensor: + return _flatten_embeddings(multimodal_embeddings) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + Returns: + Output torch.Tensor + """ + + inputs_embeds = None + if encoder_input_ids.numel() > 0: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(encoder_input_ids, + vision_embeddings) + + hidden_states = self.decoder(input_ids, + positions, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.decoder.compute_logits(hidden_states, sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index bf5ad633b94a..f3dc7dde46bd 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -22,11 +22,12 @@ MultiModalDataItems) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -337,14 +338,10 @@ def get_replacement_gemma3(item_idx: int): def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -373,13 +370,12 @@ def _apply_token_matches( [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -404,8 +400,8 @@ def get_repl_toks(tok: int) -> list[int]: repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 79061fd30c39..d59dde1560ae 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -29,11 +29,12 @@ MultiModalDataParser) # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, + MultiModalPromptUpdatesApplyResult, PlaceholderFeaturesInfo, - PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptReplacement, PromptUpdate, + PromptUpdateDetails, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder @@ -254,14 +255,10 @@ def get_replacement_audio(item_idx: int): def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - token_ids = super()._apply_token_matches( - prompt, - mm_matches, - mm_item_counts, - ) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + token_ids, res = super()._apply_token_matches(prompt, + mm_prompt_updates) # "\n\n\n" and "\n\n\n\n" are single tokens # Since our replacement can insert "\n\n" next to "\n" @@ -290,13 +287,12 @@ def _apply_token_matches( [newline_4], ) - return token_ids + return token_ids, res def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: # We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n" tokenizer = self.info.get_tokenizer() @@ -321,8 +317,8 @@ def get_repl_toks(tok: int) -> list[int]: repl_token_ids.extend(repl_toks) repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks))) - repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids, - mm_item_counts) + repls = super()._find_mm_placeholders(repl_token_ids, + mm_prompt_updates) return { modality: [ diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py index 9e27200fb1c8..88b2a295905b 100644 --- a/vllm/model_executor/models/idefics2_vision_model.py +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -27,13 +27,15 @@ Idefics2Config, Idefics2VisionConfig) from vllm.attention.layer import MultiHeadAttention -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import run_dp_sharded_vision_model class Idefics2VisionEmbeddings(nn.Module): @@ -118,6 +120,7 @@ def __init__( config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config @@ -130,22 +133,43 @@ def __init__( f" {self.num_heads}).") self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) - self.out_proj = RowParallelLinear( - self.embed_dim, - self.embed_dim, - bias=True, - quant_config=quant_config, - prefix=f"{prefix}.out_proj", - ) - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + tp_size = (1 if use_data_parallel else + get_tensor_model_parallel_world_size()) + assert self.num_heads % tp_size == 0 + self.num_heads_per_partition = self.num_heads // tp_size + + if use_data_parallel: + self.q_size = self.num_heads * self.head_dim + self.qkv_proj = ReplicatedLinear( + self.embed_dim, + 3 * self.q_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = ReplicatedLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + else: + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) self.attn = MultiHeadAttention(self.num_heads_per_partition, self.head_dim, self.scale) @@ -169,18 +193,23 @@ def __init__( config: Idefics2VisionConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.config = config self.activation_fn = get_act_fn(config.hidden_act) - self.fc1 = ColumnParallelLinear( + cls_fc1 = (ReplicatedLinear + if use_data_parallel else ColumnParallelLinear) + self.fc1 = cls_fc1( config.hidden_size, config.intermediate_size, bias=True, quant_config=quant_config, prefix=f"{prefix}.fc1", ) - self.fc2 = RowParallelLinear( + cls_fc2 = (ReplicatedLinear + if use_data_parallel else RowParallelLinear) + self.fc2 = cls_fc2( config.intermediate_size, config.hidden_size, bias=True, @@ -202,17 +231,21 @@ def __init__( config: Idefics2Config, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() self.embed_dim = config.hidden_size - self.self_attn = Idefics2VisionAttention(config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") + self.self_attn = Idefics2VisionAttention( + config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_data_parallel=use_data_parallel) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = Idefics2VisionMLP(config, quant_config=quant_config, - prefix=f"{prefix}.mlp") + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -254,6 +287,7 @@ def __init__( *, num_hidden_layers_override: Optional[int] = None, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() @@ -267,7 +301,8 @@ def __init__( self.layers = nn.ModuleList([ Idefics2EncoderLayer(config, quant_config=quant_config, - prefix=f"{prefix}.layers.{layer_idx}") + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel) for layer_idx in range(num_hidden_layers) ]) @@ -301,17 +336,20 @@ def __init__( num_hidden_layers_override: Optional[int] = None, require_post_norm: bool = True, prefix: str = "", + use_data_parallel: bool = False, ) -> None: super().__init__() embed_dim = config.hidden_size self.config = config + self.use_data_parallel = use_data_parallel self.embeddings = Idefics2VisionEmbeddings(config) self.encoder = Idefics2Encoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, - prefix=f"{prefix}.encoder") + prefix=f"{prefix}.encoder", + use_data_parallel=use_data_parallel) num_hidden_layers = config.num_hidden_layers if len(self.encoder.layers) > config.num_hidden_layers: @@ -340,10 +378,38 @@ def forward( patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - encoder_outputs = self.encoder(hidden_states) + if self.use_data_parallel: + encoder_outputs = run_dp_sharded_vision_model( + hidden_states, self.encoder) + else: + encoder_outputs = self.encoder(hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state + def _consolidate_qkv_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> Iterable[tuple[str, torch.Tensor]]: + qkv_idx_mappings = { + ".self_attn.q_proj": 0, + ".self_attn.k_proj": 1, + ".self_attn.v_proj": 2, + } + qkv_weights = {} + for name, loaded_weight in weights: + for weight_name, idx in qkv_idx_mappings.items(): + if weight_name not in name: + continue + new_name = name.replace(weight_name, ".self_attn.qkv_proj") + if new_name not in qkv_weights: + qkv_weights[new_name] = [None] * 3 + qkv_weights[new_name][idx] = loaded_weight + break + else: + yield name, loaded_weight + for key, weight in qkv_weights.items(): + qkv_weight = torch.cat(weight, dim=0) + yield key, qkv_weight + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -356,6 +422,9 @@ def load_weights(self, weights: Iterable[tuple[str, loaded_params: set[str] = set() layer_count = len(self.encoder.layers) + if self.use_data_parallel: + weights = self._consolidate_qkv_weights(weights) + for name, loaded_weight in weights: # skip pooling header if name.startswith("head."): @@ -373,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, continue for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: + if weight_name not in name or self.use_data_parallel: continue name = name.replace(weight_name, param_name) param = params_dict[name] diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index cd41d4fb4388..bc53982c938c 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -828,26 +828,19 @@ def get_replacement_mantis(item_idx: int): target=[image_token_id] * num_image_tokens, replacement=get_replacement_mantis, ) - ]) + ], mm_item_counts) prompt_ids, prompt, _ = self._apply_prompt_updates( result["prompt_token_ids"], mantis_mm_repls, - mm_item_counts, ) - unbound_orig_repls = self._get_prompt_updates( + orig_repls = self._get_mm_prompt_updates( mm_items, hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_and_group_updates(unbound_orig_repls) - - mm_placeholders = self._find_mm_placeholders( - orig_repls, - prompt_ids, - mm_item_counts, - ) + mm_placeholders = self._find_mm_placeholders(prompt_ids, orig_repls) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) mm_placeholder_ranges = { diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 42ab5e7c74d3..e4ac0cd91910 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -75,7 +75,7 @@ class LlavaOnevisionImagePixelInputs(TensorSchema): pixel_values: Annotated[ Union[torch.Tensor, list[torch.Tensor]], - TensorShape("bn", "np", 3, "h", "w"), + TensorShape("bn", "np", 3, "h", "w", dynamic_dims={"np"}), ] image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 48ce1b9d38e2..a2a71bdd12b3 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -778,6 +778,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # and config class self.config = config self.multimodal_config = multimodal_config + self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" self.version = get_version_by_config(self.config) self.llm = self.init_llm(vllm_config=vllm_config, @@ -1325,9 +1326,11 @@ def init_vision_module( prefix: str = "", ) -> nn.Module: quant_config = self._maybe_ignore_quant_config(quant_config) - model = Idefics2VisionTransformer(config.vision_config, - quant_config=quant_config, - prefix=prefix) + model = Idefics2VisionTransformer( + config.vision_config, + quant_config=quant_config, + prefix=prefix, + use_data_parallel=self.use_data_parallel) if self.config.drop_vision_last_layer: model.encoder.layers = model.encoder.layers[:-1] return model diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 82e96844cd5f..0e854bd7d913 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -4,7 +4,10 @@ import copy import math from collections.abc import Iterable -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend import regex as re import torch @@ -339,6 +342,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def mamba_type(self) -> str: return "linear_attention" + def get_attn_backend(self) -> type["AttentionBackend"]: + from vllm.v1.attention.backends.linear_attn import ( + LinearAttentionBackend) + return LinearAttentionBackend + def get_state_dtype(self) -> tuple[torch.dtype]: return MambaStateDtypeCalculator.linear_attention_state_dtype( self.model_config.dtype, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 7d6a6207c7c8..95abb190e0a4 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch from torch import nn @@ -21,6 +21,7 @@ PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel @@ -32,19 +33,27 @@ logger = init_logger(__name__) -class PaliGemmaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size * num_images, num_channels, height, width)`""" - +class PaliGemmaImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class PaliGemmaImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor - """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class PaliGemmaImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match language model backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, @@ -279,19 +288,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -301,22 +297,17 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, (torch.Tensor, list)): - raise ValueError("Incorrect type of pixel values. " - f"Got type: {type(pixel_values)}") - pixel_values = flatten_bn(pixel_values, concat=True) - return PaliGemmaImagePixelInputs( - type="pixel_values", - data=self._validate_pixel_values(pixel_values), - ) + h = w = self.config.vision_config.image_size + return PaliGemmaImagePixelInputs(type="pixel_values", + data=pixel_values, + resolve_bindings={ + "h": h, + "w": w + }) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): - raise ValueError("Incorrect type of image embeddings. " - f"Got type: {type(image_embeds)}") - image_embeds = flatten_bn(image_embeds, concat=True) return PaliGemmaImageEmbeddingInputs( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 078251ee2bf4..61e09d56046c 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -38,7 +38,8 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, BoundPromptUpdate, + BaseProcessingInfo, + MultiModalPromptUpdates, PlaceholderFeaturesInfo, PromptReplacement, PromptUpdate) # yapf: enable @@ -431,24 +432,21 @@ def get_replacement_phi3v(item_idx: int): return [_IMAGE_TOKEN_ID] * num_image_tokens - num_images = mm_items.get_count("image", strict=False) - return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_replacement_phi3v, - ) for image_token in image_tokens[:num_images] + ) ] def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: # align to hf behavior when there are images - if len(mm_item_counts): + if len(mm_prompt_updates): tokenizer = self.info.get_tokenizer() # to decode token_ids to the original text, we need to # 1. remove the first bos token @@ -484,7 +482,6 @@ def _apply_prompt_updates( token_ids, text, placeholders = super()._apply_prompt_updates( token_ids=token_ids, mm_prompt_updates=mm_prompt_updates, - mm_item_counts=mm_item_counts, ) # Keep the behavior in line with HF processor diff --git a/vllm/model_executor/models/phi4_multimodal.py b/vllm/model_executor/models/phi4_multimodal.py index ee8b71caf336..492d4bfb7d3e 100644 --- a/vllm/model_executor/models/phi4_multimodal.py +++ b/vllm/model_executor/models/phi4_multimodal.py @@ -1032,8 +1032,8 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: tokenizer = self.info.get_tokenizer() - image_token_id = tokenizer.vocab[tokenizer.image_token] - audio_token_id = tokenizer.vocab[tokenizer.audio_token] + image_token_id: int = tokenizer.vocab[tokenizer.image_token] + audio_token_id: int = tokenizer.vocab[tokenizer.audio_token] hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) audio_processor = self.info.get_feature_extractor( @@ -1053,9 +1053,7 @@ def get_image_replacement_phi4mm(item_idx: int): processor=hf_processor, ) - image_tokens = [image_token_id] * num_image_tokens - - return image_tokens + return [image_token_id] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -1066,9 +1064,7 @@ def get_audio_replacement_phi4mm(item_idx: int): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [audio_token_id] * audio_embed_size - - return audio_tokens + return [audio_token_id] * audio_embed_size return [ PromptReplacement( diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index b4aed11b8689..5129770e8d49 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -824,9 +824,7 @@ def get_image_replacement_phi4mm(item_idx: int): processor=hf_processor, ) - image_tokens = [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens - - return image_tokens + return [_IMAGE_PLACEHOLDER_TOKEN_ID] * num_image_tokens def get_audio_replacement_phi4mm(item_idx: int): audios = mm_items.get_items("audio", AudioProcessorItems) @@ -837,28 +835,20 @@ def get_audio_replacement_phi4mm(item_idx: int): audio_embed_size = self.info._compute_audio_embed_size( audio_frames) - audio_tokens = [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - - return audio_tokens + return [_AUDIO_PLACEHOLDER_TOKEN_ID] * audio_embed_size - num_images = mm_items.get_count("image", strict=False) - num_audios = mm_items.get_count("audio", strict=False) - - image_repl = [ + return [ PromptReplacement( modality="image", - target=image_token, + target=image_tokens.__getitem__, replacement=get_image_replacement_phi4mm, - ) for image_token in image_tokens[:num_images] - ] - audio_repl = [ + ), PromptReplacement( modality="audio", - target=audio_token, + target=audio_tokens.__getitem__, replacement=get_audio_replacement_phi4mm, - ) for audio_token in audio_tokens[:num_audios] + ), ] - return image_repl + audio_repl @MULTIMODAL_REGISTRY.register_processor( diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c01074e2122b..461b9c85d1c2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -5,7 +5,7 @@ from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields from functools import cached_property -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -48,6 +48,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import (MistralTokenizer, cached_tokenizer_from_config) +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, @@ -68,15 +69,20 @@ PATCH_MERGE = "patch_merge" -class PixtralImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - - images: Union[torch.Tensor, list[torch.Tensor]] +class PixtralImagePixelInputs(TensorSchema): """ - Shape: `(batch_size * num_images, num_channels, image_width, image_height)` - + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height of each image + - w: Width of each image + The result of stacking `ImageEncoding.tokens` from each prompt. """ + type: Literal["pixel_values"] = "pixel_values" + + images: Annotated[Union[torch.Tensor, list[torch.Tensor]], + TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"})] class PixtralProcessorAdapter: @@ -381,10 +387,6 @@ def _parse_and_validate_image_input( if images is None: return None - if not isinstance(images, (torch.Tensor, list)): - raise ValueError("Incorrect type of images. " - f"Got type: {type(images)}") - return PixtralImagePixelInputs( type="pixel_values", images=flatten_bn(images), diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 664e3f2985a5..a61b8ca8f7ae 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -309,9 +309,8 @@ def _maybe_apply_prompt_updates( if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders( mm_placeholders, @@ -328,7 +327,6 @@ def _maybe_apply_prompt_updates( ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders( mm_placeholders, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 811ecffcc1e4..0f11636ce3bd 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -135,7 +135,7 @@ class Qwen2_5_VLVideoPixelInputs(TypedDict): second_per_grid_ts: torch.Tensor """ - The video time interval (in seconds) for each grid along the temporal + The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when `videos` is not `None`. """ @@ -852,6 +852,10 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsQuant): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + # To ensure correct weight loading and mapping. hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 2812f79a66b7..8498f61b35fd 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -45,6 +45,9 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -146,11 +149,20 @@ def __init__( enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -682,4 +694,4 @@ def load_weights(self, weights: Iterable[tuple[str, return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() \ No newline at end of file + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 465c25f09480..ebf78771e40a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -252,6 +252,7 @@ "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] + "DonutForConditionalGeneration": ("donut", "DonutForConditionalGeneration"), "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 "Llama4ForConditionalGeneration": ("mllama4", "Llama4ForConditionalGeneration"), # noqa: E501 diff --git a/vllm/model_executor/models/skyworkr1v.py b/vllm/model_executor/models/skyworkr1v.py index 920f4def6917..9857ccdcbe2d 100644 --- a/vllm/model_executor/models/skyworkr1v.py +++ b/vllm/model_executor/models/skyworkr1v.py @@ -8,7 +8,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- from collections.abc import Iterable, Mapping, Sequence -from typing import Literal, Optional, TypedDict, Union +from typing import Annotated, Literal, Optional, Union import torch import torch.nn as nn @@ -35,6 +35,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, @@ -48,27 +49,42 @@ IMAGENET_STD = (0.229, 0.224, 0.225) -class SkyworkR1VImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values_flat: torch.Tensor +class SkyworkR1VImagePixelInputs(TensorSchema): """ - Shape: - `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + Dimensions: + - bnp: Batch size * number of images * (1 + num_patches) + - c: Number of channels (3) + - h: Height + - w: Width + - bn: Batch size * number of images """ + type: Literal["pixel_values"] = "pixel_values" - num_patches: torch.Tensor - """Shape: `(batch_size * num_images)`""" + pixel_values_flat: Annotated[ + torch.Tensor, + TensorShape("bnp", 3, "h", "w"), + ] + num_patches: Annotated[ + torch.Tensor, + TensorShape("bn"), + ] -class SkyworkR1VImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: Union[torch.Tensor, list[torch.Tensor]] - """ - A tensor of shape `(num_images, total_image_feature_size, hidden_size)` - or a list of tensors of shape `(total_image_feature_size, hidden_size)` - `hidden_size` must match the hidden size of language model backbone. +class SkyworkR1VImageEmbeddingInputs(TensorSchema): """ + Dimensions: + - ni: Number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + + data: Annotated[ + Union[torch.Tensor, list[torch.Tensor]], + TensorShape("ni", "ifs", "hs"), + ] SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs, @@ -731,26 +747,6 @@ def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) - - def _validate_shape(d: torch.Tensor): - actual_dims = tuple(d.shape) - - if actual_dims != expected_dims: - expected_expr = str(expected_dims) - raise ValueError( - "The expected shape of pixel values per image per batch " - f" per patch is {expected_expr}. " - f"You supplied {tuple(d.shape)}.") - - for d in data: - _validate_shape(d) - - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]: pixel_values_flat = kwargs.pop("pixel_values_flat", None) @@ -788,10 +784,12 @@ def _parse_and_validate_image_input( return SkyworkR1VImagePixelInputs( type="pixel_values", - pixel_values_flat=self._validate_pixel_values( - pixel_values_flat), + pixel_values_flat=pixel_values_flat, num_patches=image_num_patches, - ) + resolve_bindings={ + "h": self.config.vision_config.image_size, + "w": self.config.vision_config.image_size, + }) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/swin.py b/vllm/model_executor/models/swin.py new file mode 100644 index 000000000000..30b441f5b4df --- /dev/null +++ b/vllm/model_executor/models/swin.py @@ -0,0 +1,475 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import SwinConfig +from transformers.models.swin.modeling_swin import SwinEmbeddings +from transformers.models.swin.modeling_swin import SwinLayer as HFSwinLayer +from transformers.models.swin.modeling_swin import SwinPatchMerging +from transformers.pytorch_utils import meshgrid + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + + +class SwinSelfAttention(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if dim % num_heads != 0: + raise ValueError( + f"The hidden size ({dim}) is not a multiple of the number of " + f"attention heads ({num_heads})") + + self.num_attention_heads = num_heads + self.attention_head_size = int(dim / num_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.window_size = (window_size if isinstance(window_size, Iterable) + else (window_size, window_size)) + self.scale = self.attention_head_size**-0.5 + + self.relative_position_bias_table = nn.Parameter( + torch.zeros( + (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), + num_heads)) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) + coords_flatten = torch.flatten(coords, 1) + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, + None, :] + relative_coords = relative_coords.permute(1, 2, 0).contiguous() + relative_coords[:, :, 0] += self.window_size[0] - 1 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) + + self.relative_position_index = nn.Parameter(relative_position_index, + requires_grad=False) + + self.qkv = QKVParallelLinear( + hidden_size=dim, + head_size=self.attention_head_size, + total_num_heads=self.num_attention_heads, + bias=config.qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv", + ) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)] + relative_position_bias = relative_position_bias.view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() + return relative_position_bias.unsqueeze(0) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor, ...]: + batch_size, dim, num_channels = hidden_states.shape + + qkv_output, _ = self.qkv(hidden_states) + query_layer, key_layer, value_layer = qkv_output.chunk(3, dim=-1) + + key_layer = self.transpose_for_scores(key_layer) + value_layer = self.transpose_for_scores(value_layer) + query_layer = self.transpose_for_scores(query_layer) + + attention_scores = self._get_rel_pos_bias() + if attention_mask is not None: + mask_shape = attention_mask.shape[0] + attention_mask_expanded = attention_mask.view( + 1, mask_shape, 1, dim, + dim).expand(batch_size // mask_shape, mask_shape, + self.num_attention_heads, dim, dim) + attention_scores = attention_scores + \ + attention_mask_expanded.unsqueeze( + 1).unsqueeze(0) + attention_scores = attention_scores.view(-1, + self.num_attention_heads, + dim, dim) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_scores, + dropout_p=0., + ) + attention_probs = None + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, ) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, + attention_probs) if output_attentions else (context_layer, ) + + return outputs + + +class SwinSelfOutput(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.dense = RowParallelLinear( + input_size=dim, + output_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + + return hidden_states + + +class SwinAttention(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + num_heads: int, + window_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.self = SwinSelfAttention(config, + dim, + num_heads, + window_size, + quant_config=quant_config, + prefix=f"{prefix}.self") + self.output = SwinSelfOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + self_outputs = self.self(hidden_states, attention_mask, head_mask, + output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, ) + self_outputs[1:] + return outputs + + +class SwinIntermediate(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = ColumnParallelLinear(dim, + int(config.mlp_ratio * dim), + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class SwinOutput(nn.Module): + + def __init__(self, + config: SwinConfig, + dim: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + super().__init__() + self.dense = RowParallelLinear(int(config.mlp_ratio * dim), + dim, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + return hidden_states + + +class SwinLayer(HFSwinLayer): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + num_heads: int, + drop_path_rate: float = 0.0, + shift_size: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path_rate, + shift_size=shift_size, + ) + + self.attention = SwinAttention(config, + dim, + num_heads, + window_size=self.window_size, + quant_config=quant_config, + prefix=f"{prefix}.attention") + self.intermediate = SwinIntermediate(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + self.output = SwinOutput(config, + dim, + quant_config=quant_config, + prefix=f"{prefix}.output") + + +class SwinStage(nn.Module): + + def __init__( + self, + config: SwinConfig, + dim: int, + input_resolution: int, + depth: int, + num_heads: int, + drop_path: list[float], + downsample: Optional[SwinPatchMerging] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.dim = dim + self.blocks = nn.ModuleList([ + SwinLayer(config=config, + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + drop_path_rate=drop_path[layer_idx], + shift_size=0 if + (layer_idx % 2 == 0) else config.window_size // 2, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, + dim=dim, + norm_layer=nn.LayerNorm) + else: + self.downsample = None + + self.pointing = False + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + + hidden_states_before_downsampling = hidden_states + if self.downsample is not None: + height_downsampled, width_downsampled = (height + 1) // 2, (width + + 1) // 2 + output_dimensions = (height, width, height_downsampled, + width_downsampled) + hidden_states = self.downsample(hidden_states_before_downsampling, + input_dimensions) + else: + output_dimensions = (height, width, height, width) + + stage_outputs = (hidden_states, hidden_states_before_downsampling, + output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs + + +class SwinEncoder(nn.Module): + + def __init__( + self, + config: SwinConfig, + grid_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.num_layers = len(config.depths) + self.config = config + dpr = [ + x.item() for x in torch.linspace( + 0, config.drop_path_rate, sum(config.depths), device="cpu") + ] + self.layers = nn.ModuleList([ + SwinStage(config=config, + dim=int(config.embed_dim * 2**layer_idx), + input_resolution=(grid_size[0] // (2**layer_idx), + grid_size[1] // (2**layer_idx)), + depth=config.depths[layer_idx], + num_heads=config.num_heads[layer_idx], + drop_path=dpr[sum(config.depths[:layer_idx] + ):sum(config.depths[:layer_idx + 1])], + downsample=SwinPatchMerging if + (layer_idx < self.num_layers - 1) else None, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(self.num_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + input_dimensions: tuple[int, int], + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + always_partition: Optional[bool] = False, + ) -> tuple[torch.Tensor]: + for i, layer_module in enumerate(self.layers): + layer_head_mask = head_mask[i] if head_mask is not None else None + + layer_outputs = layer_module(hidden_states, input_dimensions, + layer_head_mask, output_attentions, + always_partition) + + hidden_states = layer_outputs[0] + output_dimensions = layer_outputs[2] + + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + + return hidden_states + + +class SwinModel(nn.Module): + config_class: SwinConfig + + def __init__( + self, + config: SwinConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.num_layers = len(config.depths) + self.num_features = int(config.embed_dim * 2**(self.num_layers - 1)) + + self.embeddings = SwinEmbeddings(config) + self.encoder = SwinEncoder(config, + self.embeddings.patch_grid, + quant_config=quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + pixel_values: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + ) -> tuple[torch.Tensor]: + embedding_output, input_dimensions = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + embedding_output, + input_dimensions, + head_mask=head_mask, + output_attentions=output_attentions, + ) + + return encoder_outputs + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv", "query", "q"), + ("qkv", "key", "k"), + ("qkv", "value", "v"), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 0990be8d02b9..9b9cca8c6bd3 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar, +from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar, Union, cast) import torch @@ -34,6 +34,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP @@ -43,14 +44,28 @@ from .vision import VisionEncoderInfo, get_vision_encoder_info -class TarsierImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - pixel_values: torch.Tensor +class TarsierImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + """ + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] -class TarsierImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - data: torch.Tensor +class TarsierImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - ifs: Image feature size + - hs: Hidden size (must match the hidden size of language model + backbone) + """ + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")] TarsierImageInputs = Union[TarsierImagePixelInputs, @@ -432,18 +447,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: - h = w = self.config.vision_config.image_size - expected_dims = (3, h, w) # Assuming 3 channels - actual_dims = tuple(data.shape[1:]) - - if actual_dims != expected_dims: - expected_expr = ("batch_size", *map(str, expected_dims)) - raise ValueError( - f"The expected shape of pixel values is {expected_expr}. " - f"You supplied {tuple(data.shape)}.") - return data - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[TarsierImageInputs]: pixel_values = kwargs.pop("pixel_values", None) @@ -459,8 +462,7 @@ def _parse_and_validate_image_input( return TarsierImagePixelInputs( type="pixel_values", - pixel_values=self._validate_pixel_values( - flatten_bn(pixel_values, concat=True)), + pixel_values=flatten_bn(pixel_values, concat=True), ) if image_embeds is not None: diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 55fd1479d2de..8c225e2a3c08 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -44,10 +44,59 @@ """A token sequence (list of token IDs) or text.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: Optional[bool] = None, +) -> list[int]: + return encode_tokens(tokenizer, + text, + add_special_tokens=add_special_tokens) + + +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: Optional[bool] = None, +) -> str: + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str: + if isinstance(seq, str): + return seq + + return _cached_decode(tokenizer, tuple(seq)) + + +def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]: + if isinstance(seq, str): + return _cached_encode(tokenizer, seq, add_special_tokens=False) + + return seq + + +class _GetMatchIndex(Protocol): + + def __call__( + self, + tokenizer: AnyTokenizer, + prompt: PromptSeq, + start_idx: int = 0, + ) -> Optional[int]: + ... + + @dataclass class PromptIndex: """Resolves to an index in the prompt.""" - get_match_index: Callable[[AnyTokenizer, PromptSeq], Optional[int]] + get_match_index: _GetMatchIndex class PromptIndexTargets: @@ -59,7 +108,7 @@ def start() -> PromptIndex: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: 0) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: 0) @staticmethod def prefix(seq: PromptSeq) -> PromptIndex: @@ -70,7 +119,11 @@ def prefix(seq: PromptSeq) -> PromptIndex: def get_match_index( tokenizer: AnyTokenizer, prompt: PromptSeq, + start_idx: int = 0, ) -> Optional[int]: + if start_idx != 0: + return None + prefix = seq if isinstance(prompt, str): @@ -96,14 +149,24 @@ def end() -> PromptIndex: This results in a match even if the prompt is empty. """ - return PromptIndex(lambda tok, prompt: len(prompt)) + return PromptIndex(lambda tokenizer, prompt, start_idx=0: len(prompt)) -PromptTarget = Union[PromptSeq, PromptIndex] +UpdateTarget = Union[PromptSeq, PromptIndex] """ The token sequence or text to update. """ +PromptUpdateTarget = Union[Callable[[int], UpdateTarget], UpdateTarget] +""" +Given the index of the processed item within +[`modality`][vllm.multimodal.processing.PromptUpdate.modality], +output the corresponding token sequence (or text). + +For convenience, you can directly pass in the token sequence (or text) +instead of a function if it does not depend on the input. +""" + @dataclass class PromptUpdateDetails(Generic[_S]): @@ -112,7 +175,8 @@ class PromptUpdateDetails(Generic[_S]): full: _S """The full content.""" - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None + is_embed: Optional[Callable[[AnyTokenizer, PromptSeq], + torch.Tensor]] = None """ Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full], return a boolean mask of shape `(len(full),)` indicating which positions @@ -134,11 +198,12 @@ def select_text( embed_text: str, ) -> "PromptUpdateDetails[_S]": - def is_embed(full: "_BoundPromptSequence") -> torch.Tensor: - embed_token_ids = encode_tokens(full.tokenizer, embed_text) + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + embed_token_ids = encode_tokens(tokenizer, embed_text) + token_ids = _seq2tokens(tokenizer, full) return torch.isin( - torch.tensor(full.token_ids), + torch.tensor(token_ids), torch.tensor(embed_token_ids), ) @@ -149,10 +214,13 @@ def select_token_id( seq: _S, embed_token_id: int, ) -> "PromptUpdateDetails[_S]": - return PromptUpdateDetails( - full=seq, - is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id, - ) + + def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor: + token_ids = _seq2tokens(tokenizer, full) + + return torch.tensor(token_ids) == embed_token_id + + return PromptUpdateDetails(full=seq, is_embed=is_embed) PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails] @@ -190,7 +258,7 @@ class PromptUpdate(ABC): modality: str """The modality for which the update is made.""" - target: PromptTarget + target: PromptUpdateTarget """The token sequence (or text) to update.""" @property @@ -205,10 +273,35 @@ def mode(self) -> UpdateMode: """Defines how to update the prompt.""" raise NotImplementedError - def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate": - return BoundPromptUpdate( - _origin=self, - tokenizer=tokenizer, + def _resolve_target(self, item_idx: int) -> UpdateTarget: + target = self.target + if callable(target): + target = target(item_idx) + + return target + + def _resolve_content(self, item_idx: int) -> PromptUpdateDetails: + content = self.content + if callable(content): + content = content(item_idx) + + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return content + + def resolve(self, item_idx: int) -> "ResolvedPromptUpdate": + """ + Given the index of the processed item within + [`modality`][vllm.multimodal.processing.PromptUpdate.modality], + output a copy of this object with its lazy attributes resolved. + """ + return ResolvedPromptUpdate( + modality=self.modality, + item_idx=item_idx, + mode=self.mode, + target=self._resolve_target(item_idx), + content=self._resolve_content(item_idx), ) @@ -355,30 +448,6 @@ def mode(self) -> UpdateMode: return UpdateMode.REPLACE -@lru_cache(maxsize=2048) -def _cached_encode( - tokenizer: AnyTokenizer, - text: str, - *, - add_special_tokens: Optional[bool] = None, -) -> list[int]: - return encode_tokens(tokenizer, - text, - add_special_tokens=add_special_tokens) - - -@lru_cache(maxsize=2048) -def _cached_decode( - tokenizer: AnyTokenizer, - token_ids: tuple[int, ...], - *, - skip_special_tokens: Optional[bool] = None, -) -> str: - return decode_tokens(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) - - class _HasModalityAttr(Protocol): modality: str @@ -399,126 +468,94 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: return full_groupby(values, key=lambda x: x.modality) -@dataclass -class _BoundPromptSequence: +class PromptTargetMatch(NamedTuple): + start_idx: int + end_idx: int + + +@dataclass(frozen=True) +class ResolvedPromptUpdate: """ - A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound - to a tokenizer to automatically - convert between token sequence and text representations. + A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] with its + lazy attributes resolved, apart from those related to tokenization. """ - tokenizer: AnyTokenizer = field(repr=False) - _text: Optional[str] - _token_ids: Optional[list[int]] + modality: str + """The modality for which the update is made.""" - @staticmethod - def from_seq( - tokenizer: AnyTokenizer, - seq: PromptSeq, - ) -> "_BoundPromptSequence": - return _BoundPromptSequence( - tokenizer=tokenizer, - _text=seq if isinstance(seq, str) else None, - _token_ids=seq if isinstance(seq, list) else None, - ) + item_idx: int + """The index within `modality` of the item this update pertains to.""" - def __post_init__(self) -> None: - if self._text is None and self._token_ids is None: - raise ValueError("At least one of 'text' and 'token_ids' must be " - "specified") + mode: UpdateMode + """Defines how to update the prompt.""" - @property - def text(self) -> str: - if self._text is None: - assert self._token_ids is not None - self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) + target: UpdateTarget + """The token sequence (or text) to update.""" - return self._text + content: PromptUpdateDetails = field(repr=False) + """The placeholder tokens that are part of the update.""" - @property - def token_ids(self) -> list[int]: - if self._token_ids is None: - assert self._text is not None - self._token_ids = _cached_encode(self.tokenizer, - self._text, - add_special_tokens=False) + def iter_token_matches( + self, + prompt: list[int], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target - return self._token_ids + if isinstance(target, PromptIndex): + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) + return -@dataclass -class _BoundPromptContent: - full: _BoundPromptSequence - is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] + target_token_ids = _seq2tokens(tokenizer, target) + for match in iter_token_matches(prompt, + target_token_ids, + start_idx=start_idx): + yield PromptTargetMatch(match.start_idx, match.end_idx) -@dataclass -class BoundPromptUpdate: - """ - A [`PromptUpdate`][vllm.multimodal.processing.PromptUpdate] bound - to a tokenizer to automatically convert - [`target`][vllm.multimodal.processing.PromptUpdate.target] and the result of - [`get_content`][vllm.multimodal.processing.BoundPromptUpdate.get_content] - between token sequence and text representations. - """ - _origin: PromptUpdate - tokenizer: AnyTokenizer = field(repr=False) - - def __post_init__(self) -> None: - self._content_cache = dict[int, _BoundPromptContent]() - - @property - def modality(self) -> str: - return self._origin.modality - - @property - def target(self) -> Union[_BoundPromptSequence, PromptIndex]: - """The token sequence (or text) to update.""" - target = self._origin.target + def iter_text_matches( + self, + prompt: str, + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + target = self.target if isinstance(target, PromptIndex): - return target + match_idx = target.get_match_index(tokenizer, prompt, start_idx) + if match_idx is not None: + yield PromptTargetMatch(match_idx, match_idx) - return _BoundPromptSequence.from_seq(self.tokenizer, target) + return - @property - def content(self) -> PromptUpdateContent: - """The placeholder tokens that are part of the update.""" - return self._origin.content + target_text = _seq2text(tokenizer, target) - @property - def mode(self) -> UpdateMode: - """Defines how to update the prompt.""" - return self._origin.mode + for match in re.finditer(re.escape(target_text), prompt, + pos=start_idx): + yield PromptTargetMatch(match.start(), match.end()) - def get_content(self, item_idx: int) -> _BoundPromptContent: - """ - Given the index of the processed item within - [`modality`][vllm.multimodal.processing.PromptUpdate.modality], - output the token sequence (or text) to update. - """ - content = self.content - if callable(content): - cache_key = item_idx - if cache_key in self._content_cache: - return self._content_cache[cache_key] - - content = content(item_idx) - else: - cache_key = None - - if not isinstance(content, PromptUpdateDetails): - content = PromptUpdateDetails.from_seq(content) - - bound_full = _BoundPromptSequence.from_seq(self.tokenizer, - content.full) - bound_content = _BoundPromptContent(full=bound_full, - is_embed=content.is_embed) - - if cache_key is not None: - self._content_cache[cache_key] = bound_content + def iter_matches( + self, + prompt: Union[list[int], str], + tokenizer: AnyTokenizer, + *, + start_idx: int = 0, + ) -> Generator[PromptTargetMatch]: + """Yield each instance of `self.target` found in `prompt`.""" + if isinstance(prompt, str): + return self.iter_text_matches(prompt, + tokenizer, + start_idx=start_idx) - return bound_content + return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) class _TokenMatch(NamedTuple): @@ -529,6 +566,8 @@ class _TokenMatch(NamedTuple): def iter_token_matches( token_ids: list[int], match_ids: list[int], + *, + start_idx: int = 0, ) -> Generator[_TokenMatch]: """ Yield each occurrence of `match_ids` in `token_ids`. @@ -541,7 +580,6 @@ def iter_token_matches( if match_len == 0: return - start_idx = 0 while start_idx < prompt_len - match_len + 1: end_idx = start_idx + match_len @@ -581,68 +619,6 @@ def replace_token_matches( return flatten_2d_lists(out_seqs) -@dataclass(repr=False) -class PromptTargetMatch(ABC): - _origin: BoundPromptUpdate - - @property - def modality(self) -> str: - return self._origin.modality - - @property - @abstractmethod - def start_idx(self) -> int: - raise NotImplementedError - - @property - @abstractmethod - def end_idx(self) -> int: - raise NotImplementedError - - def __repr__(self) -> str: - return (f"{type(self).__name__}(modality={self.modality!r}, " - f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") - - -@dataclass(repr=False) -class _PromptTargetIndexMatch(PromptTargetMatch): - match_idx: int - - @property - def start_idx(self) -> int: - return self.match_idx - - @property - def end_idx(self) -> int: - return self.match_idx - - -@dataclass(repr=False) -class _PromptTargetTokenMatch(PromptTargetMatch): - match: _TokenMatch - - @property - def start_idx(self) -> int: - return self.match.start_idx - - @property - def end_idx(self) -> int: - return self.match.end_idx - - -@dataclass(repr=False) -class _PromptTargetTextMatch(PromptTargetMatch): - match: re.Match[str] - - @property - def start_idx(self) -> int: - return self.match.start() - - @property - def end_idx(self) -> int: - return self.match.end() - - @dataclass class PlaceholderFeaturesInfo: modality: str @@ -665,163 +641,161 @@ def to_range(self) -> PlaceholderRange: ) -def find_token_matches( - prompt: list[int], - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTokenMatch(update, match) - for match in iter_token_matches(prompt, target.token_ids) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] +_MatchToApply = tuple[tuple[str, int], tuple[PromptTargetMatch, int]] -def find_text_matches( - prompt: str, - prompt_updates: Sequence[BoundPromptUpdate], -) -> Sequence[PromptTargetMatch]: - """Return each target of `prompt_updates` found in `prompt`.""" - - def get_matches(update: BoundPromptUpdate): - target = update.target - - if isinstance(target, PromptIndex): - match_idx = target.get_match_index(update.tokenizer, prompt) - if match_idx is None: - return [] - - return [_PromptTargetIndexMatch(update, match_idx)] - - return [ - _PromptTargetTextMatch(update, match) - for match in re.finditer(re.escape(target.text), prompt) - ] - - return [ - match for update in prompt_updates for match in get_matches(update) - ] - - -def _resolve_matches( - prompt: PromptSeq, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], -) -> list[PromptTargetMatch]: - """ - Resolve `mm_matches` to ensure that there are no overlapping matches, - and sort them such that earlier matches take priority over later ones. - """ - matches = [m for matches in mm_matches.values() for m in matches] - - seen_matches: list[Optional[PromptTargetMatch]] = [None] * len(prompt) - - for match in matches: - for idx in range(match.start_idx, match.end_idx): - if seen_matches[idx] is not None: - raise ValueError("Found overlapping matches " - f"({seen_matches[idx]} and {match}) " - f"at index={idx} of prompt={prompt}") - - seen_matches[idx] = match - - return sorted(matches, key=lambda x: x.start_idx) +def _find_matches( + prompt: _S, + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, + *, + prev_end_idx: int = 0, + current_result: "MultiModalPromptUpdatesApplyResult", +) -> tuple[Optional[UpdateMode], list[_MatchToApply]]: + mode: Optional[UpdateMode] = None + mm_matches = dict[tuple[str, int], tuple[PromptTargetMatch, int]]() + + for modality, modality_updates in mm_prompt_updates.items(): + for item_idx, item_updates in enumerate(modality_updates): + if current_result[modality][item_idx] is not None: + continue # Updates have already been applied for this item + + for update_idx, update in enumerate(item_updates): + if (modality, item_idx) in mm_matches: + break # Already found a match for this item + + for match in update.iter_matches( + prompt, + tokenizer, + start_idx=prev_end_idx, + ): + # All matches should share the same mode + if mode is None: + mode = update.mode + elif mode != update.mode: + continue + + mm_matches[(modality, item_idx)] = match, update_idx + break # Get only the first valid match per item + + # Prioritize earlier matches + matches_to_apply = sorted(mm_matches.items(), key=lambda item: item[1][0]) + + # To avoid conflicts, only replace one non-empty item at a time + if mode == UpdateMode.REPLACE: + matches_to_apply_ = list[_MatchToApply]() + has_non_empty_matches = False + + for item in matches_to_apply: + _, (match, _) = item + if match.start_idx == match.end_idx: + matches_to_apply_.append(item) + elif not has_non_empty_matches: + has_non_empty_matches = True + matches_to_apply_.append(item) + + matches_to_apply = matches_to_apply_ + + return mode, matches_to_apply def _apply_matches( prompt: _S, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[_S]: - """Apply the updates in `mm_matches` to `prompt`.""" + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[_S], "MultiModalPromptUpdatesApplyResult"]: + prompt_len = len(prompt) + out_seqs = list[Union[str, list[int]]]() - prev_end_idx = 0 - next_idx_by_modality = defaultdict[str, int](lambda: 0) + out_result: MultiModalPromptUpdatesApplyResult = { + m: [None] * len(items) + for m, items in mm_prompt_updates.items() + } - for match in _resolve_matches(prompt, mm_matches): - modality = match.modality + start_idx = prev_end_idx = 0 + while start_idx < max(prompt_len, 1): # Allow inserts into empty prompt + found = False - item_start_idx = next_idx_by_modality[modality] - max_item_count = mm_item_counts.get(modality, 0) - if item_start_idx >= max_item_count: - continue + mode, matches_to_apply = _find_matches( + prompt, + mm_prompt_updates, + tokenizer, + prev_end_idx=prev_end_idx, + current_result=out_result, + ) - start_idx = match.start_idx - end_idx = match.end_idx - origin = match._origin - mode = origin.mode - - if mode == UpdateMode.INSERT: - out_seqs.append(prompt[prev_end_idx:end_idx]) - num_inserts = max_item_count - elif mode == UpdateMode.REPLACE: - out_seqs.append(prompt[prev_end_idx:start_idx]) - num_inserts = max_item_count if start_idx == end_idx else 1 - else: - assert_never(mode) + if mode is not None: + for (modality, item_idx), (match, update_idx) in matches_to_apply: + found = True - item_end_idx = min(item_start_idx + num_inserts, max_item_count) + matched_update = mm_prompt_updates[modality][item_idx][ + update_idx] + matched_content = matched_update.content.full - for item_idx in range(item_start_idx, item_end_idx): - content = origin.get_content(item_idx) - insert_seq = (content.full.text if isinstance(prompt, str) else - content.full.token_ids) + if mode == UpdateMode.INSERT: + end_idx_to_insert = match.end_idx + elif mode == UpdateMode.REPLACE: + end_idx_to_insert = match.start_idx + else: + assert_never(mode) - out_seqs.append(insert_seq) + out_seqs.append(prompt[prev_end_idx:end_idx_to_insert]) + out_seqs.append( + _seq2text(tokenizer, matched_content + ) if isinstance(prompt, str) else _seq2tokens( + tokenizer, matched_content)) + out_result[modality][item_idx] = update_idx - prev_end_idx = end_idx - next_idx_by_modality[modality] += item_end_idx - item_start_idx + # Exclude overlapping matches + start_idx = prev_end_idx = match.end_idx + + if not found: + start_idx += 1 out_seqs.append(prompt[prev_end_idx:]) - return cast(list[_S], out_seqs) + return cast(list[_S], out_seqs), out_result def apply_token_matches( prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> list[int]: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[list[int], "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + token_id_seqs, result = _apply_matches(prompt, mm_prompt_updates, + tokenizer) - return flatten_2d_lists(token_id_seqs) + return flatten_2d_lists(token_id_seqs), result def apply_text_matches( prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], -) -> str: - """Apply the updates in `mm_matches` to `prompt`.""" - if not mm_matches: - return prompt + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, +) -> tuple[str, "MultiModalPromptUpdatesApplyResult"]: + """ + Apply the updates in `mm_prompt_updates` to `prompt`. - texts = _apply_matches(prompt, mm_matches, mm_item_counts) + Matches are exclusive even when multiple modalities share + the same placeholder tokens. In that case, the modality that + appears earlier in `mm_prompt_updates` takes priority. + """ + texts, result = _apply_matches(prompt, mm_prompt_updates, tokenizer) - return "".join(texts) + return "".join(texts), result def _iter_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Iterable[PlaceholderFeaturesInfo]: """ Yield each set of placeholder tokens found in `prompt`. @@ -833,6 +807,8 @@ def _iter_placeholders( Note that empty matches are ignored. """ prompt_len = len(prompt) + mm_item_counts = {m: len(items) for m, items in mm_prompt_updates.items()} + item_idx_by_modality = defaultdict[str, int](lambda: 0) start_idx = 0 @@ -844,9 +820,9 @@ def _iter_placeholders( if item_idx >= mm_item_counts.get(modality, 0): continue - for update_info in modality_updates: - content = update_info.get_content(item_idx) - content_tokens_full = content.full.token_ids + for update in modality_updates[item_idx]: + content = update.content + content_tokens_full = _seq2tokens(tokenizer, content.full) content_len_full = len(content_tokens_full) end_idx_full = start_idx + content_len_full @@ -856,7 +832,8 @@ def _iter_placeholders( if prompt[start_idx:end_idx_full] == content_tokens_full: content_is_embed = content.is_embed if content_is_embed is not None: - content_is_embed = content_is_embed(content.full) + content_is_embed = content_is_embed( + tokenizer, content.full) yield PlaceholderFeaturesInfo( modality=modality, @@ -880,11 +857,11 @@ def _iter_placeholders( def find_mm_placeholders( - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], prompt: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: "MultiModalPromptUpdates", + tokenizer: AnyTokenizer, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts) + it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer) return dict(full_groupby_modality(it)) @@ -989,12 +966,20 @@ def get_mm_max_tokens_per_item( [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ -MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]] +MultiModalPromptUpdates = Mapping[str, list[Sequence[ResolvedPromptUpdate]]] """ A collection of prompt updates with a similar structure as [`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems]. """ +MultiModalPromptUpdatesApplyResult = Mapping[str, list[Optional[int]]] +""" +For an item `MultiModalPromptUpdates[k][i]`, +`MultiModalPromptUpdatesApplyResult[k][i]` represents the index of the +`ResolvedPromptUpdate` instance that has been applied, or `None` if none of the +`ResolvedPromptUpdate` instances have been applied. +""" + class MultiModalProcessingInfo(NamedTuple): kwargs: MultiModalKwargsItems @@ -1126,14 +1111,60 @@ def _get_prompt_updates( """ raise NotImplementedError + def _bind_and_group_updates( + self, + prompt_updates: Sequence[PromptUpdate], + mm_item_counts: Mapping[str, int], + ) -> MultiModalPromptUpdates: + return { + modality: [[update.resolve(item_idx) for update in updates] + for item_idx in range(mm_item_counts.get(modality, 0))] + for modality, updates in full_groupby_modality(prompt_updates) + } + + def _get_mm_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> MultiModalPromptUpdates: + unbound_prompt_updates = self._get_prompt_updates( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, + ) + + mm_prompt_updates = self._bind_and_group_updates( + unbound_prompt_updates, + mm_items.get_all_counts(), + ) + + for modality, prompt_updates in mm_prompt_updates.items(): + for item_idx, item_prompt_updates in enumerate(prompt_updates): + if len(item_prompt_updates) > 1: + logger.warning_once( + "Detected %d prompt updates for `mm_items[%r][%s]`. " + "Multiple prompt updates per item is now " + "deprecated and may be removed in v0.13. " + "Instead, please specify dynamic update targets " + "in the same prompt update definition by passing " + "a function to `PromptUpdate.target`.", + len(prompt_updates), + modality, + item_idx, + ) + + return mm_prompt_updates + def _find_mm_placeholders( self, - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], new_token_ids: list[int], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> Mapping[str, list[PlaceholderFeaturesInfo]]: - return find_mm_placeholders(mm_prompt_updates, new_token_ids, - mm_item_counts) + tokenizer = self.info.get_tokenizer() + + return find_mm_placeholders(new_token_ids, mm_prompt_updates, + tokenizer) def _get_hf_mm_data( self, @@ -1421,13 +1452,11 @@ def _apply_hf_processor( mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs) - unbound_prompt_updates = self._get_prompt_updates( + mm_prompt_updates = self._get_mm_prompt_updates( mm_data_items, hf_processor_mm_kwargs, mm_kwargs, ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) mm_info = MultiModalProcessingInfo( kwargs=mm_kwargs, @@ -1497,13 +1526,11 @@ def _cached_apply_hf_processor( mm_missing_kwargs=mm_missing_kwargs, ) - unbound_prompt_updates = self._get_prompt_updates( + mm_prompt_updates = self._get_mm_prompt_updates( mm_data_items, hf_processor_mm_kwargs, mm_kwargs, ) - mm_prompt_updates = self._bind_and_group_updates( - unbound_prompt_updates) mm_info = MultiModalProcessingInfo( kwargs=mm_kwargs, @@ -1513,47 +1540,33 @@ def _cached_apply_hf_processor( return prompt_ids, mm_info, is_update_applied - def _bind_and_group_updates( - self, - prompt_updates: Sequence[PromptUpdate], - ) -> dict[str, Sequence[BoundPromptUpdate]]: - tokenizer = self.info.get_tokenizer() - - it = (update.bind(tokenizer) for update in prompt_updates) - return dict(full_groupby_modality(it)) - def _apply_token_matches( self, prompt: list[int], - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> list[int]: - return apply_token_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[list[int], MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_token_matches(prompt, mm_prompt_updates, tokenizer) def _apply_text_matches( self, prompt: str, - mm_matches: Mapping[str, Sequence[PromptTargetMatch]], - mm_item_counts: Mapping[str, int], - ) -> str: - return apply_text_matches(prompt, mm_matches, mm_item_counts) + mm_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[str, MultiModalPromptUpdatesApplyResult]: + tokenizer = self.info.get_tokenizer() + return apply_text_matches(prompt, mm_prompt_updates, tokenizer) def _apply_prompt_updates( self, token_ids: list[int], - mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]], - mm_item_counts: Mapping[str, int], + mm_prompt_updates: MultiModalPromptUpdates, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: tokenizer = self.info.get_tokenizer() - mm_token_matches = { - modality: find_token_matches(token_ids, updates) - for modality, updates in mm_prompt_updates.items() - } - mm_match_counts = { - modality: len(matches) - for modality, matches in mm_token_matches.items() - } + new_token_ids, match_result = self._apply_token_matches( + token_ids, + mm_prompt_updates, + ) # If the search text does not represent a special token, # it may have different token IDs in the prompt, because @@ -1566,48 +1579,38 @@ def _apply_prompt_updates( # of the search text in the prompt, we instead perform string-based # updates on the decoded token IDs, then encode them back. if all( - mm_match_counts.get(modality, 0) >= item_count - for modality, item_count in mm_item_counts.items() - ): # yapf: disable - token_ids = self._apply_token_matches( - token_ids, - mm_token_matches, - mm_item_counts, + all(update_idx is not None for update_idx in update_idxs) + for update_idxs in match_result.values()): + new_text = decode_tokens(tokenizer, new_token_ids) + else: + new_text, match_result = self._apply_text_matches( + decode_tokens(tokenizer, token_ids), + mm_prompt_updates, ) - text = decode_tokens(tokenizer, token_ids) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_token_matches.items() - } - else: - text = decode_tokens(tokenizer, token_ids) - - mm_text_matches = { - modality: find_text_matches(text, updates) - for modality, updates in mm_prompt_updates.items() - } - text = self._apply_text_matches( - text, - mm_text_matches, - mm_item_counts, + new_token_ids = encode_tokens( + tokenizer, + new_text, + add_special_tokens=False, ) - token_ids = encode_tokens(tokenizer, - text, - add_special_tokens=False) - matched_updates = { - modality: [match._origin for match in token_matches] - for modality, token_matches in mm_text_matches.items() - } + matched_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, update_idxs in match_result.items(): + for item_idx, update_idx in enumerate(update_idxs): + assert update_idx is not None, ( + "Failed to apply prompt replacement for " + f"mm_items[{modality!r}][{item_idx}]") + + matched_updates[modality].append( + [mm_prompt_updates[modality][item_idx][update_idx]]) placeholders = self._find_mm_placeholders( - matched_updates, - token_ids, - mm_item_counts, + new_token_ids, + dict(matched_updates), ) - return token_ids, text, placeholders + return new_token_ids, new_text, placeholders def _validate_mm_kwargs( self, @@ -1661,9 +1664,8 @@ def _maybe_apply_prompt_updates( if is_update_applied: mm_placeholders = self._find_mm_placeholders( - mm_prompt_updates, prompt_ids, - mm_item_counts, + mm_prompt_updates, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) @@ -1677,7 +1679,6 @@ def _maybe_apply_prompt_updates( ) = self._apply_prompt_updates( prompt_ids, mm_prompt_updates, - mm_item_counts, ) self._validate_mm_placeholders(mm_placeholders, mm_item_counts) diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index 2da9b4c72189..ea2efbdd8b52 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -209,7 +209,7 @@ def get_encoder_dummy_data( if processor.pad_dummy_encoder_prompt: num_tokens_to_pad = max(total_len, seq_len) - total_len encoder_prompt_token_ids.extend([0] * num_tokens_to_pad) - # NOTE: Whisper allows total_len > seq_len. + # NOTE: Whisper and Donut allows total_len > seq_len. elif total_len > seq_len and not envs.VLLM_USE_V1: # `max_num_batched_tokens` is defined by `SchedulerConfig` logger.warning_once( diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 58c71d865dc7..834b2189e4be 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -461,6 +461,8 @@ def run_dp_sharded_vision_model(image_input: torch.Tensor, num_chunks_per_rank, ...] vision_embeddings = vision_model(image_input_per_rank) + # Ensure tensor is contiguous before all_gather + vision_embeddings = vision_embeddings.contiguous() vision_embeddings = tensor_model_parallel_all_gather(vision_embeddings, dim=0) vision_embeddings = vision_embeddings[:num_chunks, ...] diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 317bc401a799..323ec591c50a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -171,7 +171,7 @@ class RocmPlatform(Platform): supported_quantization: list[str] = [ "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", - "quark", "ptpc_fp8", "mxfp4" + "quark", "ptpc_fp8", "mxfp4", "petit_nvfp4" ] @classmethod diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fe345bd8f0a2..674c820daba2 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -927,3 +927,25 @@ def get_model_path(model: Union[str, Path], revision: Optional[str] = None): from huggingface_hub import snapshot_download return snapshot_download(repo_id=model, **common_kwargs) + + +def get_hf_file_bytes(file_name: str, + model: Union[str, Path], + revision: Optional[str] = 'main') -> Optional[bytes]: + """Get file contents from HuggingFace repository as bytes.""" + file_path = try_get_local_file(model=model, + file_name=file_name, + revision=revision) + + if file_path is None: + hf_hub_file = hf_hub_download(model, + file_name, + revision=revision, + token=_get_hf_token()) + file_path = Path(hf_hub_file) + + if file_path is not None and file_path.is_file(): + with open(file_path, 'rb') as file: + return file.read() + + return None diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 4546f60aae67..b3f1977f26cf 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -49,12 +49,11 @@ def decode_tokens( `skip_special_tokens=None` means to use the backend's default settings. """ - decode_method = getattr(tokenizer, "_decode", tokenizer.decode) if skip_special_tokens is not None: - return decode_method(token_ids, - skip_special_tokens=skip_special_tokens) + return tokenizer.decode(token_ids, + skip_special_tokens=skip_special_tokens) - return decode_method(token_ids) + return tokenizer.decode(token_ids) def encode_tokens( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 996be1265667..5dd239c50f63 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -132,6 +132,11 @@ def has_nvidia_artifactory() -> bool: This checks connectivity to the kernel inference library artifactory which is required for downloading certain cubin kernels like TRTLLM FHMA. """ + # Since FLASHINFER_CUBIN_DIR defines the pre-downloaded cubins path, when + # it's true, we could assume the cubins are available. + if envs.VLLM_HAS_FLASHINFER_CUBIN: + return True + try: # Use a short timeout to avoid blocking for too long response = requests.get(FLASHINFER_CUBINS_REPOSITORY, timeout=5) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 9ed46331863c..973979fdf7df 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -483,6 +483,7 @@ def forward( attn_metadata: TorchSDPAMetadata, # type: ignore output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -497,7 +498,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TorchSDPABackendImpl") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index eca83b6d2ee4..6e7096de924c 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -430,6 +430,7 @@ def forward( attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -447,7 +448,7 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1e6e3f1d0abf..50819bb2bb94 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -12,6 +12,7 @@ MultiLevelCascadeAttentionWrapper) from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.utils import FP4Tensor from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -19,7 +20,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape) + QuantKey, kFp8StaticTensorSym, kNvfp4Quant) from vllm.platforms import current_platform from vllm.utils import cdiv, is_pin_memory_available from vllm.utils.flashinfer import (supports_trtllm_attention, @@ -40,6 +41,7 @@ FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 FP8_DTYPE = current_platform.fp8_dtype() +FP4_DTYPE = torch.uint8 logger = init_logger(__name__) @@ -653,14 +655,12 @@ def __init__( and num_heads % num_kv_heads == 0) self.bmm1_scale: Optional[float] = None self.bmm2_scale: Optional[float] = None + self.o_sf_scale: Optional[float] = None - def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: GroupShape): - supported_quant_type = (dtype == FP8_DTYPE and static - and group_shape == GroupShape.PER_TENSOR) + def fused_output_quant_supported(self, quant_key: QuantKey): return (self.support_trtllm_attn and self.kv_cache_dtype.startswith("fp8") - and supported_quant_type) + and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)) def forward( self, @@ -672,6 +672,7 @@ def forward( attn_metadata: FlashInferMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashInfer. @@ -705,19 +706,32 @@ def forward( if output_scale is None: assert attn_metadata.q_data_type != FP8_DTYPE, \ "Query can only be FP8 if output fusion happened." + assert output_block_scale is None, "output_block_scale "\ + "is not supported when fusion has not happened" else: assert attn_metadata.q_data_type == FP8_DTYPE, \ "Query must be FP8 when attn+quant fusion happened." assert (attn_metadata.prefill_use_trtllm and attn_metadata.decode_use_trtllm), "Must use TRT-LLM attn" - assert output.dtype == FP8_DTYPE, \ - "Output must be FP8 when attn+quant fusion happened." - # TRTLLM attn kernel requires o scale as a host scalar, store the - # o scale to host scalar in warmup run with cuda graph not enabled + if output.dtype == FP8_DTYPE: + assert output_block_scale is None, \ + "output_block_scale should not be provided for fp8 output" + elif output.dtype == FP4_DTYPE: + assert output_block_scale is not None, \ + "output_block_scale is required for nvfp4 output" + else: + raise ValueError(f"Unsupported output dtype: {output.dtype}") + + # TRTLLM attn kernel requires o scale to pass as a host scalar, + # store the o scale as a host scalar in warmup run with cuda graph + # not enabled if layer._o_scale_float is None: layer._o_scale_float = output_scale.cpu().item() - self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + if output.dtype == FP8_DTYPE: + self.bmm2_scale = self.bmm2_scale / layer._o_scale_float + elif output.dtype == FP4_DTYPE: + self.o_sf_scale = layer._o_scale_float # Insert FP8 quant for query num_tokens, num_heads, head_size = query.shape @@ -818,6 +832,16 @@ def forward( assert block_tables_prefill.is_contiguous() assert seq_lens_prefill.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[num_decode_tokens:], + scale=output_block_scale, + scale_start_index=num_decode_tokens, + original_shape=prefill_query.shape) + else: + assert self.o_sf_scale is None + out = output[num_decode_tokens:] + trtllm_batch_context_with_kv_cache( query=prefill_query, kv_cache=kv_cache_permute, @@ -833,7 +857,8 @@ def forward( cum_seq_lens_kv=attn_metadata.paged_kv_indptr_gpu, window_left=self.window_left, sinks=self.sinks, - out=output[num_decode_tokens:], + o_sf_scale=self.o_sf_scale, + out=out, ) if num_decode_tokens > 0: @@ -870,6 +895,16 @@ def forward( assert block_tables_decode.is_contiguous() assert seq_lens_decode.is_contiguous() + if output.dtype == FP4_DTYPE: + assert self.o_sf_scale is not None + out = FP4Tensor(data=output[:num_decode_tokens], + scale=output_block_scale, + scale_start_index=0, + original_shape=decode_query.shape) + else: + assert self.o_sf_scale is None + out = output[:num_decode_tokens] + trtllm_batch_decode_with_kv_cache( query=decode_query, kv_cache=kv_cache_permute, @@ -881,7 +916,8 @@ def forward( bmm2_scale=self.bmm2_scale, window_left=self.window_left, sinks=self.sinks, - out=output[:num_decode_tokens], + o_sf_scale=self.o_sf_scale, + out=out, ) return output_padded diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index abca981035d9..458562ebc8d2 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention layer with FlashAttention.""" -from collections import defaultdict +"""Attention layer with FlexAttention.""" + from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional, Union import torch +import torch._dynamo.decorators +import torch.nn.functional as F from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, _score_mod_signature, create_block_mask, @@ -16,13 +18,17 @@ is_quantized_kv_cache) from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.platforms import current_platform +from vllm.utils import cdiv, is_torch_equal_or_newer from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead") @@ -36,6 +42,23 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: torch.arange(len(counts), device=device, dtype=torch.int32), counts) +def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int): + difference = (multiple - (x.shape[dim] % multiple)) % multiple + if difference == 0: + return x + + dim = dim if dim >= 0 else x.ndim + dim + pad_list = [] + + for i in range(x.ndim - 1, dim - 1, -1): + if i == dim: + pad_list.extend([0, difference]) + else: + pad_list.extend([0, 0]) + + return F.pad(x, pad_list, mode="constant", value=0) + + class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True @@ -77,10 +100,10 @@ def use_cascade_attention(*args, **kwargs) -> bool: return False -# @torch.compile(fullgraph=True, mode="reduce-overhead") -def physical_to_logical_mapping( - block_table: torch.Tensor, - total_blocks: Optional[int] = None) -> torch.Tensor: +#@torch.compile(fullgraph=True, mode="reduce-overhead") +def physical_to_logical_mapping(block_table: torch.Tensor, + seq_lens: torch.Tensor, block_size: int, + total_blocks: int) -> torch.Tensor: """ Creates an inverse mapping from physical block locations to logical indices. @@ -114,13 +137,38 @@ def physical_to_logical_mapping( If a physical block is not mapped to by any logical block, its value in the result will be -1. + IMPORTANT: Garbage Value Protection + ──────────────────────────────────── + The block_table tensor may contain garbage values in unused positions + (beyond the actual sequence length). For example, if a sequence only + needs 3 blocks but the table has space for 8: + + block_table[0] = [10, 25, 7, 999, 1234, 888, ...] + ^^^^^^^^^^^^^^^^^^^^ + garbage values + + These garbage values can cause issues because: + 1. They may map to valid physical blocks by coincidence + 2. The scatter_ operation will assign them logical indices + 3. Later attention computations may incorrectly access these blocks + + To prevent this, we use seq_lens and block_size to mask out unused + entries, ensuring only valid block references are processed. Args: block_table: Tensor of shape [max_reqs, max_num_blocks] - mapping logical blocks to physical locations + mapping logical blocks to physical locations. May contain + garbage values in unused positions. + seq_lens: Tensor of sequence lengths for each request. Used to + determine how many blocks are actually needed per sequence. + block_size: Size of each block in tokens. Used with seq_lens to + compute the number of valid blocks per sequence. + total_blocks: Total number of physical blocks available Returns: - A tensor of shape [max_reqs, max_physical_block] + A tensor of shape [max_reqs, total_blocks] where each entry + physical_to_logical[req_id, physical_block] contains the logical + block index for that physical block, or -1 if unused. """ max_reqs, max_num_blocks = block_table.shape device = block_table.device @@ -130,17 +178,76 @@ def physical_to_logical_mapping( dtype=torch.long, device=device) - logical_indices = (torch.arange(max_num_blocks, - device=device).unsqueeze(0).expand( - max_reqs, -1)) + # Only process valid blocks to avoid garbage values + num_blocks_per_seq = cdiv(seq_lens, block_size) + mask = torch.arange(max_num_blocks, + device=device)[None, :] < num_blocks_per_seq[:, None] - physical_to_logical.scatter_(-1, block_table.to(torch.int64), - logical_indices) - # TODO Confirm - Seems like block 0 is always empty so we reset it manually + valid_block_table = torch.where(mask, block_table, 0) + valid_logical_indices = torch.where( + mask, + torch.arange(max_num_blocks, device=device)[None, :], 0) + + physical_to_logical.scatter_(-1, valid_block_table.to(torch.int64), + valid_logical_indices) + # NB - Seems like block 0 is always empty so we reset it manually physical_to_logical[:, 0] = -1 return physical_to_logical +def unique_static_unsorted( + x: torch.Tensor, + *, + M: int, # maximum positive value (0 is ā€œskip meā€) + dim: int = -1, # axis along which to deduplicate + ignored_val: int = 0, # value to ignore + pad_val: int = -1, # sentinel for unused slots +) -> torch.Tensor: + """ + - Keeps the first occurrence of each non-zero value while preserving order, + then left-packs those uniques and fills the rest with `pad_val`. + - Returns (packed, keep_mask) with the *same shape* as `x`. + - Requires that all values be in the range [0, M] + - Skips ignored_val + + Works on CPU or GPU, no Python loops, O(BĀ·N) time / O(BĀ·M) memory. + + Example: + x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1] + """ + if not (-1 <= pad_val <= M): + raise ValueError("`pad_val` must lie in [-1, M]") + + # ── move `dim` to the end so we can treat tensor as [B, N] ────────── + dim = dim % x.ndim + x_perm = x.movedim(dim, -1) # shape [..., N] + B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1] + x_flat = x_perm.reshape(B, N) # [B, N] + + device = x.device + idx = torch.arange(N, device=device).expand(B, N) # per-row indices + + # ── build first-occurrence table for every v ∈ [0, M] ─────────────── + first_idx = torch.full((B, M + 1), N, device=device) # ā€œāˆžā€ + # scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i)ā€ƒfor each i + first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin") + + # ── keep mask: first occurrence *and* value ≠ 0 ───────────────────── + keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat) + ) # [B, N] + + # ── left-pack uniques into a fresh tensor ─────────────────────────── + dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go + packed_flat = torch.full_like(x_flat, pad_val) + + rows, src_cols = torch.nonzero(keep, as_tuple=True) + packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols] + + # ── restore original layout ───────────────────────────────────────── + packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim) + return packed + + def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor): return q_idx >= kv_idx @@ -170,6 +277,7 @@ class FlexAttentionMetadata: num_reqs: int physical_to_logical: torch.Tensor decode_offset: torch.Tensor + num_blocks_per_seq: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. @@ -179,6 +287,46 @@ class FlexAttentionMetadata: block_mask: Optional[BlockMask] = None score_mod: Optional[_score_mod_signature] = None logical_mask_mod: _mask_mod_signature = causal_mask_mod + doc_ids: Optional[torch.Tensor] = None + direct_build: bool = True + q_block_size: int = 16 + kv_block_size: int = 16 + transformed_score_mod: Optional[_score_mod_signature] = None + + def _convert_physical_to_logical( + self, + request_lookup: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert physical indices to logical indices for both query and kv. + + NB is_within_lower_bound: do sequences start on block_boundaries? + + Returns: + tuple of (is_valid, logical_q_idx, logical_kv_idx) + """ + # Map query indices to corresponding request indices + q_req = request_lookup[q_idx] + + # Convert physical KV indices to logical indices + physical_kv_block = physical_kv_idx // self.block_size + physical_kv_offset = physical_kv_idx % self.block_size + logical_block_idx = self.physical_to_logical[q_req, physical_kv_block] + logical_kv_idx = (logical_block_idx * self.block_size + + physical_kv_offset) + + # Determine valid kv indices + live_block = logical_block_idx >= 0 + within_upper_bound = logical_kv_idx < self.seq_lens[q_req] + within_lower_bound = logical_kv_idx >= 0 + is_valid = live_block & within_upper_bound & within_lower_bound + + # Convert physical query indices to logical indices + local_q_idx = q_idx - self.query_start_loc[q_req] + logical_q_idx = local_q_idx + self.decode_offset[q_req] + + return is_valid, logical_q_idx, logical_kv_idx def get_causal_mask_mod(self) -> _mask_mod_signature: """Creates the mask_mod function for FlexAttention. @@ -191,11 +339,8 @@ def get_causal_mask_mod(self) -> _mask_mod_signature: With this info we create the "logical" indices that are passed to mask_mod functions. This allows mask mod functions to be agnostic to layout of the query and key/value tensors. - - TODO is_within_lower_bound: do sequences start on block_boundaries? """ - # Create a lookup mapping from query indices -> request number - request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + assert self.doc_ids is not None def final_mask_mod( b: torch.Tensor, @@ -203,27 +348,9 @@ def final_mask_mod( q_idx: torch.Tensor, physical_kv_idx: torch.Tensor, ) -> torch.Tensor: - # Map query indices to corresponding request indices - q_req = request_lookup[q_idx] - - # Convert physical KV indices to logical indices - physical_kv_block = physical_kv_idx // self.block_size - physical_kv_offset = physical_kv_idx % self.block_size - logical_block_idx = self.physical_to_logical[q_req, - physical_kv_block] - logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset # noqa: E501 - - # Determine valid kv indices - live_block = logical_block_idx >= 0 - within_upper_bound = logical_kv_idx < self.seq_lens[q_req] - within_lower_bound = logical_kv_idx >= 0 - - is_valid = live_block & within_upper_bound & within_lower_bound - - # Convert physical query indices to logical indices - local_q_idx = q_idx - self.query_start_loc[q_req] - logical_q_idx = local_q_idx + self.decode_offset[q_req] - + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + self.doc_ids, q_idx, physical_kv_idx) # Apply mask modification only for valid indices return torch.where( is_valid, @@ -236,7 +363,7 @@ def final_mask_mod( def get_bidirectional_mask_mod(self) -> _mask_mod_signature: """Creates the encoder mask_mod function for FlexAttention. - Since the encoder bidirectional attention doesn't run with + Since the encoder bidirectional attention doesn't run with KV cache, this function creates a mask based on the packed query sequences. """ @@ -253,6 +380,97 @@ def final_mask_mod( return final_mask_mod + def get_transformed_score_mod(self) -> Optional[_score_mod_signature]: + """Creates the transformed score_mod function for FlexAttention. + + This function wraps the user's score_mod to handle physical-to-logical + index conversion, similar to how get_mask_mod works for mask functions. + """ + if self.score_mod is None: + return None + + # Create a lookup mapping from query indices -> request number + request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) + user_score_mod = self.score_mod + + def transformed_score_mod( + score: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + q_idx: torch.Tensor, + physical_kv_idx: torch.Tensor, + ) -> torch.Tensor: + (is_valid, logical_q_idx, + logical_kv_idx) = self._convert_physical_to_logical( + request_lookup, q_idx, physical_kv_idx) + + return torch.where( + is_valid, + user_score_mod(score, + b, + h, + logical_q_idx, + logical_kv_idx, + physical_q=q_idx), -float('inf')) + + return transformed_score_mod + + def _build_block_mask_direct(self) -> BlockMask: + """Direct block mask construction for standard causal attention. + + This method constructs the block mask directly using + BlockMask.from_kv_blocks which is much more efficient than the + generic create_block_mask approach. + + The direct path works as follows: + 1. For each query token, fetch blocks from block_table using max_seq_len + (this fetches more blocks than needed for shorter sequences) + 2. Group query tokens into chunks of q_block_size + 3. For each group, deduplicate the blocks using unique_static_unsorted + 4. Create BlockMask using the deduplicated block indices + + Over-estimation occurs when a group of q_block_size tokens contains + multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for + each sequence represented in the group, even though individual query + tokens may only need a subset of those blocks based on causal masking + and their position. + + """ + page_to_block_ratio = self.kv_block_size // self.block_size + if page_to_block_ratio != 1: + raise ValueError( + f"FlexAttention currently requires the cache block size " + f"({self.block_size}) to be equal to the kv_block_size " + f"({self.kv_block_size}). Please check your model's " + f"configuration.") + + used_pages = self.block_table[ + self.doc_ids, :cdiv(self.max_seq_len, self.block_size)] + used_pages_padded = pad_to_multiple(used_pages, + multiple=self.q_block_size, + dim=0) + used_pages_padded = used_pages_padded.reshape( + used_pages_padded.shape[0] // self.q_block_size, -1) + used_pages_padded = used_pages_padded // page_to_block_ratio + kv_indices = unique_static_unsorted((used_pages_padded.long()), + M=self.num_blocks).to(torch.int32) + + kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) + block_mask_kwargs = { + "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), + "kv_num_blocks": kv_num_blocks[None, None], + "kv_indices": kv_indices[None, None], + "full_kv_num_blocks": None, + "full_kv_indices": None, + "BLOCK_SIZE": (self.q_block_size, self.kv_block_size), + "mask_mod": self.mask_mod, + } + + # compute_q_blocks parameter is available in PyTorch 2.9+ + if is_torch_equal_or_newer("2.9.0.dev0"): + block_mask_kwargs["compute_q_blocks"] = False + return BlockMask.from_kv_blocks(**block_mask_kwargs) + def build_block_mask(self) -> BlockMask: if self.causal: mask_mod = self.get_causal_mask_mod() @@ -267,6 +485,7 @@ def build_block_mask(self) -> BlockMask: self.num_actual_tokens, kv_len, device=self.block_table.device, + BLOCK_SIZE=(self.q_block_size, self.kv_block_size), ) def __post_init__(self): @@ -275,8 +494,21 @@ def __post_init__(self): assert self.cu_prefix_query_lens is None, "Not implemented yet." assert self.prefix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet." + # Create a lookup mapping from query indices -> request number + self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.num_blocks = self.total_cache_tokens // self.block_size - self.block_mask = self.build_block_mask() + + if self.causal: + self.mask_mod = self.get_causal_mask_mod() + else: + self.mask_mod = self.get_bidirectional_mask_mod() + + self.transformed_score_mod = self.get_transformed_score_mod() + + if self.direct_build and self.causal: + self.block_mask = self._build_block_mask_direct() + else: + self.block_mask = self.build_block_mask() class FlexAttentionMetadataBuilder( @@ -287,15 +519,24 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.model_config = vllm_config.model_config self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config + self.device = device self.num_heads_q = self.model_config.get_num_attention_heads( - vllm_config.parallel_config) + self.parallel_config) self.num_heads_kv = self.model_config.get_num_kv_heads( - vllm_config.parallel_config) + self.parallel_config) self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.device = device + self.direct_build: bool = is_torch_equal_or_newer("2.9.0.dev0") + self.q_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + self.kv_block_size: int = 16 if is_torch_equal_or_newer( + "2.9.0.dev0") else 128 + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, common_prefix_len: int, @@ -310,6 +551,7 @@ def build(self, seq_lens = common_attn_metadata.seq_lens block_table_tensor = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping + num_blocks_per_seq = cdiv(seq_lens, self.block_size) use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -320,12 +562,15 @@ def build(self, block_size = self.kv_cache_spec.block_size max_possible_seq_len = self.model_config.max_model_len - total_cache_tokens = self.cache_config.num_gpu_blocks * block_size + num_gpu_blocks = self.cache_config.num_gpu_blocks + + assert num_gpu_blocks is not None, \ + "FlexAttention requires num_gpu_blocks to be set" + total_cache_tokens = (num_gpu_blocks * block_size) inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.cache_config.num_gpu_blocks) + block_table_tensor, seq_lens, block_size, num_gpu_blocks) - # Get the original offset tensor offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( self.device, non_blocking=True) @@ -349,9 +594,16 @@ def build(self, physical_to_logical=inverse_block_table, total_cache_tokens=total_cache_tokens, decode_offset=offset_tensor, + num_blocks_per_seq=num_blocks_per_seq, + direct_build=self.direct_build, + q_block_size=self.q_block_size, + kv_block_size=self.kv_block_size, ) return out + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class FlexAttentionImpl(AttentionImpl): sliding_window: Optional[tuple[int, int]] @@ -370,6 +622,7 @@ def __init__( logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, + **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -398,6 +651,7 @@ def __init__( raise NotImplementedError( "FlexAttention does not support logits soft cap yet.") + assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads if kv_sharing_target_layer_name is not None: @@ -405,7 +659,6 @@ def __init__( "FlexAttention does not support kv sharing yet.") FlexAttentionBackend.validate_head_size(head_size) - if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "FlexAttention does not support quantized kv-cache. Yet") @@ -428,6 +681,7 @@ def forward( attn_metadata: FlexAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FLexAttention. @@ -441,7 +695,7 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlexAttentionImpl") @@ -492,35 +746,48 @@ def forward( # Doesn't work for now -> constraint violation # torch._dynamo.try_mark_dynamic(query, 2) - # default M=64, N=64 may run out of shared memory on some GPUs - # TODO: Explicit configs for each GPU? - # Not sure how to calculate the shared memory requirement - extra_kernel_options = defaultdict[str, int](lambda: 64) - if query.dtype == torch.float32: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 - if current_platform.is_cuda(): - device_props = torch.cuda.get_device_properties() - max_shared_memory = device_props.shared_memory_per_block_optin - if max_shared_memory < 144 * 1024: - extra_kernel_options["BLOCK_M"] //= 2 - extra_kernel_options["BLOCK_N"] //= 2 + assert attn_metadata.block_mask is not None + block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE + kernel_options = get_kernel_options(query, block_m, block_n, + attn_metadata.direct_build) out = flex_attention_compiled( query, key_tensor, value_tensor, - attn_metadata.score_mod, + attn_metadata.transformed_score_mod, attn_metadata.block_mask, self.scale, enable_gqa=enable_gqa, - kernel_options={ - "FORCE_USE_FLEX_ATTENTION": True, - **extra_kernel_options - }, + kernel_options=kernel_options, ) # Flex doesn't have an out variant today, rely on epilogue fusion out = out.permute(0, 2, 1, 3).squeeze(0) output[:num_actual_tokens, :, :].copy_(out) return output + + +def get_kernel_options(query, block_m, block_n, + use_direct_build: bool) -> dict[str, Union[int, bool]]: + kernel_options: dict[str, Union[int, bool]] = { + "FORCE_USE_FLEX_ATTENTION": True, + } + if use_direct_build: + kernel_options["BLOCK_M"] = block_m + kernel_options["BLOCK_N"] = block_n + return kernel_options + else: + kernel_options["BLOCK_M"] = 64 + kernel_options["BLOCK_N"] = 64 + if query.dtype == torch.float32: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + # if current_platform.is_cuda(): + if torch.cuda.is_available(): + device_props = torch.cuda.get_device_properties() + max_shared_memory = device_props.shared_memory_per_block_optin + if max_shared_memory < 144 * 1024: + kernel_options["BLOCK_M"] = 32 + kernel_options["BLOCK_N"] = 32 + return kernel_options diff --git a/vllm/v1/attention/backends/mamba_selectors.py b/vllm/v1/attention/backends/mamba_selectors.py deleted file mode 100644 index fb1844508211..000000000000 --- a/vllm/v1/attention/backends/mamba_selectors.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.attention.backends.abstract import AttentionBackend -from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend -from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend -from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.short_conv_attn import ( - ShortConvAttentionBackend) - - -def get_mamba_attn_backend(mamba_type: str) -> type[AttentionBackend]: - if mamba_type == "mamba1": - return Mamba1AttentionBackend - if mamba_type == "mamba2": - return Mamba2AttentionBackend - if mamba_type == "linear_attention": - return LinearAttentionBackend - if mamba_type == "short_conv": - return ShortConvAttentionBackend - - raise NotImplementedError(f"Mamba Attention type {mamba_type} is not " - "supported yet.") diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5c0ff943794a..ce45b34f6435 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -1138,10 +1138,11 @@ def forward( attn_metadata: M, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for MLACommonImpl") diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 3eb4a0e7a574..fd97db0abb84 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -227,6 +227,7 @@ def forward( attn_metadata: PallasMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -239,7 +240,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for PallasAttentionBackendImpl") diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index b9ff113573a1..403ad8e88a95 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -421,6 +421,7 @@ def forward( attn_metadata: AiterFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -438,7 +439,7 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for FlashAttentionImpl") diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 2a0c52377cc7..c93223a34083 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -354,6 +354,7 @@ def forward( attn_metadata: TreeAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with TreeAttention. @@ -368,7 +369,7 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TreeAttentionImpl") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c69dd8415f92..b12036c59979 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -277,6 +277,7 @@ def forward( attn_metadata: FlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -291,7 +292,7 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for TritonAttentionImpl") diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index b305bc153908..e0eb7d8be974 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -322,6 +322,7 @@ def forward( attn_metadata: XFormersAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with XFormers. @@ -336,7 +337,7 @@ def forward( """ assert output is not None, "Output tensor must be provided." - if output_scale is not None: + if output_scale is not None or output_block_scale is not None: raise NotImplementedError( "fused output quantization is not yet supported" " for XFormersAttentionImpl") diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py index 0b9da60c67de..70af419fcb95 100644 --- a/vllm/v1/core/encoder_cache_manager.py +++ b/vllm/v1/core/encoder_cache_manager.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import OrderedDict from collections.abc import Mapping from typing import TYPE_CHECKING @@ -31,34 +33,52 @@ class EncoderCacheManager: within requests, allowing for fine-grained memory management and enabling chunked processing of multimodal inputs. - Note that no caching is shared between requests at this time. If the same - input is used across multiple requests, it will be reprocessed for each - request. + Cache is enabled to share embeddings of same multimodal data + item (identified by their hash value) between different requests, + and eviction takes place at allocation time when there's no free + space for new embeddings. + Oldest cached embeddings with no request referenced will be first evicted. Args: cache_size: Limit the size of the cache, measured by the number of tokens from the input sequence. Attributes: - cache_size: Total cache capacity in encoder tokens - num_free_slots: Current available cache capacity in encoder tokens - cached: Mapping from request_id to set of cached input_ids for that - request - freed: List of (request_id, input_id) pairs that were recently freed. - This is cleared after every call to get_freed_ids(). + cache_size: Total cache capacity in encoder tokens. + num_free_slots: Current available cache capacity in encoder tokens. + num_freeable_slots: Capacity that can be immediately reclaimed by + evicting entries with zero references (in encoder tokens). + cached: Mapping from mm_hash to a set of request IDs that currently + reference the cached entry. If the set is empty, the entry exists + but is not referenced by any request and is eligible for + reclamation. + freeable: List of tuples (mm_hash, num_tokens) representing entries + whose no current running request is needed and that can be freed to + make space when needed. + freed: List of mm_hash strings that were actually evicted since the + last call to get_freed_mm_hashes(). This list is cleared on return. """ def __init__(self, cache_size: int): self.cache_size = cache_size self.num_free_slots = cache_size - # req_id -> cached input ids - self.cached: dict[str, set[int]] = {} - # list of [req_id, input_id] - self.freed: list[tuple[str, int]] = [] + self.num_freeable_slots = cache_size + + # mm_hash of mm_data => ids of requests that reference the mm_data + self.cached: dict[str, set[str]] = {} + + # mm_hash of mm_data => num_encoder_tokens of the mm_data + self.freeable: OrderedDict[str, int] = OrderedDict() + self.freed: list[str] = [] - def has_cache(self, request: Request, input_id: int) -> bool: + def check_and_update_cache(self, request: Request, input_id: int) -> bool: """Check if encoder output for a specific multimodal input is cached. + If the encoder output is cached, update `cached` to add the request id + to the set of request ids that reference the cached encoder output. + If the encoder output was previously not referenced by any request, + update `freeable` and `num_freeable_slots` accordingly. + Args: request: The request containing the multimodal input input_id: Index of the multimodal input within the request @@ -66,103 +86,151 @@ def has_cache(self, request: Request, input_id: int) -> bool: Returns: True if the encoder output for this input is already cached """ - req_id = request.request_id - return req_id in self.cached and input_id in self.cached[req_id] - - def can_allocate(self, request: Request, input_id: int) -> bool: - """Check if there's sufficient cache space for a multimodal input. + mm_hash = request.mm_hashes[input_id] + # Not cached at all + if mm_hash not in self.cached: + return False + + # Cached but currently not referenced by any request + if not self.cached[mm_hash]: + num_tokens = self.freeable.pop(mm_hash) + self.num_freeable_slots -= num_tokens + + self.cached[mm_hash].add(request.request_id) + return True + + def try_allocate(self, request: Request, input_id: int, + encoder_budget: int) -> bool: + """Check if there's sufficient cache space for a multimodal input. + If there is, return True and update EncoderCacheManager state. + + If there is not enough free space in `num_free_slots` but there is + enough reclaimable space in `num_freeable_slots`, entries will be + evicted from `freeable` (their mm_hash appended to `freed`) until + enough space is available, and then this method returns True. + Older entries are evicted first. + + Returns False only if the requested number of tokens exceeds both + the free and reclaimable capacities combined. Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + request: The request containing the multimodal input. + input_id: Index of the multimodal input within the request. Returns: - True if there's enough free cache space to store the encoder output - for this multimodal input + True if there's enough capacity to hold the encoder output for this + input (possibly after reclaiming `freeable` entries); otherwise + False. + + Note: This method does not allocate physical memory for the encoder + output but only the state of EncoderCacheManager. """ num_tokens = request.get_num_encoder_tokens(input_id) - return num_tokens <= self.num_free_slots + + # Not enough compute budget + if num_tokens > encoder_budget: + return False + + # Enough free slots + if num_tokens <= self.num_free_slots: + self.num_free_slots -= num_tokens + self.num_freeable_slots -= num_tokens + return True + + # Not enough reclaimable slots + if num_tokens > self.num_freeable_slots: + return False + + # Not enough free slots but enough reclaimable slots + # NOTE: Eviction takes place here, but physical memory is not freed + # until model runner is notified by the scheduler output. + while num_tokens > self.num_free_slots: + mm_hash, num_free_token = self.freeable.popitem(last=False) + del self.cached[mm_hash] + self.freed.append(mm_hash) + self.num_free_slots += num_free_token + self.num_free_slots -= num_tokens + self.num_freeable_slots -= num_tokens + return True def allocate(self, request: Request, input_id: int) -> None: """Allocate cache space for a multimodal input's encoder output. - This method reserves cache space for storing the encoder output of - the specified multimodal input. The actual encoder output storage - happens in the model runner, but this method ensures the cache - manager tracks the allocation. - - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input within the request + This reserves cache space for storing the encoder output of the + specified multimodal input. The actual encoder output storage happens in + the model runner; this method updates the manager's bookkeeping. Note: - This method assumes can_allocate() returned True for the same - request and input_id. It will reduce available cache space. + This method assumes try_allocate() returned True for the same input. """ - req_id = request.request_id - if req_id not in self.cached: - self.cached[req_id] = set() - self.cached[req_id].add(input_id) - self.num_free_slots -= request.get_num_encoder_tokens(input_id) + # Encoder cache space budget should be already updated for the + # multimodal input and non-negative after try_allocate() is called. + assert self.num_free_slots >= 0 + assert self.num_freeable_slots >= 0 + + mm_hash = request.mm_hashes[input_id] + request_id = request.request_id + if mm_hash not in self.cached: + self.cached[mm_hash] = set() + + self.cached[mm_hash].add(request_id) def get_cached_input_ids(self, request: Request) -> set[int]: """Get all cached multimodal input IDs for a request. - Args: - request: The request to query - - Returns: - Set of input_ids that have cached encoder outputs for this request. - Returns empty set if no inputs are cached for this request. + Returns the set of input IDs whose `mm_hash` exists in the cache map. + This includes entries that are currently unreferenced (and thus present + in `freeable`); for such entries, freeing for this request will be a + no-op. """ - return self.cached.get(request.request_id, set()) + return { + input_id + for input_id in range(len(request.mm_hashes)) + if request.mm_hashes[input_id] in self.cached + } def free_encoder_input(self, request: Request, input_id: int) -> None: - """Free cache space for a single multimodal input's encoder output. + """Free the request's reference to the encoder input (`mm_data`) - This method is called when: - - The encoder output has been fully consumed by the decoder and is - no longer needed (e.g., in vision-language models after image - tokens are processed) - - A request is being cancelled or aborted + When the reference set for the corresponding `mm_hash` becomes empty, + the entry is appended to `freeable` and `num_freeable_slots` is + increased by the number of encoder tokens for that input. - Args: - request: The request containing the multimodal input - input_id: Index of the multimodal input to free from cache + The entry is NOT physically freed until capacity is needed (e.g., by + `can_allocate`). """ req_id = request.request_id - if req_id not in self.cached: + mm_hash = request.mm_hashes[input_id] + # The mm_hash not in cache or the req_id set is empty + if not self.cached.get(mm_hash, None): return - - self.cached[req_id].discard(input_id) - if len(self.cached[req_id]) == 0: - del self.cached[req_id] - self.num_free_slots += request.get_num_encoder_tokens(input_id) - self.freed.append((req_id, input_id)) + self.cached[mm_hash].discard(req_id) + if not self.cached[mm_hash]: + num_tokens = request.get_num_encoder_tokens(input_id) + self.freeable[mm_hash] = num_tokens + self.num_freeable_slots += num_tokens def free(self, request: Request) -> None: - """Free all cached encoder outputs for a request. + """Free all encoder input cache reference held by *request*. - This method is typically called when a request is finished, cancelled, - or aborted, and all its encoder outputs should be freed from cache. + For each cached input ID, `free_encoder_input` is invoked. + The data stays in memory until eviction is triggered by a future + attempt allocation called by 'can_allocate'. - Args: - request: The request whose encoder outputs should be freed + Typically called when a request is finished, cancelled, or aborted. """ input_ids = self.get_cached_input_ids(request).copy() for input_id in input_ids: self.free_encoder_input(request, input_id) - def get_freed_ids(self) -> list[tuple[str, int]]: + def get_freed_mm_hashes(self) -> list[str]: """Get and clear the list of recently freed encoder cache entries. - This method returns all encoder cache entries that were freed since - the last call to this method. It's used by the scheduler to notify - workers about which encoder outputs can be removed from their caches. - Returns: - List of (request_id, input_id) tuples that were freed since the - last call. The internal freed list is cleared after this call. + List of mm_hash strings that were actually evicted since the last + call to be used by the scheduler to notify workers about which + encoder outputs can be removed from their caches. The internal + list is cleared after this call. """ freed = self.freed self.freed = [] @@ -177,16 +245,11 @@ def compute_encoder_budget( """Compute the encoder cache budget based on the model and scheduler configurations. - Args: - model_config: Model configuration. - scheduler_config: Scheduler configuration. - mm_registry: Provides information about the token cost. - Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ if mm_registry.supports_multimodal_inputs(model_config): max_tokens_by_modality = mm_registry \ @@ -231,10 +294,10 @@ def compute_mm_encoder_budget( non-text modality. Returns: - - Compute budget for encoder execution, in unit of number of tokens - in the input sequence. - - Space budget for encoder cache size, in unit of number of tokens - in the input sequence. + - Compute budget for encoder execution, measured in number of tokens + from the input sequence. + - Space budget for encoder cache size, measured in number of tokens + from the input sequence. """ if not max_tokens_by_modality: diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 9ba7ec9d9693..b5cd6c5c8af5 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -143,9 +143,9 @@ class SchedulerOutput: # steps. This is used to notify the workers about the finished requests # so that they can free the cached states for those requests. finished_req_ids: set[str] - # list of (req_id, encoder_input_index) tuples. - # Used to free the encoder cache. - free_encoder_input_ids: list[tuple[str, int]] + # list of mm_hash strings associated with the encoder outputs to be + # freed from the encoder cache. + free_encoder_mm_hashes: list[str] # Dict of request ids to their index within the batch # for filling the next token bitmask diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 60d5720b6bef..956e23afa0d7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -252,6 +252,7 @@ def schedule(self) -> SchedulerOutput: preempted_req = self.running.pop() self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 if self.log_stats: @@ -550,7 +551,8 @@ def schedule(self) -> SchedulerOutput: # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, ) @@ -698,7 +700,7 @@ def _try_schedule_encoder_inputs( # in the decoder's KV cache. continue - if self.encoder_cache_manager.has_cache(request, i): + if self.encoder_cache_manager.check_and_update_cache(request, i): # The encoder input is already computed and cached. continue @@ -712,8 +714,8 @@ def _try_schedule_encoder_inputs( num_new_tokens = start_pos - num_computed_tokens break - if (not self.encoder_cache_manager.can_allocate(request, i) - or num_encoder_tokens > encoder_budget): + if not self.encoder_cache_manager.try_allocate( + request, i, encoder_budget): # The encoder cache is full or the encoder budget is exhausted. # NOTE(woosuk): We assume that the encoder input tokens should # be processed altogether, as the encoder usually uses diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 69f8e531e01b..300b0713b2ff 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -21,6 +21,8 @@ from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) +from vllm.v1.structured_output.backend_lm_format_enforcer import ( + validate_structured_output_request_lm_format_enforcer) from vllm.v1.structured_output.backend_outlines import ( validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( @@ -200,6 +202,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None: elif engine_level_backend == "outlines": # outlines backend validate_structured_output_request_outlines(params) + elif engine_level_backend == "lm-format-enforcer": + # lm format enforcer backend + validate_structured_output_request_lm_format_enforcer(params) else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. @@ -389,7 +394,7 @@ def _validate_model_input( assert isinstance(mm_processor, EncDecMultiModalProcessor) if mm_processor.pad_dummy_encoder_prompt: - return # Skip encoder length check for Whisper + return # Skip encoder length check for Whisper and Donut if model_config.is_multimodal_model: suggestion = ( diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 24387ab79390..00dd757489ca 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -53,29 +53,37 @@ def update_state(self, batch_update: Optional[BatchUpdate]): # Process added requests. for index, params, _, _ in batch_update.added: min_p = params.min_p - if self.min_p_cpu[index] != min_p: + min_p_before = self.min_p_cpu[index] + if min_p_before != min_p: needs_update = True self.min_p_cpu[index] = min_p - if min_p: - self.min_p_count += 1 + if min_p and not min_p_before: + self.min_p_count += 1 + elif not min_p and min_p_before: + self.min_p_count -= 1 if self.min_p_count: # Process removed requests. - needs_update |= bool(batch_update.removed) - for index in batch_update.removed: - if self.min_p_cpu[index]: - self.min_p_count -= 1 + if batch_update.removed: + needs_update = True + for index in batch_update.removed: + if self.min_p_cpu[index]: + self.min_p_cpu[index] = 0 + self.min_p_count -= 1 - # Process moved requests, unidirectional (a->b) and swap (a<->b) + # Process moved requests, unidirectional (a->b) and swap (a<->b). for adx, bdx, direct in batch_update.moved: - change = (min_p_a := - self.min_p_cpu[adx]) != (min_p_b := - self.min_p_cpu[bdx]) - needs_update |= change - if change: + min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx] + if min_p_a != min_p_b: + needs_update = True self.min_p_cpu[bdx] = min_p_a if direct == MoveDirectionality.SWAP: self.min_p_cpu[adx] = min_p_b + if direct == MoveDirectionality.UNIDIRECTIONAL: + if min_p_a: + self.min_p_cpu[adx] = 0 + if min_p_b: + self.min_p_count -= 1 # Update tensors if needed. size = batch_update.batch_size diff --git a/vllm/v1/sample/logits_processor/state.py b/vllm/v1/sample/logits_processor/state.py index 0f58b5249695..31cece58c7db 100644 --- a/vllm/v1/sample/logits_processor/state.py +++ b/vllm/v1/sample/logits_processor/state.py @@ -50,6 +50,10 @@ def __init__( self.added = added or [] self._is_removed_sorted = False + # Used to track changes in the pooling case + # where we don't populate the added list. + self.batch_changed = False + def _ensure_removed_sorted(self) -> None: """Sort removed request indices in descending order. @@ -80,6 +84,7 @@ def removed_append(self, index: int) -> None: raise RuntimeError("Cannot register new removed request after" " self.removed has been read.") self._removed.append(index) + self.batch_changed = True def has_removed(self) -> bool: return bool(self._removed) @@ -98,9 +103,15 @@ def pop_removed(self) -> Optional[int]: return self._removed.pop() return None - def _is_update(self) -> bool: - """True if there is a batch state change""" - return any((self._removed, self.moved, self.added)) + def reset(self) -> bool: + """Returns True if there were any changes to the batch.""" + self._is_removed_sorted = False + self._removed.clear() + self.moved.clear() + self.added.clear() + batch_changed = self.batch_changed + self.batch_changed = False + return batch_changed def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """Generate a logitsprocs batch update data structure and reset @@ -114,7 +125,8 @@ def get_and_reset(self, batch_size: int) -> Optional[BatchUpdate]: """ # Reset removal-sorting logic self._is_removed_sorted = False - if not self._is_update(): + self.batch_changed = False + if not any((self._removed, self.moved, self.added)): # No update; short-circuit return None # Build batch state update diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3bafa61044ab..57854cc11204 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -108,6 +108,14 @@ def grammar_init(self, request: Request) -> None: tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "lm-format-enforcer": + from vllm.v1.structured_output.backend_lm_format_enforcer import ( # noqa: E501 + LMFormatEnforcerBackend) + self.backend = LMFormatEnforcerBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") diff --git a/vllm/v1/structured_output/backend_lm_format_enforcer.py b/vllm/v1/structured_output/backend_lm_format_enforcer.py new file mode 100644 index 000000000000..2279a1c8c8a0 --- /dev/null +++ b/vllm/v1/structured_output/backend_lm_format_enforcer.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import ast +import json +from dataclasses import dataclass, field +from functools import lru_cache +from typing import TYPE_CHECKING + +import torch +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import lmformatenforcer + import lmformatenforcer.integrations.vllm as lmfe_vllm +else: + lmformatenforcer = LazyLoader("lmformatenforcer", globals(), + "lmformatenforcer") + lmfe_vllm = LazyLoader("lmformatenforcer.integrations.vllm", globals(), + "lmformatenforcer.integrations.vllm") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase, + vocab_size: int) -> lmfe_vllm.TokenEnforcerTokenizerData: + return lmfe_vllm.build_vllm_token_enforcer_tokenizer_data( + tokenizer, use_bitmask=True, vocab_size=vocab_size) + + +@dataclass +class LMFormatEnforcerGrammar(StructuredOutputGrammar): + token_enforcer: lmformatenforcer.TokenEnforcer + current_tokens_prefix: list[int] = field(default_factory=list) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + original_len = len(self.current_tokens_prefix) + for token in tokens: + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix).is_token_allowed(token): + # Rollback partial updates to ensure atomicity. + del self.current_tokens_prefix[original_len:] + return False + self.current_tokens_prefix.append(token) + return True + + def validate_tokens(self, tokens: list[int]) -> list[int]: + for prefix_length in range(len(tokens)): + prefix = tokens[:prefix_length] + next_token = tokens[prefix_length] + if not self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix + + prefix).is_token_allowed(next_token): + break + else: + return tokens + + return tokens[:prefix_length] + + def rollback(self, num_tokens: int) -> None: + self.current_tokens_prefix = self.current_tokens_prefix[:-num_tokens] + + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + allowed_tokens = self.token_enforcer.get_allowed_tokens( + self.current_tokens_prefix) + bitmask[batch_index] = allowed_tokens.allowed_tokens + + def is_terminated(self) -> bool: + # We are considered terminated if the prefix ends with eos_token_id + return_value = len( + self.current_tokens_prefix) > 0 and self.current_tokens_prefix[ + -1] == self.token_enforcer.eos_token_id + return return_value + + def reset(self): + self.current_tokens_prefix = [] + + +@dataclass +class LMFormatEnforcerBackend(StructuredOutputBackend): + + def __post_init__(self): + self.tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + self.tokenizer, self.vocab_size) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + character_level_parser: lmformatenforcer.CharacterLevelParser + if request_type == StructuredOutputOptions.JSON: + spec_dict = json.loads(grammar_spec) + character_level_parser = lmformatenforcer.JsonSchemaParser( + spec_dict) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + character_level_parser = lmformatenforcer.JsonSchemaParser(None) + elif request_type == StructuredOutputOptions.REGEX: + character_level_parser = lmformatenforcer.RegexParser(grammar_spec) + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + character_level_parser = lmformatenforcer.UnionParser( + [lmformatenforcer.StringParser(choice) for choice in choices]) + else: + raise ValueError( + "Invalid request type for LM Format Enforcer backend" + f"({request_type!s})") + max_rollback_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None else 0) + + if max_rollback_tokens > 0: + raise ValueError( + "LM Format Enforcer backend does not support speculative tokens" + ) + + token_enforcer = lmformatenforcer.TokenEnforcer( + tokenizer_data=self.tokenizer_data, + parser=character_level_parser, + ) + return LMFormatEnforcerGrammar(token_enforcer) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +def validate_structured_output_request_lm_format_enforcer( + params: SamplingParams): + if params.guided_decoding is None: + return + + gd_params = params.guided_decoding + + if gd_params.regex: + return + elif gd_params.json: + if isinstance(gd_params.json, str): + try: + # make sure schema is valid json + json.loads(gd_params.json) + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + try: + json.dumps(gd_params.json) + except Exception as e: + raise ValueError( + f"Error serializing guided decoding jsonschema: {e}" + ) from e + return + elif gd_params.choice: + return + elif gd_params.grammar: + raise ValueError("LM Format Enforcer guided decoding backend " + "does not support grammar specifications") diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index e45d1ef31f60..284af6bfedce 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -33,6 +33,7 @@ class CachedRequestState: prompt_token_ids: list[int] mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] + mm_hashes: list[str] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator] @@ -65,8 +66,7 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: return self.prompt_token_ids[idx] - else: - return self.output_token_ids[idx - self.num_prompt_tokens] + return self.output_token_ids[idx - self.num_prompt_tokens] class InputBatch: @@ -261,30 +261,27 @@ def _register_add_request(self, request: "CachedRequestState") -> int: Not applicable to pooling models. """ - # Detailed added request metadata is only required for non-pooling - # models, to support logitsprocs - assert request.sampling_params - # Fill the next empty index if there is one. if (new_req_index := self.batch_update_builder.pop_removed()) is None: # Append to end otherwise. new_req_index = self.num_reqs assert new_req_index < self.max_num_reqs - self.batch_update_builder.added.append( - (new_req_index, request.sampling_params, request.prompt_token_ids, - request.output_token_ids)) + self.batch_update_builder.batch_changed = True + if request.sampling_params: + # Detailed added request metadata is only required for non-pooling + # models, to support logitsprocs. + self.batch_update_builder.added.append( + (new_req_index, request.sampling_params, + request.prompt_token_ids, request.output_token_ids)) + return new_req_index def add_request( self, request: "CachedRequestState", ) -> int: - if not self.is_pooling_model: - # New request index bookkeeping for autoregressive models. - req_index = self._register_add_request(request) - else: - req_index = self.num_reqs + req_index = self._register_add_request(request) req_id = request.req_id if req_index == len(self._req_ids): @@ -389,7 +386,7 @@ def add_request( self.logits_processing_needs_token_ids[req_index] = ( pooling_params.requires_token_ids) else: - raise NotImplementedError(request) + raise NotImplementedError("Unrecognized request type") # Add request lora ID if request.lora_request: @@ -419,13 +416,25 @@ def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: return None - if not self.is_pooling_model: - # Autoregressive models require bookkeeping of removed requests to - # support logitsprocs. - self.batch_update_builder.removed_append(req_index) + + self.batch_update_builder.removed_append(req_index) self._req_ids[req_index] = None self.req_output_token_ids[req_index] = None + # LoRA + lora_id = self.request_lora_mapping[req_index] + if lora_id != 0: + lora_req_ids = self.lora_id_to_request_ids[lora_id] + lora_req_ids.discard(req_id) + if not lora_req_ids: + del self.lora_id_to_request_ids[lora_id] + del self.lora_id_to_lora_request[lora_id] + self.request_lora_mapping[req_index] = 0 + + if self.is_pooling_model: + self.pooling_params.pop(req_id, None) + return req_index + self.greedy_reqs.discard(req_id) self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) @@ -439,29 +448,14 @@ def remove_request(self, req_id: str) -> Optional[int]: self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) - # LoRA - lora_id = self.request_lora_mapping[req_index] - if lora_id != 0: - lora_req_ids = self.lora_id_to_request_ids[lora_id] - lora_req_ids.discard(req_id) - if not lora_req_ids: - del self.lora_id_to_request_ids[lora_id] - del self.lora_id_to_lora_request[lora_id] - self.request_lora_mapping[req_index] = 0 - self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: # False means we don't fill with -inf. self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.bad_words_token_ids.pop(req_index, None) - self.pooling_params.pop(req_id, None) return req_index def swap_states(self, i1: int, i2: int) -> None: - # For autoregressive models, track detailed request reordering info - # to support logitsprocs - self.batch_update_builder.moved.append( - (i1, i2, MoveDirectionality.SWAP)) old_id_i1 = self._req_ids[i1] old_id_i2 = self._req_ids[i2] self._req_ids[i1], self._req_ids[i2] =\ @@ -479,18 +473,6 @@ def swap_states(self, i1: int, i2: int) -> None: self.num_prompt_tokens[i2], self.num_prompt_tokens[i1] self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\ self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1] - self.temperature_cpu[i1], self.temperature_cpu[i2] =\ - self.temperature_cpu[i2], self.temperature_cpu[i1] - self.top_p_cpu[i1], self.top_p_cpu[i2] =\ - self.top_p_cpu[i2], self.top_p_cpu[i1] - self.top_k_cpu[i1], self.top_k_cpu[i2] =\ - self.top_k_cpu[i2], self.top_k_cpu[i1] - self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\ - self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] - self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\ - self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] - self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\ - self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] # NOTE: the following is unsafe # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ @@ -501,18 +483,41 @@ def swap_states(self, i1: int, i2: int) -> None: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp - swap_dict_values(self.generators, i1, i2) - swap_dict_values(self.bad_words_token_ids, i1, i2) + self.block_table.swap_row(i1, i2) - self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ + self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ self.request_lora_mapping[i2], self.request_lora_mapping[i1] + if self.is_pooling_model: + # Sampling and logits parameters don't apply to pooling models. + return + + # For autoregressive models, track detailed request reordering info + # to support logitsprocs. + self.batch_update_builder.moved.append( + (i1, i2, MoveDirectionality.SWAP)) + + self.temperature_cpu[i1], self.temperature_cpu[i2] = \ + self.temperature_cpu[i2], self.temperature_cpu[i1] + self.top_p_cpu[i1], self.top_p_cpu[i2] = \ + self.top_p_cpu[i2], self.top_p_cpu[i1] + self.top_k_cpu[i1], self.top_k_cpu[i2] = \ + self.top_k_cpu[i2], self.top_k_cpu[i1] + self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] = \ + self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1] + self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] = \ + self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1] + self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] = \ + self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1] + + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) + if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[i1], \ self.allowed_token_ids_mask_cpu_tensor[i2] =\ self.allowed_token_ids_mask_cpu_tensor[i2], \ self.allowed_token_ids_mask_cpu_tensor[i1] - self.block_table.swap_row(i1, i2) def condense(self) -> None: """Slide non-empty requests down into lower, empty indices. @@ -529,12 +534,6 @@ def condense(self) -> None: """ num_reqs = self.num_reqs - if self.is_pooling_model: - # Will be contiguous in pooling case, just trim the lists. - del self._req_ids[num_reqs:] - del self.req_output_token_ids[num_reqs:] - return - if not (empty_req_indices := self.batch_update_builder.removed): # All removed requests were replaced by added requests, or else no # requests were removed at all. No condense() needed @@ -562,11 +561,6 @@ def condense(self) -> None: # Move active request down into empty request # index. self.batch_update_builder.pop_removed() - # Autoregressive models require detailed tracking of condense - # operations to support logitsprocs - self.batch_update_builder.moved.append( - (last_req_index, empty_index, - MoveDirectionality.UNIDIRECTIONAL)) req_id = self._req_ids[last_req_index] output_token_ids = self.req_output_token_ids[last_req_index] assert req_id is not None @@ -587,6 +581,21 @@ def condense(self) -> None: self.num_computed_tokens_cpu[ empty_index] = self.num_computed_tokens_cpu[last_req_index] self.block_table.move_row(last_req_index, empty_index) + + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + + if self.is_pooling_model: + last_req_index -= 1 + # Samping state not used by pooling models. + continue + + # Autoregressive models require detailed tracking of condense + # operations to support logitsprocs + self.batch_update_builder.moved.append( + (last_req_index, empty_index, + MoveDirectionality.UNIDIRECTIONAL)) + self.temperature_cpu[empty_index] = self.temperature_cpu[ last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] @@ -601,9 +610,6 @@ def condense(self) -> None: if generator is not None: self.generators[empty_index] = generator - self.request_lora_mapping[empty_index] = self.request_lora_mapping[ - last_req_index] - # TODO convert these to LogitsProcessors if self.allowed_token_ids_mask_cpu_tensor is not None: self.allowed_token_ids_mask_cpu_tensor[ @@ -626,8 +632,9 @@ def refresh_metadata(self): """Apply any batch updates to sampling metadata.""" if self.is_pooling_model: - # Batch changes every step for pooling models. - self.sampling_metadata = self._make_sampling_metadata() + batch_changed = self.batch_update_builder.reset() + if batch_changed: + self.sampling_metadata = self._make_sampling_metadata() return # For non-pooling models - generate and apply logitsprocs update; @@ -720,7 +727,8 @@ def pooling_metadata(self) -> PoolingMetadata: ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: - max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max() + num_reqs = self.num_reqs + max_prompt_len = self.num_prompt_tokens[:num_reqs].max() prompt_token_ids_cpu_tensor = torch.empty( (self.num_reqs, max_prompt_len), device="cpu", @@ -728,11 +736,10 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: pin_memory=self.pin_memory, ) prompt_token_ids = prompt_token_ids_cpu_tensor.numpy() - prompt_token_ids[:] = self.token_ids_cpu[:self. - num_reqs, :max_prompt_len] + prompt_token_ids[:] = self.token_ids_cpu[:num_reqs, :max_prompt_len] # Use the value of vocab_size as a pad since we don't have a # token_id of this value. - for i in range(self.num_reqs): + for i in range(num_reqs): prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7160894b4acd..73117c75b9af 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,7 +35,8 @@ from vllm.forward_context import (BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.models.interfaces import (is_mixture_of_experts, @@ -55,7 +56,6 @@ GiB_bytes, LazyLoader, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up, supports_dynamo) -from vllm.v1.attention.backends.mamba_selectors import get_mamba_attn_backend from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, make_kv_sharing_fast_prefill_attention_metadata, @@ -176,8 +176,8 @@ def __init__( self.attn_groups: list[list[AttentionGroup]] = [] # self.kv_cache_config: KVCacheConfig - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} self.use_aux_hidden_state_outputs = False # Set up speculative decoding. @@ -254,9 +254,6 @@ def __init__( self.seq_lens = torch.zeros(self.max_num_reqs, dtype=torch.int32, device=self.device) - self.slot_mapping = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -439,7 +436,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -450,12 +446,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: self.input_batch.remove_request(req_id) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -499,6 +491,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=pooling_params, generator=generator, @@ -1164,17 +1157,18 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs if not scheduled_encoder_inputs: return - # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # list of tuple (mm_hash, position_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -1207,15 +1201,9 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): for output in curr_group_outputs: encoder_outputs.append(output) - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + # Cache the encoder outputs by mm_hash + for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( output, is_embed=pos_info.is_embed, ) @@ -1233,6 +1221,7 @@ def _gather_mm_embeddings( num_computed_tokens = \ req_state.num_computed_tokens + shift_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -1252,11 +1241,14 @@ def _gather_mm_embeddings( start_idx = max(num_computed_tokens - start_pos, 0) end_idx = min( num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) + num_encoder_tokens, + ) assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] + + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] @@ -1489,10 +1481,8 @@ def _pool( for raw_output, seq_len, prompt_len in zip( raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): - if seq_len == prompt_len: - pooler_output.append(raw_output.data) - else: - pooler_output.append(None) + output = raw_output.data if seq_len == prompt_len else None + pooler_output.append(output) return ModelRunnerOutput( req_ids=self.input_batch.req_ids, @@ -1522,7 +1512,7 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = (self._prepare_inputs(scheduler_output)) + max_query_len) = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -2757,11 +2747,13 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ assert len(self.attn_groups) == 0, \ "Attention backends are already initialized" - attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) def get_attn_backends_for_layers( layer_names: list[str] ) -> dict[type[AttentionBackend], list[str]]: + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) attn_backends = {} attn_backend_layers = defaultdict(list) # Dedupe based on full class name; this is a bit safer than using @@ -2770,7 +2762,7 @@ def get_attn_backends_for_layers( # they are cached correctly, there will be different objects per # layer. for layer_name in layer_names: - attn_backend = attn_layers[layer_name].get_attn_backend() + attn_backend = layers[layer_name].get_attn_backend() key = attn_backend.full_cls_name() attn_backends[key] = attn_backend attn_backend_layers[key].append(layer_name) @@ -2799,20 +2791,8 @@ def create_attn_groups( for kv_cache_group_spec in kv_cache_config.kv_cache_groups: kv_cache_spec = kv_cache_group_spec.kv_cache_spec - if isinstance(kv_cache_spec, AttentionSpec): - attn_backends = get_attn_backends_for_layers( - kv_cache_group_spec.layer_names) - # TODO(lucas): move `get_mamba_attn_backend` into the mamba - # layers like above - elif isinstance(kv_cache_spec, MambaSpec): - attn_backends = { - get_mamba_attn_backend(kv_cache_spec.mamba_type): - kv_cache_group_spec.layer_names - } - else: - raise ValueError( - f"Unknown KV cache spec type: {type(kv_cache_spec)}") - + attn_backends = get_attn_backends_for_layers( + kv_cache_group_spec.layer_names) self.attn_groups.append( create_attn_groups(attn_backends, kv_cache_spec)) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 2a8d65948d57..4a485b7e077d 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -208,8 +208,8 @@ def __init__( # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] - # req_id -> (input_id -> encoder_output) - self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {} + # mm_hash -> encoder_output + self.encoder_cache: dict[str, torch.Tensor] = {} # Request states. self.requests: dict[str, CachedRequestState] = {} @@ -342,7 +342,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and @@ -357,12 +356,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: removed_req_indices.append(req_index) # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests @@ -394,6 +389,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: prompt_token_ids=new_req_data.prompt_token_ids, mm_kwargs=new_req_data.mm_kwargs, mm_positions=new_req_data.mm_positions, + mm_hashes=new_req_data.mm_hashes, sampling_params=sampling_params, pooling_params=None, generator=None, @@ -845,14 +841,16 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + # List of tuple (mm_hash, pos_info) + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] for mm_input_id in encoder_input_ids: + mm_hash = req_state.mm_hashes[mm_input_id] mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, @@ -895,15 +893,15 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # NOTE (NickLucche) here we diverge from logic in other runners, as we # assume to only have whole mm items to process. Hence we avoid the # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, + for (mm_hash, pos_info), output in zip( + mm_hashes_pos, encoder_outputs, ): if req_id not in self.encoder_cache: self.encoder_cache[req_id] = {} assert pos_info.is_embed is None, "Expected all positions to be"\ " contiguous and embeddings." - self.encoder_cache[req_id][input_id] = output + self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( self, @@ -916,6 +914,7 @@ def _gather_mm_embeddings( req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + mm_hashes = req_state.mm_hashes # TODO unroll loop and assume/enforce --disable_chunked_mm_input # NOTE (NickLucche) here we diverge from logic in other runners, as # we assume to only have whole mm items to process. Hence we avoid @@ -936,11 +935,13 @@ def _gather_mm_embeddings( # in the decoder's KV cache. continue - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." assert pos_info.is_embed is None, "Expected all positions to"\ " be contiguous and embeddings." - encoder_output = self.encoder_cache[req_id][i] + encoder_output = self.encoder_cache[mm_hash] mm_embeds.append(encoder_output) return mm_embeds