diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 65be3c5d93b2..b98d42aa7b82 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \ + --model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \ --tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \ --batch_size "$BATCH_SIZE" diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 930adfaf3e19..ceea01166b7f 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -18,12 +18,14 @@ def launch_lm_eval(eval_config, tp_size): trust_remote_code = eval_config.get("trust_remote_code", False) + max_model_len = eval_config.get("max_model_len", 4096) model_args = ( f"pretrained={eval_config['model_name']}," f"tensor_parallel_size={tp_size}," f"enforce_eager=true," f"add_bos_token=true," - f"trust_remote_code={trust_remote_code}" + f"trust_remote_code={trust_remote_code}," + f"max_model_len={max_model_len}" ) results = lm_eval.simple_evaluate( model="vllm", diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 156456c92e63..5e5a532cb57d 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -108,7 +108,6 @@ fi if [[ $commands == *" kernels/attention"* ]]; then commands="${commands} \ --ignore=kernels/attention/test_attention_selector.py \ - --ignore=kernels/attention/test_blocksparse_attention.py \ --ignore=kernels/attention/test_encoder_decoder_attn.py \ --ignore=kernels/attention/test_flash_attn.py \ --ignore=kernels/attention/test_flashinfer.py \ diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 42506730e868..90cc9c844622 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -6,6 +6,7 @@ set -ex # allow to bind to different cores CORE_RANGE=${CORE_RANGE:-48-95} +# used for TP/PP E2E test OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95} NUMA_NODE=${NUMA_NODE:-1} @@ -24,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -48,10 +49,16 @@ function cpu_tests() { # Run basic model test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model - pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model - pytest -v -s tests/models/language/generation -m cpu_model - VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model + # Note: disable until supports V1 + # pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model + # pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model + + # Note: disable Bart until supports V1 + pytest -v -s tests/models/language/generation -m cpu_model \ + --ignore=tests/models/language/generation/test_bart.py + VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \ + --ignore=tests/models/language/generation/test_bart.py + pytest -v -s tests/models/language/pooling -m cpu_model pytest -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ @@ -62,33 +69,26 @@ function cpu_tests() { docker exec cpu-test-"$NUMA_NODE" bash -c " set -e pytest -s -v \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]" + # Note: disable it until supports V1 # Run AWQ test - docker exec cpu-test-"$NUMA_NODE" bash -c " - set -e - VLLM_USE_V1=0 pytest -s -v \ - tests/quantization/test_ipex_quant.py" - - # Run chunked-prefill and prefix-cache test - docker exec cpu-test-"$NUMA_NODE" bash -c " - set -e - pytest -s -v -k cpu_model \ - tests/basic_correctness/test_chunked_prefill.py" + # docker exec cpu-test-"$NUMA_NODE" bash -c " + # set -e + # VLLM_USE_V1=0 pytest -s -v \ + # tests/quantization/test_ipex_quant.py" # online serving - docker exec cpu-test-"$NUMA_NODE" bash -c " + docker exec cpu-test-"$NUMA_NODE" bash -c ' set -e - python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & - timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 - VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \ + VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & + timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1 + python3 benchmarks/benchmark_serving.py \ --backend vllm \ --dataset-name random \ - --model facebook/opt-125m \ + --model meta-llama/Llama-3.2-3B-Instruct \ --num-prompts 20 \ - --endpoint /v1/completions \ - --tokenizer facebook/opt-125m" + --endpoint /v1/completions' # Run multi-lora tests docker exec cpu-test-"$NUMA_NODE" bash -c " diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh index ae5b35a9ac6b..dc9f2d39ba77 100644 --- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh @@ -6,19 +6,17 @@ set -exuo pipefail # Try building the docker image cat <."' language: system verbose: true pass_filenames: false diff --git a/.tekton/vllm-cuda-pull-request.yaml b/.tekton/vllm-cuda-pull-request.yaml index 08214f95aaff..0ab7e9b11f7c 100644 --- a/.tekton/vllm-cuda-pull-request.yaml +++ b/.tekton/vllm-cuda-pull-request.yaml @@ -48,7 +48,7 @@ spec: value: - max_jobs=6 - nvcc_threads=2 - - VLLM_VERSION=0.9.0.1 + - VLLM_VERSION=0.10.0 - name: fetch-git-tags value: true - name: clone-depth diff --git a/CMakeLists.txt b/CMakeLists.txt index 0129f85123fb..98ed682fee7d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1") set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0") # @@ -171,7 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() - # # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. # setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. @@ -232,7 +231,6 @@ endif() set(VLLM_EXT_SRC "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" - "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/cache_kernels.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" @@ -298,7 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" - "csrc/attention/mla/cutlass_mla_entry.cu") + "csrc/attention/mla/cutlass_mla_entry.cu" + "csrc/quantization/fp8/per_token_group_quant.cu") set_gencode_flags_for_srcs( SRCS "${VLLM_EXT_SRC}" @@ -393,7 +392,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu" @@ -409,7 +408,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS) + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " "later if you intend on running FP8 quantized models on " @@ -424,7 +423,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" @@ -438,7 +437,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " "not >= 12.8, we recommend upgrading to CUDA 12.8 or " "later if you intend on running FP8 quantized models on " @@ -453,7 +452,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) # require CUDA 12.8 or later cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" @@ -468,7 +467,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " "not >= 12.8, we recommend upgrading to CUDA 12.8 or " "later if you intend on running FP8 quantized models on " @@ -511,7 +510,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor # require CUDA 12.2 or later (and only work on Hopper). cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -520,7 +519,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1") message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}") else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS) + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS) message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is " "not >= 12.2, we recommend upgrading to CUDA 12.2 or later " "if you intend on running FP8 sparse quantized models on Hopper.") @@ -532,7 +531,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # FP4 Archs and flags cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" @@ -553,9 +552,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUTLASS MLA Archs and flags cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) set(SRCS - "csrc/attention/mla/cutlass_mla_kernels.cu") + "csrc/attention/mla/cutlass_mla_kernels.cu" + "csrc/attention/mla/sm100_cutlass_mla_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${MLA_ARCHS}") @@ -578,7 +578,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -596,6 +596,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + # moe_data.cu is used by all CUTLASS MoE kernels. cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) @@ -642,7 +662,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The machete kernels only work on hopper and require CUDA 12.0 or later. # Only build Machete kernels if we are building for something compatible with sm90a cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS) + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS) # # For the Machete kernels we automatically generate sources for various # preselected input type pairs and schedules. @@ -694,7 +714,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS) message(STATUS "Not building Machete kernels as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " diff --git a/Dockerfile.rocm.ubi b/Dockerfile.rocm.ubi index ccad15a07353..49eb429dcfb1 100644 --- a/Dockerfile.rocm.ubi +++ b/Dockerfile.rocm.ubi @@ -5,7 +5,7 @@ ARG VLLM_VERSION # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" ARG MAX_JOBS=12 -ARG VLLM_TGIS_ADAPTER_VERSION=0.7.1 +ARG VLLM_TGIS_ADAPTER_VERSION=0.8.0 FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base @@ -118,7 +118,7 @@ FROM rocm_devel AS build_flashattention ARG FA_GFX_ARCHS="gfx90a;gfx942" # the FA_BRANCH commit belongs to the ROCm/flash-attention fork, `main_perf` branch -ARG FA_BRANCH="3cea2fb" +ARG FA_BRANCH="1a7f4dfa" ARG MAX_JOBS ENV MAX_JOBS=${MAX_JOBS} @@ -162,27 +162,6 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ SETUPTOOLS_SCM_PRETEND_VERSION="$VLLM_VERSION" \ python3 setup.py bdist_wheel --dist-dir=dist -#################### libsodium Build IMAGE #################### -FROM rocm_base as libsodium-builder - -RUN microdnf install -y --nodocs gcc gzip tar \ - && microdnf clean all - -WORKDIR /usr/src/libsodium - -ARG LIBSODIUM_VERSION -RUN curl -LO https://github.com/jedisct1/libsodium/releases/download/${LIBSODIUM_VERSION}-RELEASE/libsodium-${LIBSODIUM_VERSION}.tar.gz \ - && tar -xzvf libsodium*.tar.gz \ - && rm -f libsodium*.tar.gz \ - && mv libsodium*/* ./ - -RUN CFLAGS="-O3 -Wall -Werror=format-security -Wno-unused-function -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection" \ - ./configure \ - --prefix="/usr/" \ - --libdir=/usr/lib64 && \ - make -j $(nproc) && \ - make check - ################################################################################################## FROM rocm_base AS vllm-openai @@ -199,11 +178,6 @@ ENV PATH=$VIRTUAL_ENV/bin:$PATH RUN microdnf install -y --setopt=install_weak_deps=0 --nodocs gcc rsync && \ microdnf clean all -# Install libsodium for Tensorizer encryption -RUN --mount=type=bind,from=libsodium-builder,src=/usr/src/libsodium,target=/usr/src/libsodium \ - cd /usr/src/libsodium \ - && make install - RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install/amdsmi/ \ --mount=type=bind,from=build_flashattention,src=/install,target=/install/flashattention \ --mount=type=bind,from=build_vllm,src=/workspace/dist,target=/install/vllm/ \ @@ -215,14 +189,14 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install/amdsmi/ \ --extra-index-url "https://download.pytorch.org/whl/nightly/rocm${version}" \ /install/amdsmi/*.whl\ /install/flashattention/*.whl\ - "$(echo /install/vllm/*.whl)[audio,video,tensorizer]" + "$(echo /install/vllm/*.whl)[audio,video,tensorizer]" \ + && uv pip install blobfile ENV HF_HUB_OFFLINE=1 \ HOME=/home/vllm \ # Allow requested max length to exceed what is extracted from the # config.json # see: https://github.com/vllm-project/vllm/pull/7080 - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ VLLM_USAGE_SOURCE=production-docker-image \ VLLM_WORKER_MULTIPROC_METHOD=spawn \ VLLM_NO_USAGE_STATS=1 \ @@ -230,7 +204,6 @@ ENV HF_HUB_OFFLINE=1 \ TOKENIZERS_PARALLELISM=false \ RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 \ VLLM_USE_TRITON_FLASH_ATTN=0 \ - VLLM_USE_V1=1 \ HIP_FORCE_DEV_KERNARG=1 \ OUTLINES_CACHE_DIR=/tmp/outlines \ NUMBA_CACHE_DIR=/tmp/numba \ @@ -244,6 +217,10 @@ RUN umask 002 && \ COPY LICENSE /licenses/vllm.md COPY examples/*.jinja /app/data/template/ +COPY examples/*.jinja /opt/app-root/template/ +RUN chown -R vllm /opt/app-root/template && chmod -R g+r /opt/app-root/template +# RUN mkdir -p /opt/app-root && \ +# ln -s /app/data/template /opt/app-root/template USER 2000 WORKDIR /home/vllm @@ -273,4 +250,4 @@ ENV GRPC_PORT=8033 \ DISABLE_LOGPROBS_DURING_SPEC_DECODING=false USER 2000 -ENTRYPOINT ["python3", "-m", "vllm_tgis_adapter", "--uvicorn-log-level=warning"] +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.ubi b/Dockerfile.ubi index 12db83a3132c..8e3b24cfb7a9 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -2,9 +2,8 @@ ARG BASE_UBI_IMAGE_TAG=9.5-1742914212 ARG PYTHON_VERSION=3.12 ARG VLLM_VERSION -ARG VLLM_TGIS_ADAPTER_VERSION="0.7.1" +ARG VLLM_TGIS_ADAPTER_VERSION="0.8.0" ARG TORCH_CUDA_ARCH_LIST="7.5 8.0 8.6 8.9 9.0 10.0 12.0+PTX" -ARG vllm_fa_cmake_gpu_arches='80-real;90-real' ARG max_jobs=2 ARG nvcc_threads=8 @@ -94,8 +93,6 @@ COPY . . ARG TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST -ARG vllm_fa_cmake_gpu_arches -ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches} # max jobs used by Ninja to build extensions ARG max_jobs @@ -141,19 +138,20 @@ RUN microdnf install -y --nodocs gcc \ && microdnf clean all # install vllm wheel first, so that torch etc will be installed +# nccl install is temp until torch is upgraded: https://github.com/vllm-project/vllm/issues/19166 RUN --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ uv pip install \ --extra-index-url="https://download.pytorch.org/whl/cu128" --index-strategy='unsafe-best-match' \ "$(echo dist/*.whl)[audio,video,tensorizer]" --verbose \ - "https://storage.googleapis.com/neuralmagic-public-pypi/dist/flashinfer_python-0.2.5-cp38-abi3-linux_x86_64.whl" + "https://storage.googleapis.com/nm-public-pypi/dist/flashinfer_python-0.2.8-cp39-abi3-linux_x86_64.whl" \ + && uv pip install -U nvidia-nccl-cu12==2.26.5 blobfile ENV HF_HUB_OFFLINE=1 \ HOME=/home/vllm \ # Allow requested max length to exceed what is extracted from the # config.json # see: https://github.com/vllm-project/vllm/pull/7080 - VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \ VLLM_USAGE_SOURCE=production-docker-image \ VLLM_WORKER_MULTIPROC_METHOD=fork \ VLLM_NO_USAGE_STATS=1 \ @@ -176,6 +174,8 @@ RUN umask 002 && \ COPY LICENSE /licenses/vllm.md COPY examples/*.jinja /app/data/template/ +COPY examples/*.jinja /opt/app-root/template/ +RUN chown -R vllm /opt/app-root/template && chmod -R g+r /opt/app-root/template USER 2000 WORKDIR /home/vllm @@ -195,6 +195,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \ "$(echo /workspace/dist/*.whl)[audio,video,tensorizer]" \ vllm-tgis-adapter==${VLLM_TGIS_ADAPTER_VERSION} +# Upgrade NCCL back to required version after vllm-tgis-adapter installation +RUN --mount=type=cache,target=/root/.cache/uv \ + HOME=/root uv pip install -U nvidia-nccl-cu12==2.26.5 + ENV GRPC_PORT=8033 \ PORT=8000 \ # As an optimization, vLLM disables logprobs when using spec decoding by @@ -204,4 +208,3 @@ ENV GRPC_PORT=8033 \ DISABLE_LOGPROBS_DURING_SPEC_DECODING=false USER 2000 -ENTRYPOINT ["python3", "-m", "vllm_tgis_adapter", "--uvicorn-log-level=warning"] \ No newline at end of file diff --git a/README.md b/README.md index 3e6ae2acab2a..dc2f0afbe353 100644 --- a/README.md +++ b/README.md @@ -63,13 +63,11 @@ vLLM is fast with: - Speculative decoding - Chunked prefill -**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script. - vLLM is flexible and easy to use with: - Seamless integration with popular Hugging Face models - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -- Tensor parallelism and pipeline parallelism support for distributed inference +- Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron diff --git a/RELEASE.md b/RELEASE.md index 7f5270715212..9352e7ef706c 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -52,3 +52,36 @@ After branch cut, we approach finalizing the release branch with clear criteria * Release branch specific changes (e.g. change version identifiers or CI fixes) Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes. + +## Manual validations + +### E2E Performance Validation + +Before each release, we perform end-to-end performance validation to ensure no regressions are introduced. This validation uses the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) on PyTorch CI. + +**Current Coverage:** +* Models: Llama3, Llama4, and Mixtral +* Hardware: NVIDIA H100 and AMD MI300x +* *Note: Coverage may change based on new model releases and hardware availability* + +**Performance Validation Process:** + +**Step 1: Get Access** +Request write access to the [pytorch/pytorch-integration-testing](https://github.com/pytorch/pytorch-integration-testing) repository to run the benchmark workflow. + +**Step 2: Review Benchmark Setup** +Familiarize yourself with the benchmark configurations: +* [CUDA setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/cuda) +* [ROCm setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/rocm) + +**Step 3: Run the Benchmark** +Navigate to the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) and configure: +* **vLLM branch**: Set to the release branch (e.g., `releases/v0.9.2`) +* **vLLM commit**: Set to the RC commit hash + +**Step 4: Review Results** +Once the workflow completes, benchmark results will be available on the [vLLM benchmark dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) under the corresponding branch and commit. + +**Step 5: Performance Comparison** +Compare the current results against the previous release to verify no performance regressions have occurred. Here is an +example of [v0.9.1 vs v0.9.2](https://hud.pytorch.org/benchmark/llms?startTime=Thu%2C%2017%20Apr%202025%2021%3A43%3A50%20GMT&stopTime=Wed%2C%2016%20Jul%202025%2021%3A43%3A50%20GMT&granularity=week&lBranch=releases/v0.9.1&lCommit=b6553be1bc75f046b00046a4ad7576364d03c835&rBranch=releases/v0.9.2&rCommit=a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f&repoName=vllm-project%2Fvllm&benchmarkName=&modelName=All%20Models&backendName=All%20Backends&modeName=All%20Modes&dtypeName=All%20DType&deviceName=All%20Devices&archName=All%20Platforms). diff --git a/benchmarks/auto_tune/README.md b/benchmarks/auto_tune/README.md new file mode 100644 index 000000000000..7732f50b1d22 --- /dev/null +++ b/benchmarks/auto_tune/README.md @@ -0,0 +1,137 @@ +# Automated vLLM Server Parameter Tuning + +This script automates the process of finding the optimal server parameter combination (`max-num-seqs` and `max-num-batched-tokens`) to maximize throughput for a vLLM server. It also supports additional constraints such as E2E latency and prefix cache hit rate. + +## Table of Contents +- [Prerequisites](#prerequisites) +- [Configuration](#configuration) +- [How to Run](#how-to-run) +- [Example Use Cases](#example-use-cases) +- [Output](#output) +- [How It Works](#how-it-works) + +## Prerequisites + +Before running the script, please ensure the following steps are completed: + +1. **Clone vLLM & Set Up Branch**: Clone the vLLM repository and check out to your desired branch. + +```bash +git clone https://github.com/vllm-project/vllm.git +cd vllm +# git checkout +``` + +1. **Install Environment**: Install or update the correct running environment. For TPU usage, activate your `conda` environment and install the corresponding `torch` and `torch_xla` versions. + +2. **Model Configuration**: If you are using a customized model, ensure its configuration files are correctly placed and accessible. + +## Configuration + +You must set the following variables at the top of the script before execution. + +| Variable | Description | Example Value | +| --- | --- | --- | +| `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` | +| `MODEL` | **Required.** The Hugging Face model identifier to be served by vllm. | `"meta-llama/Llama-3.1-8B-Instruct"` | +| `SYSTEM`| **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` | +| `TP` | **Required.** The tensor-parallelism size. | `1` | +| `DOWNLOAD_DIR` | **Required.** Directory to download and load model weights from. | `""` (default download path) | +| `INPUT_LEN` | **Required.** Request input length. | `4000` | +| `OUTPUT_LEN` | **Required.** Request output length. | `16` | +| `MIN_CACHE_HIT_PCT` | Prefix cache hit rate in percentage (0-100). Set to `0` to disable. | `60` | +| `MAX_LATENCY_ALLOWED_MS` | The maximum allowed P99 end-to-end latency in milliseconds. Set to a very large number (e.g., `100000000000`) to effectively ignore the latency constraint. | `500` | +| `NUM_SEQS_LIST` | A space-separated string of `max-num-seqs` values to test. | `"128 256"` | +| `NUM_BATCHED_TOKENS_LIST` | A space-separated string of `max-num-batched-tokens` values to test. | `"1024 2048 4096"` | + +**Note**: The default `NUM_SEQS_LIST` and `NUM_BATCHED_TOKENS_LIST` are set for medium-sized inputs/outputs. For very short contexts (e.g., 20 input, 20 output tokens), you may need to test larger values for `max-num-seqs`. + +## How to Run + +1. **Configure**: Edit the script and set the variables in the [Configuration](#configuration) section. +2. **Execute**: Run the script. Since the process can take a long time, it is highly recommended to use a terminal multiplexer like `tmux` or `screen` to prevent the script from stopping if your connection is lost. + +``` +cd +bash auto_tune.sh +``` + + Please note that the `bash auto_tune.sh` command cannot contain full or partial path with keyword `vllm`, otherwise `pkill -f vllm` command will also kill this script itself. + +## Example Use Cases + +Here are a few examples of how to configure the script for different goals: + +### 1. Maximize Throughput (No Latency Constraint) +- **Goal**: Find the best `max-num-seqs` and `max-num-batched-tokens` to get the highest possible throughput for 1800 input tokens and 20 output tokens. +- **Configuration**: + +```bash +INPUT_LEN=1800 +OUTPUT_LEN=20 +MIN_CACHE_HIT_PCT=0 +MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number +``` + +#### 2. Maximize Throughput with a Latency Requirement +- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms. +- **Configuration**: + +```bash +INPUT_LEN=1800 +OUTPUT_LEN=20 +MIN_CACHE_HIT_PCT=0 +MAX_LATENCY_ALLOWED_MS=500 +``` + +#### 3. Maximize Throughput with Prefix Caching and Latency Requirements +- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms. +- **Configuration**: + +```bash +INPUT_LEN=1800 +OUTPUT_LEN=20 +MIN_CACHE_HIT_PCT=60 +MAX_LATENCY_ALLOWED_MS=500 +``` + +## Output + +After the script finishes, you will find the results in a new, timestamped directory created inside `$BASE/auto-benchmark/`. + +- **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run: + - `vllm_log_...txt`: The log output from the vLLM server for each parameter combination. + - `bm_log_...txt`: The log output from the `benchmark_serving.py` script for each benchmark run. + +- **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found. + +``` +# Example result.txt content +hash:a1b2c3d4... +max_num_seqs: 128, max_num_batched_tokens: 2048, request_rate: 10.0, e2el: 450.5, throughput: 9.8, goodput: 9.8 +max_num_seqs: 128, max_num_batched_tokens: 4096 does not meet latency requirement 500 +... +best_max_num_seqs: 256, best_num_batched_tokens: 2048, best_throughput: 12.5, profile saved in: /home/user/vllm/auto-benchmark/2024_08_01_10_30/profile +``` + + If it cannot find the best parameters, the final row will be `best_max_num_seqs: 0, best_num_batched_tokens: 0, best_throughput: 0`. This can be due to either the server not starting properly, or the latency requirement being too strict. + +- **Profiler Trace**: A directory named `profile` is created inside the log directory. It contains the profiler trace file (e.g., `.xplane.pb` for TPU or a `.json` trace for GPU) from the single best-performing run. + +## How It Works + +The script follows a systematic process to find the optimal parameters: + +1. **Find Max GPU Memory Utilization**: The script first determines the highest safe `gpu-memory-utilization` (starting from 0.98 and decreasing) that does not cause an Out-Of-Memory (OOM) error when launching the server. This ensures the benchmark runs use the maximum available memory without crashing. + +2. **Iterate and Benchmark**: It then enters a nested loop, iterating through every combination of `max-num-seqs` and `max-num-batched-tokens` provided in the configuration lists. + +3. **Latency-Aware Throughput Search**: For each parameter combination: + - The vLLM server is started. + - A benchmark is first run with an infinite request rate (`--request-rate inf`). + - If the resulting P99 E2E latency is within the `MAX_LATENCY_ALLOWED_MS` limit, this throughput is considered the maximum for this configuration. + - If the latency is too high, the script performs a search by iteratively decreasing the request rate until the latency constraint is met. This finds the highest sustainable throughput for the given parameters and latency requirement. + +4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far. + +5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard. diff --git a/benchmarks/auto_tune.sh b/benchmarks/auto_tune/auto_tune.sh similarity index 79% rename from benchmarks/auto_tune.sh rename to benchmarks/auto_tune/auto_tune.sh index b257b57ce06f..eaa28ea5c92b 100644 --- a/benchmarks/auto_tune.sh +++ b/benchmarks/auto_tune/auto_tune.sh @@ -1,36 +1,7 @@ #!/bin/bash # This script aims to tune the best server parameter combinations to maximize throughput for given requirement. -# The current server parameter combination is max_num_seqs and max_num_batched_tokens -# It also supports additional requirement: e2e latency and prefix cache. - -# Pre-requisite: -# 1. Checkout to your branch, install/ update the correct running env. For TPU, activate conda env and install the corresponding torch, xla version. -# 2. If the model is customized, replace the MODEL's config with the customized config. -# 3. Set variables (ALL REQUIRED) -# BASE: your directory for vllm repo -# MODEL: the model served by vllm -# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support. -# TP: ways of tensor parallelism -# DOWNLOAD_DIR: directory to download and load model weights. -# INPUT_LEN: request input len -# OUTPUT_LEN: request output len -# MIN_CACHE_HIT_PCT: prefix cache rate -# MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000 -# NUM_SEQS_LIST: a list of `max-num-seqs` you want to loop with. -# NUM_BATCHED_TOKENS_LIST: a list of `max-num-batched-tokens` you want to loop with. -# Note that the default NUM_SEQS_LIST and NUM_BATCHED_TOKENS_LIST are set for medium size input/output len, for extra short context (such as 20:20), you might need to include larger numbers in NUM_SEQS_LIST. -# 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens. -# 5. The final result will be saved in RESULT file. - - -# Example use cases -# 1. Given input_len=1800, output_len=20, what's the best max_num_seqs and max_num_batched_tokens to get highest throughput? -# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=100000000000 -# 2. If we have latency requirement to be lower than 500ms, what's the best server parameter? -# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=500 -# 3. If we want to reach 60% prefix cache, what's the best server parameter? -# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=60, MAX_LATENCY_ALLOWED_MS=500 +# See details in README (benchmarks/auto_tune/README.md). TAG=$(date +"%Y_%m_%d_%H_%M") BASE="" @@ -155,11 +126,12 @@ run_benchmark() { # get a basic qps by using request-rate inf bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt" prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 )) - python benchmarks/benchmark_serving.py \ +adjusted_input_len=$(( INPUT_LEN - prefix_len )) + python3 benchmarks/benchmark_serving.py \ --backend vllm \ --model $MODEL \ --dataset-name random \ - --random-input-len $INPUT_LEN \ + --random-input-len $adjusted_input_len \ --random-output-len $OUTPUT_LEN \ --ignore-eos \ --disable-tqdm \ @@ -188,11 +160,11 @@ run_benchmark() { curl -X POST http://0.0.0.0:8004/reset_prefix_cache sleep 5 bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt" - python benchmarks/benchmark_serving.py \ + python3 benchmarks/benchmark_serving.py \ --backend vllm \ --model $MODEL \ --dataset-name random \ - --random-input-len $INPUT_LEN \ + --random-input-len $adjusted_input_len \ --random-output-len $OUTPUT_LEN \ --ignore-eos \ --disable-tqdm \ diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 55c0cf851264..1ad6cef7a9db 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -324,6 +324,9 @@ def sample( input_low = int(real_input_len * (1 - range_ratio)) input_high = int(real_input_len * (1 + range_ratio)) output_low = int(output_len * (1 - range_ratio)) + # Ensure the lower bound for output length is at least 1 to prevent + # sampling 0 tokens, which can cause request failures. + output_low = max(output_low, 1) output_high = int(output_len * (1 + range_ratio)) # Add logging for debugging @@ -701,6 +704,7 @@ def __init__( self, dataset_path: str, dataset_split: str, + no_stream: bool = False, dataset_subset: Optional[str] = None, **kwargs, ) -> None: @@ -708,6 +712,7 @@ def __init__( self.dataset_split = dataset_split self.dataset_subset = dataset_subset + self.load_stream = not no_stream self.load_data() def load_data(self) -> None: @@ -716,7 +721,7 @@ def load_data(self) -> None: self.dataset_path, name=self.dataset_subset, split=self.dataset_split, - streaming=True, + streaming=self.load_stream, ) self.data = self.data.shuffle(seed=self.random_seed) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 9b235266dff1..c597fb1068ab 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -30,7 +30,7 @@ import random import time import warnings -from collections.abc import AsyncGenerator, Iterable +from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime from typing import Any, Literal, Optional @@ -73,6 +73,7 @@ VisionArenaDataset, ) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json +from vllm.benchmarks.serve import get_request MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -107,101 +108,6 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] -def _get_current_request_rate( - ramp_up_strategy: Optional[Literal["linear", "exponential"]], - ramp_up_start_rps: Optional[int], - ramp_up_end_rps: Optional[int], - request_index: int, - total_requests: int, - request_rate: float, -) -> float: - if ( - ramp_up_strategy - and ramp_up_start_rps is not None - and ramp_up_end_rps is not None - ): - progress = request_index / max(total_requests - 1, 1) - if ramp_up_strategy == "linear": - increase = (ramp_up_end_rps - ramp_up_start_rps) * progress - return ramp_up_start_rps + increase - elif ramp_up_strategy == "exponential": - ratio = ramp_up_end_rps / ramp_up_start_rps - return ramp_up_start_rps * (ratio**progress) - else: - raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") - return request_rate - - -async def get_request( - input_requests: list[SampleRequest], - request_rate: float, - burstiness: float = 1.0, - ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, - ramp_up_start_rps: Optional[int] = None, - ramp_up_end_rps: Optional[int] = None, -) -> AsyncGenerator[tuple[SampleRequest, float], None]: - """ - Asynchronously generates requests at a specified rate - with OPTIONAL burstiness and OPTIONAL ramp-up strategy. - - Args: - input_requests: - A list of input requests, each represented as a SampleRequest. - request_rate: - The rate at which requests are generated (requests/s). - burstiness (optional): - The burstiness factor of the request generation. - Only takes effect when request_rate is not inf. - Default value is 1, which follows a Poisson process. - Otherwise, the request intervals follow a gamma distribution. - A lower burstiness value (0 < burstiness < 1) results - in more bursty requests, while a higher burstiness value - (burstiness > 1) results in a more uniform arrival of requests. - ramp_up_strategy (optional): - The ramp-up strategy. Can be "linear" or "exponential". - If None, uses constant request rate (specified by request_rate). - ramp_up_start_rps (optional): - The starting request rate for ramp-up. - ramp_up_end_rps (optional): - The ending request rate for ramp-up. - """ - assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}." - ) - # Convert to list to get length for ramp-up calculations - if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): - input_requests = list(input_requests) - - total_requests = len(input_requests) - request_index = 0 - - for request in input_requests: - current_request_rate = _get_current_request_rate( - ramp_up_strategy, - ramp_up_start_rps, - ramp_up_end_rps, - request_index, - total_requests, - request_rate, - ) - - yield request, current_request_rate - - request_index += 1 - - if current_request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - - theta = 1.0 / (current_request_rate * burstiness) - - # Sample the request interval from the gamma distribution. - # If burstiness is 1, it follows exponential distribution. - interval = np.random.gamma(shape=burstiness, scale=theta) - # The next request will be sent after the interval. - await asyncio.sleep(interval) - - def calculate_metrics( input_requests: list[SampleRequest], outputs: list[RequestFuncOutput], @@ -825,6 +731,7 @@ def main(args: argparse.Namespace): dataset_subset=args.hf_subset, dataset_split=args.hf_split, random_seed=args.seed, + no_stream=args.no_stream, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, @@ -1033,6 +940,11 @@ def create_argument_parser(): help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.", ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Do not load the dataset in streaming mode.", + ) parser.add_argument( "--max-concurrency", type=int, diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 0ded34c70bad..14461121fece 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -356,6 +356,7 @@ def get_requests(args, tokenizer): elif args.dataset_name == "burstgpt": dataset_cls = BurstGPTDataset elif args.dataset_name == "hf": + common_kwargs["no_stream"] = args.no_stream if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset common_kwargs["dataset_subset"] = None @@ -610,6 +611,11 @@ def create_argument_parser(): help="Name of the dataset to benchmark on.", default="sharegpt", ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Do not load the dataset in streaming mode.", + ) parser.add_argument( "--dataset", type=str, diff --git a/benchmarks/kernels/bench_nvfp4_gemm.py b/benchmarks/kernels/bench_nvfp4_gemm.py new file mode 100644 index 000000000000..9e832c9faa8e --- /dev/null +++ b/benchmarks/kernels/bench_nvfp4_gemm.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types +from vllm.triton_utils import triton + +if not current_platform.has_device_capability(100): + raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)") + + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "nvfp4": dict(no_a_quant=False, enabled=True), + "nvfp4-noquant": dict(no_a_quant=True, enabled=True), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def _quant_weight_nvfp4(b: torch.Tensor, device: str): + # Compute global scale for weight + b_amax = torch.abs(b).max().to(torch.float32) + b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax + b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale) + return b_fp4, scale_b_fp4, b_global_scale + + +def build_nvfp4_runner(cfg, a, b, dtype, device): + b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device) + + # Compute global scale for activation + # NOTE: This is generally provided ahead-of-time by the model checkpoint. + a_amax = torch.abs(a).max().to(torch.float32) + a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax + + # Alpha for the GEMM operation + alpha = 1.0 / (a_global_scale * b_global_scale) + + if cfg["no_a_quant"]: + # Pre-quantize activation + a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) + + def run(): + return ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype + ) + + return run + + # Quantize activation on-the-fly + def run(): + a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale) + return ops.cutlass_scaled_fp4_mm( + a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype + ) + + return run + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=_enabled, + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs NVFP4 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_nvfp4_runner(cfg, a, b, dtype, device) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + out.append(KN) + return out + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_nvfp4_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/bench_per_token_quant_fp8.py b/benchmarks/kernels/bench_per_token_quant_fp8.py new file mode 100644 index 000000000000..923d678f1f2d --- /dev/null +++ b/benchmarks/kernels/bench_per_token_quant_fp8.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from typing import Callable + +import torch + +from vllm import _custom_ops as ops +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.triton_utils import triton + + +# TODO(luka): use standalone_compile utility +def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int): + def inner(*args): + torch._dynamo.mark_dynamic(args[arg_index], dim_index) + return fn(*args) + + return inner + + +torch._dynamo.config.recompile_limit = 8888 +compilation_config = CompilationConfig(custom_ops=["none"]) +with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)): + torch_per_token_quant_fp8 = torch.compile( + QuantFP8(False, GroupShape.PER_TOKEN), + fullgraph=True, + dynamic=False, # recompile for different shapes + ) + + # First dim is explicitly dynamic to simulate vLLM usage + torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0) + + +def cuda_per_token_quant_fp8( + input: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return ops.scaled_fp8_quant(input) + + +def calculate_diff(batch_size: int, seq_len: int): + """Calculate difference between Triton and CUDA implementations.""" + device = torch.device("cuda") + x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device) + + torch_out, torch_scale = torch_per_token_quant_fp8(x) + cuda_out, cuda_scale = cuda_per_token_quant_fp8(x) + + if torch.allclose( + cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5 + ) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [1, 16, 32, 64, 128] +seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096] + +configs = list(itertools.product(batch_size_range, seq_len_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=configs, + line_arg="provider", + line_vals=["torch", "cuda"], + line_names=["Torch", "CUDA"], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="per-token-dynamic-quant-fp8-performance", + args={}, + ) +) +def benchmark_quantization(batch_size, seq_len, provider): + dtype = torch.float16 + device = torch.device("cuda") + + x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch": + fn = lambda: torch_per_token_quant_fp8(x.clone()) + elif provider == "cuda": + fn = lambda: cuda_per_token_quant_fp8(x.clone()) + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + calculate_diff(batch_size=4, seq_len=4096) + benchmark_quantization.run(print_data=True) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 07af58d81c68..c350aaf5d3ad 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -86,6 +86,9 @@ def benchmark_config( (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_deep_gemm: + # we use the default block shape for deepgemm + block_quant_shape = [128, 128] if use_fp8_w8a8: if block_quant_shape: block_n, block_k = block_quant_shape[0], block_quant_shape[1] @@ -573,7 +576,11 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): + elif config.architectures[0] in ( + "DeepseekV3ForCausalLM", + "DeepseekV2ForCausalLM", + "Glm4MoeForCausalLM", + ): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -583,6 +590,11 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"): + E = config.num_experts + topk = config.moe_topk[0] + intermediate_size = config.moe_intermediate_size[0] + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Support for llama4 config = config.get_text_config() diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py index 5170ac09dc42..1af5a21caf46 100644 --- a/benchmarks/kernels/benchmark_moe_align_block_size.py +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): sorted_ids_triton = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device="cuda" ) - sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value - expert_ids_triton = torch.zeros( + expert_ids_triton = torch.empty( (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" ) num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") sorted_ids_vllm = torch.empty_like(sorted_ids_triton) - sorted_ids_vllm.fill_(topk_ids.numel()) - expert_ids_vllm = torch.zeros_like(expert_ids_triton) + expert_ids_vllm = torch.empty_like(expert_ids_triton) num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) # 2. run implementations @@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = max_num_tokens_padded // block_size expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index dba1f3943b96..04d2205aa372 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -8,12 +8,13 @@ import torch from transformers import AutoConfig -from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute, _moe_unpermute_and_reduce, + moe_permute, + moe_unpermute, ) -from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -63,18 +64,19 @@ def prepare(i: int): def run(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + ( + permuted_hidden_states, + a1q_scale, + first_token_off, + inv_perm_idx, + m_indices, + ) = moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, ) else: ( @@ -150,18 +152,19 @@ def benchmark_unpermute( def prepare(): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( - moe_permute( - qhidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - token_expert_indices=token_expert_indices, - topk=topk, - n_expert=num_experts, - n_local_expert=num_experts, - expert_map=None, - align_block_size=align_block_size, - ) + ( + permuted_hidden_states, + a1q_scale, + first_token_off, + inv_perm_idx, + m_indices, + ) = moe_permute( + qhidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, ) # convert to fp16/bf16 as gemm output return ( @@ -191,16 +194,19 @@ def prepare(): def run(input: tuple): if use_customized_permute: - (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input + ( + permuted_hidden_states, + first_token_off, + inv_perm_idx, + m_indices, + ) = input + output = torch.empty_like(hidden_states) moe_unpermute( + output, permuted_hidden_states, topk_weights, - topk_ids, inv_perm_idx, first_token_off, - topk, - num_experts, - num_experts, ) else: ( @@ -211,7 +217,11 @@ def run(input: tuple): inv_perm, ) = input _moe_unpermute_and_reduce( - output_hidden_states, permuted_hidden_states, inv_perm, topk_weights + output_hidden_states, + permuted_hidden_states, + inv_perm, + topk_weights, + True, ) # JIT compilation & warmup @@ -318,6 +328,7 @@ def main(args: argparse.Namespace): elif ( config.architectures[0] == "DeepseekV3ForCausalLM" or config.architectures[0] == "DeepseekV2ForCausalLM" + or config.architectures[0] == "Glm4MoeForCausalLM" ): E = config.n_routed_experts topk = config.num_experts_per_tok diff --git a/benchmarks/kernels/benchmark_trtllm_attention.py b/benchmarks/kernels/benchmark_trtllm_attention.py new file mode 100644 index 000000000000..8c980f930366 --- /dev/null +++ b/benchmarks/kernels/benchmark_trtllm_attention.py @@ -0,0 +1,240 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import os +import random +from datetime import datetime + +import flashinfer +import torch + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@torch.no_grad() +def benchmark_decode( + num_seqs, + max_seq_len, + page_size=16, + dtype=torch.bfloat16, + kv_layout="HND", + num_kv_heads=8, + kv_cache_dtype="auto", + head_dim=128, + warmup=10, + trials=20, +): + torch.set_default_device("cuda") + device = "cuda" + torch.manual_seed(0) + + # Currently only HEAD_GRP_SIZE == 8 is supported + HEAD_GRP_SIZE = 8 + MAX_SEQ_LEN = max_seq_len + + # large number to reduce kv_cache reuse + NUM_BLOCKS = int(256000 / page_size) + + workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device) + + # For decode, batch_size is num_decode_token + num_qo_heads = num_kv_heads * HEAD_GRP_SIZE + sm_scale = float(1.0 / (head_dim**0.5)) + q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype) + kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + + max_kv_len = max(kv_lens) + kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device) + max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size + + block_tables = torch.randint( + 0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32 + ) + + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim) + kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype) + k_scale = v_scale = 1.0 + + if kv_cache_dtype.startswith("fp8"): + kv_cache, _ = to_float8(kv_cache) + + # Benchmark TRT decode + def trt_decode(): + return flashinfer.decode.trtllm_batch_decode_with_kv_cache( + q, + kv_cache, + workspace_buffer, + num_qo_heads, + num_kv_heads, + sm_scale, + block_tables, + kv_lens_tensor, + page_size, + max_kv_len, + kv_cache_dtype, + k_scale, + v_scale, + ) + + def time_fn(fn, warmup=10, trials=20): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + times = [] + for i in range(warmup): + fn() + for i in range(trials): + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) # ms + return sum(times) / len(times), torch.std(torch.tensor(times)) + + # TRT Decode + trt_mean, trt_std = time_fn(trt_decode) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + page_size - 1) // page_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % page_size + if kv_last_page_len == 0: + kv_last_page_len = page_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + kv_layout, + use_tensor_cores=((num_qo_heads // num_kv_heads) > 4), + ) + + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + "NONE", + q_data_type=dtype, + kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype, + ) + + def baseline_decode(): + return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale) + + baseline_mean, baseline_std = time_fn(baseline_decode) + + # Calculate percentage speedup (positive means TRT is faster) + speedup_percent = (baseline_mean - trt_mean) / baseline_mean + + print( + f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}" + f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}" + ) + + # Return results for CSV writing + return { + "num_seqs": num_seqs, + "trt_mean": trt_mean, + "trt_std": trt_std.item(), + "baseline_mean": baseline_mean, + "baseline_std": baseline_std.item(), + "speedup_percent": speedup_percent, + "q_dtype": str(dtype), + "kv_cache_dtype": kv_cache_dtype, + "page_size": page_size, + "num_kv_heads": num_kv_heads, + "head_dim": head_dim, + "max_seq_len": max_seq_len, + } + + +def write_results_to_csv(results, filename=None): + """Write benchmark results to CSV file.""" + if filename is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv" + + fieldnames = [ + "num_seqs", + "trt_mean", + "trt_std", + "baseline_mean", + "baseline_std", + "speedup_percent", + "q_dtype", + "kv_cache_dtype", + "page_size", + "num_kv_heads", + "head_dim", + "max_seq_len", + ] + + file_exists = os.path.exists(filename) + + with open(filename, "a", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + if not file_exists: + writer.writeheader() + + for result in results: + writer.writerow(result) + + print(f"Results written to {filename}") + + +if __name__ == "__main__": + num_seqs = [1, 4, 8, 16, 32, 64, 128, 256] + max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + all_results = [] + + print("Running benchmark for kv_cache_dtype: bfloat16") + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in num_seqs: + result = benchmark_decode( + bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto" + ) + all_results.append(result) + + print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8") + print( + "\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent" + ) + for max_seq_len in max_seq_lens: + for bs in num_seqs: + result = benchmark_decode( + bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8" + ) + all_results.append(result) + + # Write all results to CSV + write_results_to_csv(all_results) diff --git a/benchmarks/kv_cache/benchmark_block_pool.py b/benchmarks/kv_cache/benchmark_block_pool.py new file mode 100644 index 000000000000..134551bb6128 --- /dev/null +++ b/benchmarks/kv_cache/benchmark_block_pool.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import gc +import time +from typing import Optional + +from tabulate import tabulate + +from vllm.utils import FlexibleArgumentParser +from vllm.v1.core.block_pool import BlockPool + + +class Metric: + def __init__(self) -> None: + self.cnt: int = 0 + self.sum_v: int = 0 + self.max_v: Optional[int] = None + + def update(self, v: int) -> None: + self.cnt += 1 + self.sum_v += v + if self.max_v is None: + self.max_v = v + else: + self.max_v = max(self.max_v, v) + + def avg_v(self) -> float: + return self.sum_v * 1.0 / self.cnt + + +def main(args): + rows = [] + for allocate_block in args.allocate_blocks: + # Enforce a GC collect ahead to minimize the impact among runs + gc.collect() + block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True) + + get_blocks_metric: Metric = Metric() + free_blocks_metric: Metric = Metric() + for _ in range(args.num_iteration): + t1 = time.monotonic_ns() + blocks = block_pool.get_new_blocks(allocate_block) + t2 = time.monotonic_ns() + block_pool.free_blocks(blocks) + t3 = time.monotonic_ns() + get_blocks_metric.update(t2 - t1) + free_blocks_metric.update(t3 - t2) + + if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None: + rows.append( + [ + get_blocks_metric.cnt, + args.num_gpu_blocks, + allocate_block, + get_blocks_metric.avg_v() / 1000000, + get_blocks_metric.max_v / 1000000.0, + free_blocks_metric.avg_v() / 1000000, + free_blocks_metric.max_v / 1000000.0, + ] + ) + else: + print( + "No valid metrics found." + f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}" + ) + + print( + tabulate( + rows, + headers=[ + "Iterations", + "Total\nBlocks", + "Allocated\nBlocks", + "Get Blocks\nAvg (ms)", + "Get Blocks\nMax (ms)", + "Free Blocks\nAvg (ms)", + "Free Blocks\nMax (ms)", + ], + tablefmt="grid", + floatfmt=".6f", + ) + ) + + +def invoke_main() -> None: + parser = FlexibleArgumentParser( + description="Benchmark the performance of BlockPool for KV Cache." + ) + parser.add_argument("--num-gpu-blocks", type=int, default=100000) + parser.add_argument( + "--num-iteration", + type=int, + default=1000, + help="Number of iterations to run to stablize final data readings", + ) + parser.add_argument( + "--allocate-blocks", + type=int, + nargs="*", + default=[10, 50, 100, 500, 1000], + help="Number of blocks to allocate", + ) + args = parser.parse_args() + main(args) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index fc7291972309..21fcee66d603 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -165,17 +165,32 @@ else() endif() # -# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) -# -if (AVX512_FOUND AND NOT AVX512_DISABLED) +# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms) +# Flag to enable ACL kernels for AARCH64 platforms +if ( VLLM_BUILD_ACL STREQUAL "ON") + set(USE_ACL ON) +else() + set(USE_ACL OFF) +endif() + +if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND) FetchContent_Declare( oneDNN GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git - GIT_TAG v3.7.1 + GIT_TAG v3.8.1 GIT_PROGRESS TRUE GIT_SHALLOW TRUE ) + if(USE_ACL) + find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/) + if(NOT ARM_COMPUTE_LIBRARY) + message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR") + endif() + set(ONEDNN_AARCH64_USE_ACL "ON") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/") + endif() + set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_BUILD_DOC "OFF") set(ONEDNN_BUILD_EXAMPLES "OFF") @@ -264,6 +279,11 @@ elseif(POWER10_FOUND) "csrc/cpu/quant.cpp" ${VLLM_EXT_SRC}) endif() +if (ASIMD_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) +endif() message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 79a546554fa1..57382c1ddc65 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -24,6 +24,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "../cuda_compat.h" #ifdef USE_ROCM #include @@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16; #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif - #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp new file mode 100644 index 000000000000..95e32559cd54 --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +// clang-format off +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + // printf("set_split_kv start"); + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + // std::cout << H << " " << K << " " << D << " " << B << "\n"; + int sm_count = args.hw_info.sm_count; + // printf(" sm_count = %d\n", sm_count); + int max_splits = ceil_div(K, 128); + max_splits = min(16, max_splits); + // printf(" max_splits = %d\n", max_splits); + int sms_per_batch = max(1, sm_count / B); + // printf(" sms_per_batch = %d\n", sms_per_batch); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + args.split_kv = split_wave_aware; + // printf(" args.split_kv = %d\n", args.split_kv); + + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 000000000000..7b6e1dd2657d --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = max(lse_max, local_lse[i]); + } + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset)); + } + lse_max = lse_max == -std::numeric_limits::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 000000000000..2cbc2379579e --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2023 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + alignas(2048) cute::array> smem_p; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + + if (args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + return false; + } + if (get<1>(args.problem_shape) <= 0) { + return false; + } + if (args.split_kv <= 0) { + return false; + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + +#if 1 + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } +#endif + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tItI = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tItI.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tItI); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tItI.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tItI); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + +#ifndef B2B + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } +#endif + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing +#ifndef B2B + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tItI = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tItI.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tItI) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tItI) == _1{}); + Tensor tAcc = tItI(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); +#endif + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tItI = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tItI.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tItI) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tItI) == _1{}); + Tensor tAcc = tItI(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + if (epilogue_args.ptr_o_acc != nullptr) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + #endif + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + +#ifndef B2B + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } +#endif + } + } + + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + +#ifdef B2B + row_sum = 1; +#else + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } +#endif + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv + ); + } + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 000000000000..c990ee2d856f --- /dev/null +++ b/csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,165 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ + +// clang-format off +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/csrc/attention/mla/sm100_cutlass_mla_kernel.cu b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu new file mode 100644 index 000000000000..e0e95d06290d --- /dev/null +++ b/csrc/attention/mla/sm100_cutlass_mla_kernel.cu @@ -0,0 +1,283 @@ +/* +Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +/* + * Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929 + * by Alcanderian JieXin Liang + */ +#include "core/registration.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include "cutlass_sm100_mla/device/sm100_mla.hpp" +#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" + +// clang-format off +#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 +void sm100_cutlass_mla_decode( + torch::Tensor const& out, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, + torch::Tensor const& workspace, + int64_t num_kv_splits) { + TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode"); +} +int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { + TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"); +} +#else + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \ + } + +using namespace cute; +using namespace cutlass::fmha::kernel; + +template +struct IsPersistent { + static const bool value = v; +}; + +template > +struct MlaSm100 { + using Element = T; + using ElementAcc = float; + using ElementOut = T; + + using TileShape = Shape<_128, _128, Shape<_512, _64>>; + using TileShapeH = cute::tuple_element_t<0, TileShape>; + using TileShapeD = cute::tuple_element_t<2, TileShape>; + + // H K (D_latent D_rope) B + using ProblemShape = cute::tuple; + + using StrideQ = cute::tuple; // H D B + using StrideK = cute::tuple; // K D B + using StrideO = StrideK; // H D B + using StrideLSE = cute::tuple<_1, int>; // H B + + using TileScheduler = + std::conditional_t; + + using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized< + TileShape, + Element, + ElementAcc, + ElementOut, + ElementAcc, + TileScheduler, + /*kIsCpAsync=*/!IsPaged128>; + using Fmha = cutlass::fmha::device::MLA; +}; + +template +typename T::Fmha::Arguments args_from_options( + at::Tensor const& out, + at::Tensor const& q_nope, + at::Tensor const& q_pe, + at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, + at::Tensor const& page_table, + double sm_scale, + int64_t num_kv_splits) { + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = q_nope.device().index(); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + int batches = q_nope.sizes()[0]; + int page_count_per_seq = page_table.sizes()[1]; + int page_count_total = kv_c_and_k_pe_cache.sizes()[0]; + int page_size = kv_c_and_k_pe_cache.sizes()[1]; + int max_seq_len = page_size * page_count_per_seq; + using TileShapeH = typename T::TileShapeH; + using TileShapeD = typename T::TileShapeD; + auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + float scale = float(sm_scale); + + using StrideQ = typename T::StrideQ; + using StrideK = typename T::StrideK; + using StrideO = typename T::StrideO; + using StrideLSE = typename T::StrideLSE; + + StrideQ stride_Q_nope = cute::make_tuple( + static_cast(q_nope.stride(1)), _1{}, static_cast(q_nope.stride(0))); + StrideQ stride_Q_pe = cute::make_tuple( + static_cast(q_pe.stride(1)), _1{}, static_cast(q_pe.stride(0))); + + StrideK stride_C = cute::make_tuple( + static_cast(0 + D_latent + D_rope), _1{}, static_cast(page_size * (D_latent + D_rope))); + StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq); + StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H); + StrideO stride_O = cute::make_tuple(static_cast(0 + D_latent), _1{}, static_cast(0 + H * D_latent)); + + using Element = typename T::Element; + using ElementOut = typename T::ElementOut; + using ElementAcc = typename T::ElementAcc; + auto Q_nope_ptr = static_cast(q_nope.data_ptr()); + auto Q_pe_ptr = static_cast(q_pe.data_ptr()); + auto C_ptr = static_cast(kv_c_and_k_pe_cache.data_ptr()); + typename T::Fmha::Arguments arguments{ + problem_shape, + {scale, + Q_nope_ptr, + stride_Q_nope, + Q_pe_ptr, + stride_Q_pe, + C_ptr, + stride_C, + C_ptr + D_latent, + stride_C, + static_cast(seq_lens.data_ptr()), + static_cast(page_table.data_ptr()), + stride_PT, + page_count_total, + page_size}, + {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, + hw_info, + // TODO(trevor-m): Change split_kv back to -1 when + // https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will + // perform worse with larger context length and smaller batch sizes. + num_kv_splits, // split_kv + nullptr, // is_var_split_kv + }; + // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute + // split_kv automatically based on batch size and sequence length to balance + // workload across available SMs. Consider using var_split_kv for manual + // control if needed. + T::Fmha::set_split_kv(arguments); + return arguments; +} + +template +void runMla( + at::Tensor const& out, + at::Tensor const& q_nope, + at::Tensor const& q_pe, + at::Tensor const& kv_c_and_k_pe_cache, + at::Tensor const& seq_lens, + at::Tensor const& page_table, + at::Tensor const& workspace, + double sm_scale, + int64_t num_kv_splits, + cudaStream_t stream) { + using MlaSm100Type = MlaSm100; + typename MlaSm100Type::Fmha fmha; + auto arguments = args_from_options(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits); + + CUTLASS_CHECK(fmha.can_implement(arguments)); + + CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream)); +} + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +void sm100_cutlass_mla_decode( + torch::Tensor const& out, + torch::Tensor const& q_nope, + torch::Tensor const& q_pe, + torch::Tensor const& kv_c_and_k_pe_cache, + torch::Tensor const& seq_lens, + torch::Tensor const& page_table, + torch::Tensor const& workspace, + double sm_scale, + int64_t num_kv_splits) { + auto in_dtype = q_nope.dtype(); + at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); + const int page_size = kv_c_and_k_pe_cache.sizes()[1]; + + // NOTE(alcanderian): IsPersistent has bug with manual split_kv. + // Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8) + // Maybe per batch split kv will fix this. + DISPATCH_BOOL(page_size == 128, IsPaged128, [&] { + DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] { + if (in_dtype == at::ScalarType::Half) { + runMla>( + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + runMla>( + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + } else if (in_dtype == at::ScalarType::Float8_e4m3fn) { + runMla>( + out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream); + } else { + TORCH_CHECK(false, "Unsupported input data type of MLA"); + } + return true; + }); + return true; + }); +} + +int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { + // Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc) + // which are float, so Element type here doesn't matter. + using MlaSm100Type = MlaSm100; + + // Get split kv. Requires problem shape and sm_count only. + typename MlaSm100Type::Fmha::Arguments arguments; + using TileShapeH = typename MlaSm100Type::TileShapeH; + using TileShapeD = typename MlaSm100Type::TileShapeD; + arguments.problem_shape = + cute::make_tuple(TileShapeH{}, static_cast(max_seq_len), TileShapeD{}, static_cast(num_batches)); + // Assumes device 0 when getting sm_count. + arguments.hw_info.sm_count = + sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count; + arguments.split_kv = num_kv_splits; + MlaSm100Type::Fmha::set_split_kv(arguments); + + return MlaSm100Type::Fmha::get_workspace_size(arguments); +} + +#endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) { + m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size); +} + +// clang-format on diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 46108a32d719..307300e55666 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -16,14 +16,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "attention_kernels.cuh" - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif +#include "../cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -80,7 +74,7 @@ void paged_attention_v1_launcher( const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_seq_len * sizeof(float); @@ -187,7 +181,6 @@ void paged_attention_v1( CALL_V1_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9358c0d9f6a2..eb9b4feb4a89 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -16,14 +16,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "attention_kernels.cuh" - -#ifndef USE_ROCM - #define WARP_SIZE 32 -#else - #define WARP_SIZE warpSize -#endif +#include "../cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -84,7 +78,7 @@ void paged_attention_v2_launcher( const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); @@ -197,7 +191,6 @@ void paged_attention_v2( CALL_V2_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/cpu/cpu_types_arm.hpp b/csrc/cpu/cpu_types_arm.hpp index 65ffe524af73..2251aac45e6f 100644 --- a/csrc/cpu/cpu_types_arm.hpp +++ b/csrc/cpu/cpu_types_arm.hpp @@ -33,6 +33,8 @@ namespace vec_op { #endif #define FORCE_INLINE __attribute__((always_inline)) inline +// Number of elements in single ASIMD vector of given Datatype +#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0])) namespace { template @@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec { } void save(void* ptr, const int elem_num) const { - int full_blocks = elem_num / 8; - int remainder = elem_num % 8; + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); if (full_blocks > 0) { vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]); @@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec { vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {}; void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; + void save(void* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + for (int i = 0; i < full_blocks; i++) + vst1q_bf16( + reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + if (remainder > 0) { + bfloat16x8_t temp = reg.val[full_blocks]; + bfloat16_t* base = reinterpret_cast(ptr) + full_blocks * 8; + if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3); + if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4); + if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5); + if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6); + } + }; }; struct BF16Vec32 : public Vec { @@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec { : reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {}; void save(void* ptr) const { *reinterpret_cast(ptr) = reg; }; + void save(void* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + for (int i = 0; i < full_blocks; i++) + vst1q_bf16( + reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + if (remainder > 0) { + bfloat16x8_t temp = reg.val[full_blocks]; + bfloat16_t* base = reinterpret_cast(ptr) + full_blocks * 8; + base[0] = vgetq_lane_bf16(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3); + if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4); + if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5); + if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6); + } + }; }; #endif @@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec { } }; +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + int32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + int32x4x4_t reg; + + explicit INT32Vec16(const void* ptr) { + reg.val[0] = vld1q_s32(reinterpret_cast(ptr)); + reg.val[1] = vld1q_s32(reinterpret_cast(ptr) + 4); + reg.val[2] = vld1q_s32(reinterpret_cast(ptr) + 8); + reg.val[3] = vld1q_s32(reinterpret_cast(ptr) + 12); + } + + void save(int32_t* ptr) const { + vst1q_s32(ptr, reg.val[0]); + vst1q_s32(ptr + 4, reg.val[1]); + vst1q_s32(ptr + 8, reg.val[2]); + vst1q_s32(ptr + 12, reg.val[3]); + }; + + void save(int32_t* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + + for (int i = 0; i < full_blocks; i++) + vst1q_s32( + reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + + if (remainder > 0) { + int32x4_t temp = reg.val[full_blocks]; + int32_t* base = reinterpret_cast(ptr) + full_blocks * 4; + if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3); + } + } +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec { reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1])); reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1])); }; - + explicit FP32Vec16(const INT32Vec16& v) { + reg.val[0] = vcvtq_f32_s32(v.reg.val[0]); + reg.val[1] = vcvtq_f32_s32(v.reg.val[1]); + reg.val[2] = vcvtq_f32_s32(v.reg.val[2]); + reg.val[3] = vcvtq_f32_s32(v.reg.val[3]); + }; FP32Vec16 operator+(const FP32Vec16& b) const { return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]), vaddq_f32(reg.val[1], b.reg.val[1]), @@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec { vdivq_f32(reg.val[3], b.reg.val[3])})); }; + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(float32x4x4_t( + {vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])), + vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])), + vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])), + vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))})); + }; + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]), + vmaxq_f32(b.reg.val[1], reg.val[1]), + vmaxq_f32(b.reg.val[2], reg.val[2]), + vmaxq_f32(b.reg.val[3], reg.val[3])})); + }; + + FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + float32x4x4_t temp; + + for (int i = 0; i < full_blocks; i++) + temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]); + + if (remainder > 0) { + float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0), + vgetq_lane_f32(b.reg.val[full_blocks], 0)); + temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0); + } + if (remainder > 1) { + float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1), + vgetq_lane_f32(b.reg.val[full_blocks], 1)); + temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1); + } + if (remainder > 2) { + float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2), + vgetq_lane_f32(b.reg.val[full_blocks], 2)); + temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2); + } + return FP32Vec16(temp); + }; + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(float32x4x4_t({ + vminq_f32(b.reg.val[0], reg.val[0]), + vminq_f32(b.reg.val[1], reg.val[1]), + vminq_f32(b.reg.val[2], reg.val[2]), + vminq_f32(b.reg.val[3], reg.val[3]), + })); + }; + FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + float32x4x4_t temp; + for (int i = 0; i < full_blocks; i++) + temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]); + + if (remainder > 0) { + float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0), + vgetq_lane_f32(b.reg.val[full_blocks], 0)); + temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0); + } + if (remainder > 1) { + float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1), + vgetq_lane_f32(b.reg.val[full_blocks], 1)); + temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1); + } + if (remainder > 2) { + float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2), + vgetq_lane_f32(b.reg.val[full_blocks], 2)); + temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2); + } + + return FP32Vec16(temp); + }; + FP32Vec16 abs() const { + return FP32Vec16( + float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]), + vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])})); + } float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec { return answer; }; + float reduce_max() const { + AliasReg ar; + ar.reg = reg; + float max_v = std::numeric_limits::lowest(); + unroll_loop( + [&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); }); + return max_v; + } + + float reduce_min() const { + AliasReg ar; + ar.reg = reg; + float min_v = std::numeric_limits::max(); + unroll_loop( + [&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); }); + return min_v; + } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); @@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec { vst1q_f32(ptr + 8, reg.val[2]); vst1q_f32(ptr + 12, reg.val[3]); }; + + void save(float* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]); + int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]); + + for (int i = 0; i < full_blocks; i++) + vst1q_f32( + reinterpret_cast(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i, + reg.val[i]); + + if (remainder > 0) { + float32x4_t temp = reg.val[full_blocks]; + float* base = reinterpret_cast(ptr) + + full_blocks * NUM_ELEMENTS_REG(reg.val[0]); + if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2); + } + } +}; + +struct INT8Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + int8x16_t reg; + int8_t values[VEC_ELEM_NUM]; + }; + int8x16_t reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + // Convert each 128-bit float32 vector to int32 + int32x4_t part0 = + vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block + int32x4_t part1 = + vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block + int32x4_t part2 = + vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block + int32x4_t part3 = + vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block + + // Narrow each 32-bit vector to 8 bits and combine + int8x8_t lower = + vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1))); + int8x8_t upper = + vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3))); + reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector + } + + void save(int8_t* ptr) const { vst1q_s8(ptr, reg); }; + + void save(int8_t* ptr, const int elem_num) const { + int full_blocks = elem_num / NUM_ELEMENTS_REG(reg); + int remainder = elem_num % NUM_ELEMENTS_REG(reg); + + for (int i = 0; i < full_blocks; i++) + vst1q_s8(reinterpret_cast(ptr) + NUM_ELEMENTS_REG(reg) * i, reg); + if (remainder > 0) { + int8x16_t temp = reg; + int8_t* base = + reinterpret_cast(ptr) + full_blocks * NUM_ELEMENTS_REG(reg); + if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0); + if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1); + if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2); + if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3); + if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4); + if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5); + if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6); + if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7); + if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8); + if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9); + if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10); + if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11); + if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12); + if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13); + if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14); + } + }; }; template diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp index 8b5011dc065f..1cb8dc5b25a6 100644 --- a/csrc/cpu/dnnl_helper.hpp +++ b/csrc/cpu/dnnl_helper.hpp @@ -57,6 +57,7 @@ class DNNLPrimitiveHelper { // Note: Due to the limitation of oneDNN // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is // not supported. + template static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, @@ -90,6 +91,27 @@ class DNNLPrimitiveHelper { } dnnl::matmul::primitive_desc matmul_pd; +// Create memory descriptors with format_tag::any for the primitive. This +// enables the matmul primitive to choose memory layouts for an +// optimized primitive implementation, and these layouts may differ from the +// ones provided by the user. +#ifdef __aarch64__ + auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8, + dnnl::memory::format_tag::any); + auto mat_weights_md = dnnl::memory::desc( + {K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any); + auto mat_dst_md = + dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any); + if (bias) { + dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md, + mat_weights_md, bias_md, + mat_dst_md, attr); + } else { + matmul_pd = dnnl::matmul::primitive_desc( + default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr); + } +#else if (bias) { dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, @@ -98,6 +120,7 @@ class DNNLPrimitiveHelper { matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md, attr); } +#endif dnnl::matmul matmul(matmul_pd); auto& engine = default_engine(); @@ -111,24 +134,34 @@ class DNNLPrimitiveHelper { (void*)b_scales); auto& stream = default_stream(); + + auto mat_src_mem = a_m; + auto mat_weights_mem = b_m; + auto mat_dst_mem = c_m; +#ifdef __aarch64__ + if (matmul_pd.weights_desc() != b_m.get_desc()) { + mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine); + dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem); + } +#endif if constexpr (InputNoScale) { if (bias) { dnnl::memory::desc bias_md({N}, BiasType, {1}); dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } @@ -138,19 +171,19 @@ class DNNLPrimitiveHelper { dnnl::memory bias_m(bias_md, engine, (void*)bias); matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, {DNNL_ARG_BIAS, bias_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); } else { matmul.execute( stream, { - {DNNL_ARG_SRC, a_m}, - {DNNL_ARG_WEIGHTS, b_m}, - {DNNL_ARG_DST, c_m}, + {DNNL_ARG_SRC, mat_src_mem}, + {DNNL_ARG_WEIGHTS, mat_weights_mem}, + {DNNL_ARG_DST, mat_dst_mem}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, }); @@ -170,5 +203,4 @@ class DNNLPrimitiveHelper { return stream; } }; - #endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index f61dbcc948e8..c1f7c64ea2f4 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -36,7 +36,7 @@ struct KernelVecType { using cvt_vec_type = vec_op::FP32Vec16; }; -#ifdef __AVX512F__ +#if defined(__AVX512F__) || defined(__aarch64__) template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, @@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK( - false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") + TORCH_CHECK(false, + "static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 " + "support.") } template @@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK( - false, - "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") + TORCH_CHECK(false, + "dynamic_scaled_int8_quant_impl requires " + "AVX512/powerpc64/AArch64 support.") } template @@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") + TORCH_CHECK( + false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.") } template @@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, - "dynamic_quant_epilogue requires AVX512/powerpc64 support.") + TORCH_CHECK( + false, + "dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.") } #endif } // namespace diff --git a/csrc/cpu/sgl-kernels/common.h b/csrc/cpu/sgl-kernels/common.h index 20261c1ef3e8..b96037e82c19 100644 --- a/csrc/cpu/sgl-kernels/common.h +++ b/csrc/cpu/sgl-kernels/common.h @@ -58,7 +58,7 @@ namespace { #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_LAST_DIM_CONTIGUOUS(x) \ - TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") #define CHECK_INPUT(x) \ CHECK_CPU(x); \ diff --git a/csrc/cpu/sgl-kernels/gemm.h b/csrc/cpu/sgl-kernels/gemm.h index afae19721ae9..fba5673323f5 100644 --- a/csrc/cpu/sgl-kernels/gemm.h +++ b/csrc/cpu/sgl-kernels/gemm.h @@ -126,7 +126,7 @@ void fused_experts_int4_w4a16_kernel_impl( int64_t topk, int64_t num_tokens_post_pad); -// shared expert implememntation for int8 w8a8 +// shared expert implementation for int8 w8a8 template void shared_expert_int8_kernel_impl( scalar_t* __restrict__ output, diff --git a/csrc/cpu/sgl-kernels/gemm_int8.cpp b/csrc/cpu/sgl-kernels/gemm_int8.cpp index 5a0f65a9200d..9a5ca0642e7a 100644 --- a/csrc/cpu/sgl-kernels/gemm_int8.cpp +++ b/csrc/cpu/sgl-kernels/gemm_int8.cpp @@ -41,7 +41,7 @@ struct tinygemm_kernel_nn { __m512 vd0; __m512 vd1[COLS]; - // oops! 4x4 spills but luckly we use 4x2 + // oops! 4x4 spills but luckily we use 4x2 __m512 vbias[COLS]; // [NOTE]: s8s8 igemm compensation in avx512-vnni diff --git a/csrc/cpu/sgl-kernels/vec.h b/csrc/cpu/sgl-kernels/vec.h index 87955cfb2922..160845c9b1cb 100644 --- a/csrc/cpu/sgl-kernels/vec.h +++ b/csrc/cpu/sgl-kernels/vec.h @@ -37,7 +37,7 @@ inline Vectorized convert_from_float_ext(const Vecto #define CVT_FP16_TO_FP32(a) \ _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) -// this doesn't hanel NaN. +// this doesn't handle NaN. inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index 9adb6f27ec41..7e64e1c52198 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -7,7 +7,7 @@ namespace { #define MAX_SHM_RANK_NUM 8 -#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) +#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); #define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) #define MIN_THREAD_PROCESS_SIZE (256) @@ -34,9 +34,10 @@ struct KernelVecType { }; struct ThreadSHMContext { - volatile char _curr_thread_stamp; - volatile char _ready_thread_stamp; - char _padding1[6]; + volatile char _curr_thread_stamp[2]; + volatile char _ready_thread_stamp[2]; + int local_stamp_buffer_idx; + int remote_stamp_buffer_idx; int thread_id; int thread_num; int rank; @@ -45,23 +46,28 @@ struct ThreadSHMContext { int swizzled_ranks[MAX_SHM_RANK_NUM]; void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; - size_t _thread_buffer_mask; - char _padding2[56]; + size_t _thread_buffer_mask[2]; + char _padding2[40]; ThreadSHMContext(const int thread_id, const int thread_num, const int rank, const int group_size, void* thread_shm_ptr) - : _curr_thread_stamp(1), - _ready_thread_stamp(0), + : local_stamp_buffer_idx(0), + remote_stamp_buffer_idx(0), thread_id(thread_id), thread_num(thread_num), rank(rank), group_size(group_size), - _spinning_count(0), - _thread_buffer_mask(0) { + _spinning_count(0) { static_assert(sizeof(ThreadSHMContext) % 64 == 0); TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); TORCH_CHECK((size_t)this % 64 == 0); TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0); + _curr_thread_stamp[0] = 1; + _curr_thread_stamp[1] = 1; + _ready_thread_stamp[0] = 0; + _ready_thread_stamp[1] = 0; + _thread_buffer_mask[0] = 0; + _thread_buffer_mask[1] = 0; for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) { shm_contexts[i] = nullptr; thread_shm_ptrs[i] = nullptr; @@ -70,6 +76,11 @@ struct ThreadSHMContext { set_context(rank, this, thread_shm_ptr); } + void set_stamp_buffer_idx(int local, int remote) { + local_stamp_buffer_idx = local; + remote_stamp_buffer_idx = remote; + } + void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) { TORCH_CHECK(rank < MAX_SHM_RANK_NUM); TORCH_CHECK(ptr); @@ -84,23 +95,27 @@ struct ThreadSHMContext { T* get_thread_shm_ptr(int rank) { return reinterpret_cast( reinterpret_cast(thread_shm_ptrs[rank]) + - (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); + (PER_THREAD_SHM_BUFFER_OFFSET & + _thread_buffer_mask[local_stamp_buffer_idx])); } - void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } + void next_buffer() { + _thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF; + } - char get_curr_stamp() const { return _curr_thread_stamp; } + char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; } - char get_ready_stamp() const { return _ready_thread_stamp; } + char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; } void next_stamp() { _mm_mfence(); - _curr_thread_stamp += 1; + _curr_thread_stamp[local_stamp_buffer_idx] += 1; } void commit_ready_stamp() { _mm_mfence(); - _ready_thread_stamp = _curr_thread_stamp; + _ready_thread_stamp[local_stamp_buffer_idx] = + _curr_thread_stamp[local_stamp_buffer_idx]; } int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } @@ -117,10 +132,11 @@ struct ThreadSHMContext { void wait_for_one(int rank, Cond&& cond) { ThreadSHMContext* rank_ctx = shm_contexts[rank]; for (;;) { - char local_curr_stamp = get_curr_stamp(); - char local_ready_stamp = get_ready_stamp(); - char rank_curr_stamp = rank_ctx->get_curr_stamp(); - char rank_ready_stamp = rank_ctx->get_ready_stamp(); + char local_curr_stamp = get_curr_stamp(local_stamp_buffer_idx); + char local_ready_stamp = get_ready_stamp(local_stamp_buffer_idx); + char rank_curr_stamp = rank_ctx->get_curr_stamp(remote_stamp_buffer_idx); + char rank_ready_stamp = + rank_ctx->get_ready_stamp(remote_stamp_buffer_idx); if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, rank_ready_stamp)) { break; @@ -361,6 +377,15 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { } } } + +void reset_threads_stamp_buffer_idx(ThreadSHMContext* ctx, int local, + int remote) { + int thread_num = ctx->thread_num; + for (int i = 0; i < thread_num; ++i) { + ThreadSHMContext* thread_ctx = ctx + i; + thread_ctx->set_stamp_buffer_idx(local, remote); + } +} }; // namespace shm_cc_ops namespace shm_cc_ops { @@ -632,6 +657,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta(); metadata->bind_tensor_list(tensor_list_with_metadata); + shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 0, 1); shm_cc_ops::shm_cc_loop( ctx, metadata->total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, @@ -659,6 +685,7 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, torch::Tensor metadata_tensor = torch::empty({sizeof(TensorListMeta)}, options); + shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 1, 0); ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); shm_cc_ops::memcpy(metadata_tensor.data_ptr(), ctx->get_thread_shm_ptr(src), @@ -677,7 +704,7 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, ctx, metadata.total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, int64_t data_elem_num, bool fast_mode) { - ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); + thread_ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); int64_t curr_shm_offset = 0; while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index ebfc81f85836..f1738aee980b 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); // Quantization -#ifdef __AVX512F__ +#if defined(__AVX512F__) || defined(__aarch64__) at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 82e55613d915..d7d589db62cf 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -4,10 +4,37 @@ #include #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 +#ifdef USE_ROCM +struct Utils { + static __host__ int get_warp_size() { + static bool is_cached = false; + static int result; + + if (!is_cached) { + int device_id; + cudaDeviceProp deviceProp; + cudaGetDevice(&device_id); + cudaGetDeviceProperties(&deviceProp, device_id); + + result = deviceProp.warpSize; + is_cached = true; + } + + return result; + } + + static __device__ constexpr int get_warp_size() { + #ifdef __GFX9__ + return 64; + #else + return 32; + #endif + } +}; + + #define WARP_SIZE Utils::get_warp_size() #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #ifndef USE_ROCM diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index 64b7ddae3d2d..ad8c0067d4a9 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -153,7 +153,7 @@ struct ScaledEpilogueBias cutlass::epilogue::threadblock::Sm80EVT; using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: @@ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp EVTComputeAzp>; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: @@ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken EVTComputeAcc>; using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 62b848a0a963..cf79507e1997 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -195,7 +195,7 @@ struct ScaledEpilogueBias cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: @@ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias cutlass::epilogue::fusion::Sm90EVT; using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: @@ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: @@ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken cutlass::epilogue::fusion::Sm90EVT; using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, + cutlass::homogeneous_multiply_add, ElementD, float, cutlass::FloatRoundStyle::round_to_nearest>; public: diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index d073dd6d2dee..f051eb070222 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -15,15 +15,16 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride, const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * input_stride + idx]; variance += x * x; } @@ -37,7 +38,7 @@ __global__ void rms_norm_kernel( __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; + float x = (float)input[blockIdx.x * input_stride + idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } @@ -50,7 +51,8 @@ __global__ void rms_norm_kernel( template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -59,6 +61,7 @@ fused_add_rms_norm_kernel( static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; + const int64_t vec_input_stride = input_stride / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are @@ -73,7 +76,8 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; - _f16Vec temp = input_v[id]; + int64_t strided_id = blockIdx.x * vec_input_stride + idx; + _f16Vec temp = input_v[strided_id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; @@ -90,10 +94,11 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { int id = blockIdx.x * vec_hidden_size + idx; + int64_t strided_id = blockIdx.x * vec_input_stride + idx; _f16Vec temp = residual_v[id]; temp *= s_variance; temp *= weight_v[idx]; - input_v[id] = temp; + input_v[strided_id] = temp; } } @@ -103,7 +108,8 @@ fused_add_rms_norm_kernel( template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_kernel( - scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int64_t input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { @@ -111,7 +117,7 @@ fused_add_rms_norm_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - scalar_t z = input[blockIdx.x * hidden_size + idx]; + scalar_t z = input[blockIdx.x * input_stride + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; @@ -129,7 +135,7 @@ fused_add_rms_norm_kernel( for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)residual[blockIdx.x * hidden_size + idx]; - input[blockIdx.x * hidden_size + idx] = + input[blockIdx.x * input_stride + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; } } @@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(input.stride(-1) == 1); TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; + int64_t input_stride = input.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); @@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), + out.data_ptr(), input.data_ptr(), input_stride, weight.data_ptr(), epsilon, num_tokens, hidden_size); }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ - vllm::fused_add_rms_norm_kernel \ - <<>>(input.data_ptr(), \ - residual.data_ptr(), \ - weight.data_ptr(), epsilon, \ - num_tokens, hidden_size); \ +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>( \ + input.data_ptr(), input_stride, \ + residual.data_ptr(), weight.data_ptr(), \ + epsilon, num_tokens, hidden_size); \ }); void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(residual.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); int hidden_size = input.size(-1); + int64_t input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] auto inp_ptr = reinterpret_cast(input.data_ptr()); auto res_ptr = reinterpret_cast(residual.data_ptr()); auto wt_ptr = reinterpret_cast(weight.data_ptr()); - bool ptrs_are_aligned = - inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0) { + constexpr int vector_width = 8; + constexpr int req_alignment_bytes = + vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32 + // falls back to non-vectorized version anyway) + bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 && + res_ptr % req_alignment_bytes == 0 && + wt_ptr % req_alignment_bytes == 0; + bool offsets_are_multiple_of_vector_width = + hidden_size % vector_width == 0 && input_stride % vector_width == 0; + if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index d595b9e889c8..0fd5849d9626 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -23,8 +23,9 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. template __global__ void rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] const float epsilon, const int num_tokens, const int hidden_size) { @@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - const float x = (float)input[blockIdx.x * hidden_size + idx]; + const float x = (float)input[blockIdx.x * input_stride + idx]; variance += x * x; } @@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( float const scale_inv = 1.0f / *scale; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - float x = (float)input[blockIdx.x * hidden_size + idx]; + float x = (float)input[blockIdx.x * input_stride + idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion(out_norm, scale_inv); @@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel( template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - scalar_t* __restrict__ input, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); const int vec_hidden_size = hidden_size / width; + const int vec_input_stride = input_stride / width; __shared__ float s_variance; float variance = 0.0f; /* These and the argument pointers are all declared `restrict` as they are @@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel( reinterpret_cast*>(weight); for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int stride_id = blockIdx.x * vec_input_stride + idx; int id = blockIdx.x * vec_hidden_size + idx; - _f16Vec temp = input_v[id]; + _f16Vec temp = input_v[stride_id]; temp += residual_v[id]; variance += temp.sum_squares(); residual_v[id] = temp; @@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel( template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( - fp8_type* __restrict__ out, // [..., hidden_size] - scalar_t* __restrict__ input, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { - scalar_t z = input[blockIdx.x * hidden_size + idx]; + scalar_t z = input[blockIdx.x * input_stride + idx]; z += residual[blockIdx.x * hidden_size + idx]; float x = (float)z; variance += x * x; @@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] torch::Tensor& scale, // [1] double epsilon) { + TORCH_CHECK(out.is_contiguous()); int hidden_size = input.size(-1); + int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] vllm::rms_norm_static_fp8_quant_kernel <<>>( out.data_ptr(), input.data_ptr(), - weight.data_ptr(), scale.data_ptr(), - epsilon, num_tokens, hidden_size); + input_stride, weight.data_ptr(), + scale.data_ptr(), epsilon, num_tokens, + hidden_size); }); }); } @@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] width, fp8_t> \ <<>>( \ out.data_ptr(), input.data_ptr(), \ - residual.data_ptr(), \ + input_stride, residual.data_ptr(), \ weight.data_ptr(), scale.data_ptr(), \ epsilon, num_tokens, hidden_size); \ }); \ @@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant( torch::Tensor& weight, // [hidden_size] torch::Tensor& scale, // [1] double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(residual.is_contiguous()); int hidden_size = input.size(-1); + int input_stride = input.stride(-2); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); @@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant( auto wt_ptr = reinterpret_cast(weight.data_ptr()); bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; - if (ptrs_are_aligned && hidden_size % 8 == 0) { + if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) { LAUNCH_FUSED_ADD_RMS_NORM(8); } else { LAUNCH_FUSED_ADD_RMS_NORM(0); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu deleted file mode 100644 index c83d72751a55..000000000000 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ /dev/null @@ -1,656 +0,0 @@ -// clang-format off -// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu -// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu -#include -#include -#include - -#include "causal_conv1d.h" -#include -#include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include -#include - -#ifdef USE_ROCM - namespace cub = hipcub; -#endif - -#include "static_switch.h" - - - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - -void set_conv_params_fwd(ConvParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t width, - // device pointers - const at::Tensor x, - const at::Tensor weight, - const at::Tensor out, - const std::optional& bias, - bool silu_activation, - int64_t pad_slot_id, - const std::optional& query_start_loc = std::nullopt, - const std::optional& cache_indices = std::nullopt, - const std::optional& has_initial_state = std::nullopt) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.width = width; - params.pad_slot_id = pad_slot_id; - - params.silu_activation = silu_activation; - - // Set the pointers and strides. - params.x_ptr = x.data_ptr(); - params.weight_ptr = weight.data_ptr(); - params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; - params.out_ptr = out.data_ptr(); - // All stride are in elements, not bytes. - params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; - params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; - params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; - const bool varlen = params.query_start_loc_ptr != nullptr; - params.x_batch_stride = x.stride(varlen ? 1 : 0); - params.x_c_stride = x.stride(varlen ? 0 : 1); - params.x_l_stride = x.stride(varlen ? 1 : -1); - params.weight_c_stride = weight.stride(0); - params.weight_width_stride = weight.stride(1); - params.out_batch_stride = out.stride(varlen ? 1 : 0); - params.out_c_stride = out.stride(varlen ? 0 : 1); - params.out_l_stride = out.stride(varlen ? 1 : -1); -} - - -void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, - const std::optional &bias_, - const std::optional &conv_states, - const std::optional &query_start_loc, - const std::optional &cache_indices, - const std::optional &has_initial_state, - bool silu_activation, - // used to identify padding entries if cache_indices provided - // in case of padding, the kernel will return early - int64_t pad_slot_id) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const bool varlen = query_start_loc.has_value() ? true : false; - const auto sizes = x.sizes(); - const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; - const int dim = varlen ? sizes[0] : sizes[1]; - const int seqlen = varlen ? sizes[1] : sizes[2]; - const int width = weight.size(-1); - if (varlen){ - CHECK_SHAPE(x, dim, seqlen); - } - else { - CHECK_SHAPE(x, batch_size, dim, seqlen); - } - CHECK_SHAPE(weight, dim, width); - - - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - - if (has_initial_state.has_value()) { - auto has_initial_state_ = has_initial_state.value(); - TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); - TORCH_CHECK(has_initial_state_.is_cuda()); - CHECK_SHAPE(has_initial_state_, batch_size); - } - - - if (query_start_loc.has_value()) { - auto query_start_loc_ = query_start_loc.value(); - TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(query_start_loc_.is_cuda()); - } - - - if (cache_indices.has_value()) { - auto cache_indices_ = cache_indices.value(); - TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(cache_indices_.is_cuda()); - CHECK_SHAPE(cache_indices_, batch_size); - } - - at::Tensor out = x; - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_, - silu_activation, - pad_slot_id, - query_start_loc, - cache_indices, - has_initial_state - ); - - if (conv_states.has_value()) { - auto conv_states_ = conv_states.value(); - TORCH_CHECK(conv_states_.scalar_type() == input_type); - TORCH_CHECK(conv_states_.is_cuda()); - params.conv_states_ptr = conv_states_.data_ptr(); - params.conv_states_batch_stride = conv_states_.stride(0); - params.conv_states_c_stride = conv_states_.stride(1); - params.conv_states_l_stride = conv_states_.stride(2); - } else { - params.conv_states_ptr = nullptr; - } - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { - causal_conv1d_fwd_cuda(params, stream); - }); -} - - -void causal_conv1d_update(const at::Tensor &x, - const at::Tensor &conv_state, - const at::Tensor &weight, - const std::optional &bias_, - bool silu_activation, - const std::optional &cache_seqlens_, - const std::optional &conv_state_indices_, - // used to identify padding entries if cache_indices provided - // in case of padding, the kernel will return early - int64_t pad_slot_id) { - auto input_type = x.scalar_type(); - auto weight_type = weight.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); - TORCH_CHECK(conv_state.scalar_type() == input_type); - - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(conv_state.is_cuda()); - TORCH_CHECK(weight.is_cuda()); - - const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int width = weight.size(-1); - const int conv_state_len = conv_state.size(2); - TORCH_CHECK(conv_state_len >= width - 1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); - CHECK_SHAPE(weight, dim, width); - - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); - - if (bias_.has_value()) { - auto bias = bias_.value(); - TORCH_CHECK(bias.scalar_type() == weight_type); - TORCH_CHECK(bias.is_cuda()); - TORCH_CHECK(bias.stride(-1) == 1); - CHECK_SHAPE(bias, dim); - } - - at::Tensor out = x; - - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_, - silu_activation, - pad_slot_id); - params.conv_state_ptr = conv_state.data_ptr(); - params.conv_state_len = conv_state_len; - // All stride are in elements, not bytes. - params.conv_state_batch_stride = conv_state.stride(0); - params.conv_state_c_stride = conv_state.stride(1); - params.conv_state_l_stride = conv_state.stride(2); - - if (cache_seqlens_.has_value()) { - auto cache_seqlens = cache_seqlens_.value(); - TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); - TORCH_CHECK(cache_seqlens.is_cuda()); - TORCH_CHECK(cache_seqlens.stride(-1) == 1); - CHECK_SHAPE(cache_seqlens, batch_size); - params.cache_seqlens = cache_seqlens.data_ptr(); - } else { - params.cache_seqlens = nullptr; - } - - if (conv_state_indices_.has_value()) { - auto conv_state_indices = conv_state_indices_.value(); - TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) - TORCH_CHECK(conv_state_indices.is_cuda()); - TORCH_CHECK(conv_state_indices.stride(0) == 1) - CHECK_SHAPE(conv_state_indices, batch_size); - - int conv_state_entries = conv_state.size(0); - CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); - - params.conv_state_indices_ptr = conv_state_indices.data_ptr(); - } else { - CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); - params.conv_state_indices_ptr = nullptr; - } - - const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { - causal_conv1d_update_cuda(params, stream); - }); -} - -template -struct Causal_conv1d_fwd_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static_assert(kWidth <= kNElts); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - static constexpr int kSmemIOSize = kIsVecLoad - ? 0 - : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); - static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - extern __shared__ char smem_[]; - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_vec = reinterpret_cast(smem_); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_store_vec = reinterpret_cast(smem_); - vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - - const bool kVarlen = params.query_start_loc_ptr != nullptr; - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y; - const int *query_start_loc = kVarlen ? reinterpret_cast(params.query_start_loc_ptr) : nullptr; - const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; - const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; - - input_t *x = reinterpret_cast(params.x_ptr) + sequence_start_index * params.x_batch_stride - + channel_id * params.x_c_stride; - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - bool has_initial_state = params.has_initial_state_ptr == nullptr ? false - : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; - - int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr - : reinterpret_cast(params.cache_indices_ptr); - int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - // cache_index == params.pad_slot_id is defined as padding, so we exit early - if (cache_index == params.pad_slot_id){ - return; - } - input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr - : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; - - // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. - if (tidx == 0) { - input_t initial_state[kNElts] = {0}; - if (has_initial_state) { - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } - } - smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; - } - - float weight_vals[kWidth]; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; - for (int chunk = 0; chunk < n_chunks; ++chunk) { - input_t x_vals_load[2 * kNElts] = {0}; - if constexpr(kIsVecLoad) { - typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); - } else { - __syncthreads(); - typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); - } - x += kChunkSize; - __syncthreads(); - // Thread kNThreads - 1 don't write yet, so that thread 0 can read - // the last elements of the previous chunk. - if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - __syncthreads(); - reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; - __syncthreads(); - // Now thread kNThreads - 1 can write the last elements of the current chunk. - if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } - - float x_vals[2 * kNElts]; - #pragma unroll - for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } - - float out_vals[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = bias_val; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; - } - } - - if (params.silu_activation) { - #pragma unroll - for (int i = 0; i < kNElts; ++i) { - out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); - } - } - - input_t out_vals_store[kNElts]; - #pragma unroll - for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } - if constexpr(kIsVecLoad) { - typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); - } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); - } - out += kChunkSize; - - int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); - // in case the final state is separated between the last "smem_exchange" and - // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), - // (which occurs when `final_state_position` is a non-positive index) - // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it - if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){ - input_t vals_load[kNElts] = {0}; - if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ - // chunk = n_chunks - 2, a segment of the final state sits in the last index - reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; - #pragma unroll - for (int w = 0; w < -final_state_position; ++w){ - conv_states[w] = vals_load[kNElts + final_state_position + w]; - } - } - if ((chunk == n_chunks - 1) && tidx == 0){ - // chunk = n_chunks - 1, the second segment of the final state first positions - reinterpret_cast(vals_load)[0] = smem_exchange[0]; - for (int w = -final_state_position; w < kWidth - 1; ++w){ - conv_states[w] = vals_load[w + final_state_position]; - } - return; - } - } - } - // Final state is stored in the smem_exchange last token slot, - // in case seqlen < kWidth, we would need to take the final state from the - // initial state which is stored in conv_states - // in case seqlen > kWidth, we would need to load the last kWidth - 1 data - // and load it into conv_state accordingly - int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; - if (conv_states != nullptr && tidx == last_thread) { - input_t x_vals_load[kNElts * 2] = {0}; - // in case we are on the first kWidth tokens - if (last_thread == 0 && seqlen < kWidth){ - // Need to take the initial state - reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; - const int offset = seqlen - (kWidth - 1); - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ - // pad the existing state - if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } - else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } - } - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ - if (offset + w >= 0) - conv_states[w] = x_vals_load[offset + w ]; - } - } - else { - // in case the final state is in between the threads data - const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); - if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ - // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a - // illegal access error on H100. - // Therefore, we access last_thread + 1, only if the final state data sits there - reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; - } - reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; - #pragma unroll - for (int w = 0; w < kWidth - 1; ++w){ - conv_states[w] = x_vals_load[offset + w ]; - } - } - - } -} - - -template -void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - const bool kVarlen = params.query_start_loc_ptr != nullptr; - BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { - using Ktraits = Causal_conv1d_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize; - dim3 grid(params.batch, params.dim); - - auto kernel = &causal_conv1d_fwd_kernel; - - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; - } - kernel<<>>(params); - - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -template -void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} - - -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); - - - - -template -struct Causal_conv1d_update_kernel_traits { - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kWidth = kWidth_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_update_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - - const int tidx = threadIdx.x; - const int batch_id = blockIdx.x; - const int channel_id = blockIdx.y * kNThreads + tidx; - if (channel_id >= params.dim) return; - - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + channel_id * params.x_c_stride; - - // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor - // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. - const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr - ? batch_id - : params.conv_state_indices_ptr[batch_id]; - // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early - if (conv_state_batch_coord == params.pad_slot_id){ - return; - } - input_t *conv_state = reinterpret_cast(params.conv_state_ptr) - + conv_state_batch_coord * params.conv_state_batch_stride - + channel_id * params.conv_state_c_stride; - - weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); - - int state_len = params.conv_state_len; - int advance_len = params.seqlen; - int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; - int update_idx = cache_seqlen - (kWidth - 1); - update_idx = update_idx < 0 ? update_idx + state_len : update_idx; - - float weight_vals[kWidth] = {0}; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - - float x_vals[kWidth] = {0}; - if constexpr (!kIsCircularBuffer) { - #pragma unroll 2 - for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { - conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; - } - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { - input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; - if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { - conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; - } - x_vals[i] = float(state_val); - } - } else { - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { - input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; - x_vals[i] = float(state_val); - } - } - #pragma unroll 2 - for (int i = 0; i < params.seqlen; ++i) { - input_t x_val = x[i * params.x_l_stride]; - if constexpr (!kIsCircularBuffer) { - if (i < advance_len && state_len - advance_len + i >= 0) { - conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; - } - } else { - conv_state[update_idx * params.conv_state_l_stride] = x_val; - ++update_idx; - update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; - } - x_vals[kWidth - 1] = float(x_val); - float out_val = bias_val; - #pragma unroll - for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } - if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - out[i * params.out_l_stride] = input_t(out_val); - // Shift the input buffer by 1 - #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } - } -} - -template -void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - using Ktraits = Causal_conv1d_update_kernel_traits; - dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = params.cache_seqlens == nullptr - ? &causal_conv1d_update_kernel - : &causal_conv1d_update_kernel; - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); - } -} - -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h deleted file mode 100644 index e26684a2b98b..000000000000 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ /dev/null @@ -1,159 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ -// clang-format off -// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h -#pragma once - -#include -#include -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ConvParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, width; - int64_t pad_slot_id; - bool silu_activation; - - index_t x_batch_stride; - index_t x_c_stride; - index_t x_l_stride; - index_t weight_c_stride; - index_t weight_width_stride; - index_t out_batch_stride; - index_t out_c_stride; - index_t out_l_stride; - - int conv_state_len; - index_t conv_state_batch_stride; - index_t conv_state_c_stride; - index_t conv_state_l_stride; - - // Common data pointers. - void *__restrict__ x_ptr; - void *__restrict__ weight_ptr; - void *__restrict__ bias_ptr; - void *__restrict__ out_ptr; - - void *__restrict__ conv_state_ptr; - void *__restrict__ query_start_loc_ptr; - void *__restrict__ has_initial_state_ptr; - void *__restrict__ cache_indices_ptr; - int32_t *__restrict__ cache_seqlens; - - // For the continuous batching case. Makes it so that the mamba state for - // the current batch doesn't need to be a contiguous tensor. - int32_t *__restrict__ conv_state_indices_ptr; - - void *__restrict__ seq_idx_ptr; - - // No __restrict__ since initial_states could be the same as final_states. - void * initial_states_ptr; - index_t initial_states_batch_stride; - index_t initial_states_l_stride; - index_t initial_states_c_stride; - - void * final_states_ptr; - index_t final_states_batch_stride; - index_t final_states_l_stride; - index_t final_states_c_stride; - - void * conv_states_ptr; - index_t conv_states_batch_stride; - index_t conv_states_l_stride; - index_t conv_states_c_stride; -}; - - -#ifndef USE_ROCM - #include - - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor_sync(uint32_t(-1), val, offset); - } - - constexpr size_t custom_max(std::initializer_list ilist) - { - return std::max(ilist); - } - - template - constexpr T constexpr_min(T a, T b) { - return std::min(a, b); - } - -#else - #include - - template - __device__ inline T shuffle_xor(T val, int offset) { - return __shfl_xor(val, offset); - } - constexpr size_t custom_max(std::initializer_list ilist) - { - return *std::max_element(ilist.begin(), ilist.end()); - } - - template - constexpr T constexpr_min(T a, T b) { - return a < b ? a : b; - } -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } -}; - -template -struct Allreduce { - static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); - template - static __device__ inline T run(T x, Operator &op) { - constexpr int OFFSET = THREADS / 2; - x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); - return Allreduce::run(x, op); - } -}; - -template<> -struct Allreduce<2> { -template -static __device__ inline T run(T x, Operator &op) { - x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); - return x; -} -}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h deleted file mode 100644 index ef74bf447f84..000000000000 --- a/csrc/mamba/causal_conv1d/static_switch.h +++ /dev/null @@ -1,28 +0,0 @@ -// Inspired by -// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h -// clang-format off -// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - static constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - static constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 785d316025ec..5766fbab4e87 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -7,7 +7,11 @@ #include #include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#ifdef USE_ROCM + #include // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK +#else + #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#endif #ifndef USE_ROCM #include @@ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size constexpr bool kIsVariableB = true; constexpr bool kIsVariableC = true; - constexpr bool kHasZ = true; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { +#ifdef USE_ROCM + C10_HIP_CHECK(hipFuncSetAttribute( + reinterpret_cast(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#else + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); +#endif + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); } @@ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, at::Tensor z, out_z; const bool has_z = z_.has_value(); - TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - if (varlen){ - CHECK_SHAPE(z, dim, seqlen); - } else { - CHECK_SHAPE(z, batch_size, dim, seqlen); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + if (varlen){ + CHECK_SHAPE(z, dim, seqlen); + } else { + CHECK_SHAPE(z, batch_size, dim, seqlen); + } + + out_z = z; } - out_z = z; - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); @@ -653,4 +664,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, selective_scan_fwd_cuda(params, stream); }); } - diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 462dbd1f8b38..8bbcf5a673fd 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -1,6 +1,7 @@ #include #include #include +#include #include #include @@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel( int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, - size_t numel, int32_t* __restrict__ cumsum) { + size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) { extern __shared__ int32_t shared_counts[]; + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[it] = numel; + } + const int warp_id = threadIdx.x / WARP_SIZE; const int my_expert_start = warp_id * experts_per_warp; @@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel( __syncthreads(); - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - int expert_count = 0; - int warp_idx = (i - 1) / experts_per_warp; - int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + // Compute prefix sum over token counts per expert + using BlockScan = cub::BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; - cumsum[i] = - cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; + int expert_count = 0; + int expert_id = threadIdx.x; + if (expert_id < num_experts) { + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; + expert_count = CEILDIV(expert_count, block_size) * block_size; + } + + int cumsum_val; + BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val); + if (expert_id <= num_experts) { + cumsum[expert_id] = cumsum_val; + } + + if (expert_id == num_experts) { + *total_tokens_post_pad = cumsum_val; } __syncthreads(); @@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel( expert_ids[i / block_size] = threadIdx.x; } } + + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; + const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); + for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { + expert_ids[i] = 0; + } } template @@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel) { + int32_t block_size, size_t numel, int32_t max_num_tokens_padded) { + // Initialize sorted_token_ids with numel + for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) { + sorted_token_ids[it] = numel; + } + const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } + // Fill remaining expert_ids with 0 + const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x; + const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size); + for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) { + expert_ids[i] = 0; + } + for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i]; int32_t rank_post_pad = @@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int threads = 1024; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + // BlockScan uses 1024 threads and assigns one thread per expert. + TORCH_CHECK(padded_num_experts < 1024, + "padded_num_experts must be less than 1024"); + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); torch::Tensor cumsum_buffer = - torch::zeros({num_experts + 1}, options_int); + torch::empty({num_experts + 1}, options_int); bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); + topk_ids.numel(), sorted_token_ids.size(0)); } else { auto align_kernel = vllm::moe::moe_align_block_size_kernel; @@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, padded_num_experts, experts_per_warp, block_size, - topk_ids.numel(), cumsum_buffer.data_ptr()); + topk_ids.numel(), cumsum_buffer.data_ptr(), + sorted_token_ids.size(0)); const int block_threads = std::min(256, (int)threads); const int num_blocks = diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index a77471a7f207..2922352a3f7c 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -10,32 +10,28 @@ void moe_permute( const torch::Tensor& input, // [n_token, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& topk_ids, // [n_token, topk] const torch::Tensor& token_expert_indices, // [n_token, topk] const std::optional& expert_map, // [n_expert] int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, - torch::Tensor& - permuted_input, // [topk * n_token/align_block_size_m, hidden] + torch::Tensor& permuted_input, // [permuted_size, hidden] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] - torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + torch::Tensor& inv_permuted_idx, // [n_token, topk] + torch::Tensor& permuted_idx, // [permute_size] torch::Tensor& m_indices) { // [align_expand_m] - TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, - "topk_weights must be float32"); TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, "topk_ids must be int32"); TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, "token_expert_indices must be int32"); - TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, - "src_row_id2dst_row_id_map must be int32"); + TORCH_CHECK(inv_permuted_idx.scalar_type() == at::ScalarType::Int, + "inv_permuted_idx must be int32"); TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, "expert_first_token_offset shape != n_local_expert+1") - TORCH_CHECK( - src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), - "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); + TORCH_CHECK(inv_permuted_idx.sizes() == token_expert_indices.sizes(), + "token_expert_indices shape must be same as inv_permuted_idx"); auto n_token = input.sizes()[0]; auto n_hidden = input.sizes()[1]; auto align_block_size_value = @@ -46,8 +42,9 @@ void moe_permute( auto sort_workspace = torch::empty( {sorter_size}, torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + auto copy_topk_ids = topk_ids.clone(); // copy topk_ids for preprocess auto permuted_experts_id = torch::empty_like(topk_ids); - auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); + auto sorted_row_idx = torch::empty_like(inv_permuted_idx); auto align_expert_first_token_offset = torch::zeros_like(expert_first_token_offset); @@ -67,24 +64,22 @@ void moe_permute( const int* expert_map_ptr = get_ptr(expert_map.value()); valid_num_ptr = get_ptr(expert_first_token_offset) + n_local_expert; - preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, + preprocessTopkIdLauncher(get_ptr(copy_topk_ids), n_token * topk, expert_map_ptr, n_expert, stream); } // expert sort topk expert id and scan expert id get expert_first_token_offset - sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indices), - get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(expert_first_token_offset), n_token, - n_expert, n_local_expert, topk, sorter, - get_ptr(sort_workspace), stream); + sortAndScanExpert( + get_ptr(copy_topk_ids), get_ptr(token_expert_indices), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(expert_first_token_offset), n_token, n_expert, + n_local_expert, topk, sorter, get_ptr(sort_workspace), stream); // dispatch expandInputRowsKernelLauncher MOE_DISPATCH(input.scalar_type(), [&] { expandInputRowsKernelLauncher( get_ptr(input), get_ptr(permuted_input), - get_ptr(topk_weights), get_ptr(permuted_experts_id), - get_ptr(dst_row_id2src_row_id_map), - get_ptr(src_row_id2dst_row_id_map), + get_ptr(permuted_experts_id), get_ptr(sorted_row_idx), + get_ptr(inv_permuted_idx), get_ptr(permuted_idx), get_ptr(expert_first_token_offset), n_token, valid_num_ptr, n_hidden, topk, n_local_expert, align_block_size_value, stream); }); @@ -101,32 +96,34 @@ void moe_permute( } void moe_unpermute( - const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] - const torch::Tensor& topk_weights, //[n_token, topk] - const torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] - const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] - int64_t n_expert, int64_t n_local_expert, int64_t topk, + const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] + const torch::Tensor& topk_weights, // [n_token, topk] + const torch::Tensor& inv_permuted_idx, // [n_token, topk] + const std::optional& + expert_first_token_offset, // [n_local_expert+1] + int64_t topk, torch::Tensor& hidden_states // [n_token, hidden] ) { - TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), - "topk_ids shape must be same as src_row_id2dst_row_id_map"); - TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, - "topk_ids must be int32"); TORCH_CHECK( permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), - "topk_ids dtype must be same as src_row_id2dst_row_id_map"); + "permuted_hidden_states dtype must be same as hidden_states"); auto n_token = hidden_states.size(0); auto n_hidden = hidden_states.size(1); auto stream = at::cuda::getCurrentCUDAStream().stream(); - const int64_t* valid_ptr = - get_ptr(expert_first_token_offset) + n_local_expert; + + int64_t const* valid_ptr = nullptr; + if (expert_first_token_offset.has_value()) { + int n_local_expert = expert_first_token_offset.value().size(0) - 1; + valid_ptr = + get_ptr(expert_first_token_offset.value()) + n_local_expert; + } + MOE_DISPATCH(hidden_states.scalar_type(), [&] { finalizeMoeRoutingKernelLauncher( get_ptr(permuted_hidden_states), get_ptr(hidden_states), get_ptr(topk_weights), - get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), - n_token, n_hidden, topk, valid_ptr, stream); + get_ptr(inv_permuted_idx), n_token, n_hidden, topk, valid_ptr, + stream); }); } diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu index de2c153882d9..2271c1bc75b1 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset, int tidx = threadIdx.x; extern __shared__ int64_t smem_expert_first_token_offset[]; for (int i = tidx; i <= num_local_expert; i += blockDim.x) { - smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); + smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i); } __syncthreads(); auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h index 43c29721cd16..108091efbefa 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream); -// Final kernel to unpermute and scale -// This kernel unpermutes the original data, does the k-way reduction and -// performs the final skip connection. -template -__global__ void finalizeMoeRoutingKernel( - T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, - float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, - int64_t const* num_valid_ptr); - template void finalizeMoeRoutingKernelLauncher( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const num_rows, - int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, - cudaStream_t stream); + int64_t const num_rows, int64_t const cols, int64_t const k, + int64_t const* num_valid_ptr, cudaStream_t stream); void preprocessTopkIdLauncher(int* topk_id_ptr, int size, const int* expert_map_ptr, int num_experts, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index ad0d390665a0..449243b92a28 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -2,10 +2,9 @@ template __global__ void expandInputRowsKernel( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_dest_rows, int64_t const cols, int64_t k, int num_local_experts, int align_block_size) { @@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel( assert(expanded_dest_row <= INT32_MAX); expanded_source_row_to_expanded_dest_row[expanded_source_row] = static_cast(expanded_dest_row); + // skip non local expert token + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { + permuted_idx[expanded_dest_row] = expanded_source_row; + } } if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { @@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel( using DataElem = cutlass::Array; // Duplicate and permute rows - int64_t const source_row = expanded_source_row % num_rows; + int64_t const source_row = expanded_source_row / k; auto const* source_row_ptr = reinterpret_cast(unpermuted_input + source_row * cols); @@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel( template void expandInputRowsKernelLauncher( - T const* unpermuted_input, T* permuted_output, - const float* unpermuted_scales, int* sorted_experts, + T const* unpermuted_input, T* permuted_output, int* sorted_experts, int const* expanded_dest_row_to_expanded_source_row, - int* expanded_source_row_to_expanded_dest_row, + int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int64_t* expert_first_token_offset, int64_t const num_rows, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int num_local_experts, const int& align_block_size, cudaStream_t stream) { @@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher( int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); func<<>>( - unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, + unpermuted_input, permuted_output, sorted_experts, expanded_dest_row_to_expanded_source_row, - expanded_source_row_to_expanded_dest_row, expert_first_token_offset, - num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, - align_block_size); + expanded_source_row_to_expanded_dest_row, permuted_idx, + expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k, + num_local_experts, align_block_size); } template @@ -128,11 +130,9 @@ template __global__ void finalizeMoeRoutingKernel( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, - int64_t const* num_valid_ptr) { + int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) { assert(orig_cols % 4 == 0); int64_t const original_row = blockIdx.x; - int64_t const num_rows = gridDim.x; auto const offset = original_row * orig_cols; OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; int64_t const num_valid = *num_valid_ptr; @@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel( ComputeElem thread_output; thread_output.fill(0); for (int k_idx = 0; k_idx < k; ++k_idx) { - int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_original_row = original_row * k + k_idx; int64_t const expanded_permuted_row = expanded_source_row_to_expanded_dest_row[expanded_original_row]; int64_t const k_offset = original_row * k + k_idx; float const row_scale = scales[k_offset]; - // Check after row_rescale has accumulated if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { continue; } @@ -189,9 +188,8 @@ template void finalizeMoeRoutingKernelLauncher( T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, float const* scales, int const* expanded_source_row_to_expanded_dest_row, - int const* expert_for_source_row, int64_t const num_rows, - int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, - cudaStream_t stream) { + int64_t const num_rows, int64_t const cols, int64_t const k, + int64_t const* num_valid_ptr, cudaStream_t stream) { int64_t const blocks = num_rows; int64_t const threads = 256; bool const check_finished = num_valid_ptr != nullptr; @@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher( auto* const kernel = func_map[check_finished]; kernel<<>>( expanded_permuted_rows, reduced_unpermuted_output, scales, - expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, - num_valid_ptr); + expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr); } diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 064b76c9cd42..0b505d2e04a2 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK( 2) This implementation assumes k is small, but will work for any k. */ -template -__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__ void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { @@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ // Restrictions based on previous section. static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); - static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); - static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size"); // We have NUM_EXPERTS elements per row. We specialize for small #experts - static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT; static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; @@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ namespace detail { // Constructs some constants needed to partition the work across threads at compile time. -template +template struct TopkConstants { static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); - static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); - static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM)); static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; static constexpr int THREADS_PER_ROW = EXPERTS / VPT; - static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; + static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW; }; } // namespace detail -template +template void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); - using Constants = detail::TopkConstants; + using Constants = detail::TopkConstants; static constexpr int VPT = Constants::VPT; static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; - dim3 block_dim(WARP_SIZE, WARPS_PER_TB); - topkGatingSoftmax<<>>( + dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB); + topkGatingSoftmax<<>>( input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert); } -#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ - topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indices, \ - token_expert_indices, num_tokens, topk, 0, num_experts, \ - stream); +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + switch (warpSize) { \ + case 32: \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + break; \ + case 64: \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \ + } template void topkGatingSoftmaxKernelLauncher( @@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher( const int topk, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; + auto warpSize = WARP_SIZE; switch (num_experts) { case 1: LAUNCH_SOFTMAX(1, WARPS_PER_TB); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 97df311d0440..d96e082f6ef1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " -> Tensor"); m.def( - "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," + "moe_permute(Tensor input, Tensor topk_ids," "Tensor token_expert_indices, Tensor? expert_map, int n_expert," "int n_local_expert," "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " - "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " - "m_indices)->()"); + "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " + "permuted_idx, Tensor! m_indices)->()"); m.def( "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," - "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " - "expert_first_token_offset, int n_expert, int n_local_expert,int " - "topk, Tensor! hidden_states)->()"); + "Tensor inv_permuted_idx, Tensor? expert_first_token_offset, " + "int topk, Tensor! hidden_states)->()"); m.def("moe_permute_unpermute_supported() -> bool"); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); diff --git a/csrc/ops.h b/csrc/ops.h index 52c264d64cca..97a247d9d628 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -287,6 +287,11 @@ void scaled_fp4_experts_quant( torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input_offset_by_experts, torch::Tensor const& output_scale_offset_by_experts); + +void per_token_group_quant_fp8(const torch::Tensor& input, + torch::Tensor& output_q, torch::Tensor& output_s, + int64_t group_size, double eps, double fp8_min, + double fp8_max, bool scale_ue8m0); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, @@ -326,22 +331,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const std::optional& has_initial_state, const torch::Tensor& ssm_states, int64_t pad_slot_id); -void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, - const at::Tensor& weight, - const std::optional& bias_, - bool silu_activation, - const std::optional& cache_seqlens_, - const std::optional& conv_state_indices_, - int64_t pad_slot_id); - -void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const std::optional& bias_, - const std::optional& conv_states, - const std::optional& query_start_loc, - const std::optional& cache_indices, - const std::optional& has_initial_state, - bool silu_activation, int64_t pad_slot_id); - using fptr_t = int64_t; fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, torch::Tensor& rank_data, int64_t rank, diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu index 67e9149c1379..8bc2b9bff3d5 100644 --- a/csrc/quantization/activation_kernels.cu +++ b/csrc/quantization/activation_kernels.cu @@ -4,7 +4,7 @@ #include #include "core/math.hpp" -#include "cuda_compat.h" +#include "../cuda_compat.h" #include "dispatch_utils.h" #include "quantization/fp8/common.cuh" diff --git a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu index 236d76ed5208..6c8f6309ef43 100644 --- a/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu @@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm( reinterpret_cast( layout_sfb.data_ptr())}; - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = a_ptrs.get_device(); - hw_info.sm_count = - cutlass::KernelHardwareInfo::query_device_multiprocessor_count( - hw_info.device_id); + int device_id = a_ptrs.device().index(); + static const cutlass::KernelHardwareInfo hw_info{ + device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; // Epilogue Arguments typename GemmKernel::EpilogueArguments epilogue_args{ diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index bbd82d72e95b..659941de182e 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -18,28 +18,34 @@ using ProblemShape = cutlass::gemm::GroupProblemShape>; using ElementAccumulator = float; -using ArchTag = cutlass::arch::Sm90; using OperatorClass = cutlass::arch::OpClassTensorOp; using LayoutA = cutlass::layout::RowMajor; +using LayoutA_Transpose = + typename cutlass::layout::LayoutTranspose::type; using LayoutB = cutlass::layout::ColumnMajor; -using LayoutC = cutlass::layout::RowMajor; - -template ::type; +using LayoutD = cutlass::layout::RowMajor; +using LayoutD_Transpose = + typename cutlass::layout::LayoutTranspose::type; +using LayoutC = LayoutD; +using LayoutC_Transpose = LayoutD_Transpose; + +template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, - typename EpilogueSchedule> + typename EpilogueSchedule, bool swap_ab_ = false> struct cutlass_3x_group_gemm { + static constexpr bool swap_ab = swap_ab_; using ElementAB = ElementAB_; using ElementC = void; using ElementD = ElementC_; using ElementAccumulator = float; + using ArchTag = ArchTag_; using Epilogue = Epilogue_; - using StrideC = - cute::remove_pointer_t, cute::Int<0>>>; - static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; @@ -50,21 +56,28 @@ struct cutlass_3x_group_gemm { typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, - LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; + ElementAccumulator, ElementC, + conditional_t, AlignmentC, + ElementD, conditional_t, + AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp; static constexpr size_t CEStorageSize = sizeof(typename CollectiveEpilogue::SharedStorage); using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout< static_cast(CEStorageSize)>; - using CollectiveMainloop = + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB, + ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator, + TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp, typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB, LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape, - Stages, KernelSchedule>::CollectiveOp; + Stages, KernelSchedule>::CollectiveOp>; - using KernelType = enable_sm90_only>; struct GemmKernel : public KernelType {}; @@ -78,12 +91,12 @@ void cutlass_group_gemm_caller( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { + static constexpr bool swap_ab = Gemm::swap_ab; + using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; int num_experts = static_cast(expert_offsets.size(0)); - int k_size = a_tensors.size(1); - int n_size = out_tensors.size(1); auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); @@ -110,26 +123,47 @@ void cutlass_group_gemm_caller( problem_sizes.data_ptr()); ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr}; - typename GemmKernel::MainloopArguments mainloop_args{ - static_cast(a_ptrs.data_ptr()), - static_cast(a_strides.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(b_strides.data_ptr())}; + typename GemmKernel::MainloopArguments mainloop_args; + if constexpr (swap_ab) { + mainloop_args = typename GemmKernel::MainloopArguments{ + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr())}; + } else { + mainloop_args = typename GemmKernel::MainloopArguments{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides.data_ptr())}; + } // Currently, we are only able to do broadcast on either all or none a_scales // and on either all or none b_scales typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( - static_cast(a_scales_ptrs.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), - per_act_token, per_out_ch), + swap_ab ? static_cast( + b_scales_ptrs.data_ptr()) + : static_cast( + a_scales_ptrs.data_ptr()), + swap_ab ? static_cast( + a_scales_ptrs.data_ptr()) + : static_cast( + b_scales_ptrs.data_ptr()), + swap_ab ? per_out_ch : per_act_token, + swap_ab ? per_act_token : per_out_ch), nullptr, static_cast(c_strides.data_ptr()), static_cast(out_ptrs.data_ptr()), static_cast(c_strides.data_ptr())}; + int device_id = a_tensors.device().index(); + static const cutlass::KernelHardwareInfo hw_info{ + device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + device_id)}; + typename GemmKernel::Arguments args{ cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args, - epilogue_args}; + epilogue_args, hw_info}; using GemmOp = cutlass::gemm::device::GemmUniversalAdapter; GemmOp gemm_op; diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu new file mode 100644 index 000000000000..641e5997f0fd --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu @@ -0,0 +1,140 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "grouped_mm_c3x.cuh" + +using namespace cute; + +namespace { + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm100_fp8_config_M64 { + // M in [1,64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm100_fp8_config_N8192 { + // N in [8192, inf) + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm100; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template +void run_cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); + + using Cutlass3xGemmDefault = typename sm100_fp8_config_default< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM64 = typename sm100_fp8_config_M64< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + + if (m <= 64) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} +} // namespace + +void dispatch_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm100( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + run_cutlass_moe_mm_sm100( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} + +void cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); +} diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu similarity index 72% rename from csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu rename to csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu index c88e134ae406..8f21623b52fa 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu @@ -21,27 +21,49 @@ struct sm90_fp8_config_default { cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; using TileShape = cute::Shape; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template typename Epilogue> -struct sm90_fp8_config_M16 { - // M in [1, 16] +struct sm90_fp8_config_M4 { + // M in [1, 4] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; - using TileShape = cute::Shape; - using ClusterShape = cute::Shape; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm90_fp8_config_M64 { + // M in (4, 64] + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template ; using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm90; using Cutlass3xGemm = - cutlass_3x_group_gemm; + cutlass_3x_group_gemm; }; template @@ -95,14 +119,13 @@ void run_cutlass_moe_mm_sm90( TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, "B tensors must be of type float8_e4m3fn."); - TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); - using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; - using Cutlass3xGemmM16 = typename sm90_fp8_config_M16< + using Cutlass3xGemmM4 = typename sm90_fp8_config_M4< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM64 = typename sm90_fp8_config_M64< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; using Cutlass3xGemmDefault = typename sm90_fp8_config_default< InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; @@ -111,18 +134,24 @@ void run_cutlass_moe_mm_sm90( uint32_t const n = out_tensors.size(1); uint32_t const k = a_tensors.size(1); - if (n >= 8192) { - cutlass_group_gemm_caller( + // Use swap_ab for M <= 64 by default to reduce padding + if (m <= 4) { + cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); - } else if (k >= 8192) { - cutlass_group_gemm_caller( + } else if (m <= 64) { + cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); - } else if (m <= 16) { - cutlass_group_gemm_caller( + } else if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (k >= 8192) { + cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, c_strides, per_act_token, per_out_ch); diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index 32254641cc38..993c30c48c84 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -6,8 +6,11 @@ #include constexpr uint64_t THREADS_PER_EXPERT = 512; +// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90() +constexpr int SWAP_AB_THRESHOLD = 64; -__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, +template +__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, @@ -24,45 +27,58 @@ __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, if (threadIdx.x == 0) { int final_occurrences = atomic_buffer[expert_id]; - problem_sizes1[expert_id * 3] = final_occurrences; - problem_sizes1[expert_id * 3 + 1] = 2 * n; - problem_sizes1[expert_id * 3 + 2] = k; - problem_sizes2[expert_id * 3] = final_occurrences; - problem_sizes2[expert_id * 3 + 1] = k; - problem_sizes2[expert_id * 3 + 2] = n; + if constexpr (!SWAP_AB) { + problem_sizes1[expert_id * 3] = final_occurrences; + problem_sizes1[expert_id * 3 + 1] = 2 * n; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = final_occurrences; + problem_sizes2[expert_id * 3 + 1] = k; + problem_sizes2[expert_id * 3 + 2] = n; + } else { + problem_sizes1[expert_id * 3] = 2 * n; + problem_sizes1[expert_id * 3 + 1] = final_occurrences; + problem_sizes1[expert_id * 3 + 2] = k; + problem_sizes2[expert_id * 3] = k; + problem_sizes2[expert_id * 3 + 1] = final_occurrences; + problem_sizes2[expert_id * 3 + 2] = n; + } } } __global__ void compute_expert_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* atomic_buffer, const int num_experts) { + int32_t* atomic_buffer, const int num_experts, const int topk_length) { int32_t tot_offset = 0; expert_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { atomic_buffer[i] = tot_offset; - tot_offset += problem_sizes1[i * 3]; + tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3] + : problem_sizes1[i * 3 + 1]; expert_offsets[i + 1] = tot_offset; } } __global__ void compute_expert_blockscale_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, - int32_t* blockscale_offsets, int32_t* atomic_buffer, - const int num_experts) { + int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts, + const int topk_length) { int32_t tot_offset = 0; int32_t tot_offset_round = 0; expert_offsets[0] = 0; blockscale_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { + int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD + ? problem_sizes1[i * 3] + : problem_sizes1[i * 3 + 1]; atomic_buffer[i] = tot_offset; - tot_offset += problem_sizes1[i * 3]; + tot_offset += cur_offset; expert_offsets[i + 1] = tot_offset; - tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128; + tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128; blockscale_offsets[i + 1] = tot_offset_round; } } -__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids, +__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids, const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, @@ -102,25 +118,39 @@ void get_cutlass_moe_mm_data_caller( torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); - compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), - static_cast(problem_sizes1.data_ptr()), - static_cast(problem_sizes2.data_ptr()), - static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); + + if (topk_ids.numel() > SWAP_AB_THRESHOLD) { + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); + } else { + compute_problem_sizes<<>>( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, + k); + } + if (blockscale_offsets.has_value()) { compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(blockscale_offsets.value().data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); + static_cast(atomic_buffer.data_ptr()), num_experts, + topk_ids.numel()); } else { compute_expert_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), - static_cast(atomic_buffer.data_ptr()), num_experts); + static_cast(atomic_buffer.data_ptr()), num_experts, + topk_ids.numel()); } compute_arg_sorts<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), @@ -160,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, static_cast(problem_sizes2.data_ptr()), static_cast(expert_num_tokens.data_ptr()), padded_m, n, k); -} +} \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 31b60488dfb7..106bacb4883c 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90( #endif +#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 +void cutlass_moe_mm_sm100( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); +#endif + #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, @@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { // and at least SM90 (Hopper) #if defined CUDA_VERSION - if (cuda_device_capability >= 90 && cuda_device_capability < 100) { - return CUDA_VERSION >= 12000; - } else if (cuda_device_capability >= 100) { + if (cuda_device_capability >= 100) { return CUDA_VERSION >= 12080; + } else if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; } #endif @@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { } bool cutlass_group_gemm_supported(int64_t cuda_device_capability) { - // CUTLASS grouped FP8 kernels need at least CUDA 12.3 - // and SM90 (Hopper) + // CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper) + // or CUDA 12.8 and SM100 (Blackwell) #if defined CUDA_VERSION - if (cuda_device_capability == 90) { + if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; + } + if (cuda_device_capability >= 90) { return CUDA_VERSION >= 12030; } #endif @@ -234,16 +247,26 @@ void cutlass_moe_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 + if (version_num >= 100) { + cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } +#endif #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 - cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, - expert_offsets, problem_sizes, a_strides, b_strides, - c_strides, per_act_token, per_out_ch); - return; + if (version_num >= 90) { + cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } #endif TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90"); + ". Required capability: 90 or 100"); } void get_cutlass_moe_mm_data( diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 7572a7eb3122..5bc4c38a275c 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -30,35 +30,40 @@ #include "cutlass/util/packed_stride.hpp" +#include "core/math.hpp" + using namespace cute; #if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) -// Kernel Perf config -template -struct KernelTraits; -template <> -struct KernelTraits { - using MmaTileShape = Shape<_128, _128, _256>; - using ClusterShape = Shape<_1, _1, _1>; - using PerSmTileShape_MNK = Shape<_128, _128, _256>; +// Configuration for M in (256, inf) +struct sm100_fp4_config_default { + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_2, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; }; -template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; +// Configuration for M in (16, 256] +struct sm100_fp4_config_M256 { + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _128, _256>; + using ClusterShape = Shape<_2, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; }; -template <> -struct KernelTraits { - using MmaTileShape = Shape<_256, _256, _256>; - using ClusterShape = Shape<_4, _4, _1>; - using PerSmTileShape_MNK = Shape<_128, _256, _256>; +// Configuration for M in [1, 16] +struct sm100_fp4_config_M16 { + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; }; -template +template struct Fp4GemmSm100 { // A matrix configuration using ElementA = cutlass::nv_float4_t; @@ -71,21 +76,22 @@ struct Fp4GemmSm100 { static constexpr int AlignmentB = 32; // C/D matrix configuration - using ElementD = T; - using ElementC = T; + using ElementD = OutType; + using ElementC = OutType; using LayoutCTag = cutlass::layout::RowMajor; using LayoutDTag = cutlass::layout::RowMajor; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Kernel functional config using ElementAccumulator = float; using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; - // Kernel Perf config - using MmaTileShape = typename KernelTraits::MmaTileShape; - using ClusterShape = typename KernelTraits::ClusterShape; - using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + // Use config's tile shapes + using MmaTileShape = typename Config::TileShape; + using ClusterShape = typename Config::ClusterShape; + using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -119,22 +125,22 @@ struct Fp4GemmSm100 { using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); }; -template -typename T::Gemm::Arguments args_from_options( +template +typename Config::Gemm::Arguments args_from_options( at::Tensor& D, at::Tensor const& A, at::Tensor const& B, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, int64_t M, int64_t N, int64_t K) { - using ElementA = typename T::Gemm::ElementA; - using ElementB = typename T::Gemm::ElementB; + using ElementA = typename Config::Gemm::ElementA; + using ElementB = typename Config::Gemm::ElementB; using ElementSFA = cutlass::float_ue4m3_t; using ElementSFB = cutlass::float_ue4m3_t; - using ElementD = typename T::Gemm::ElementD; + using ElementD = typename Config::Gemm::ElementD; using ElementCompute = float; - using StrideA = typename T::StrideA; - using StrideB = typename T::StrideB; - using StrideD = typename T::StrideD; - using Sm100BlkScaledConfig = - typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using StrideA = typename Config::StrideA; + using StrideB = typename Config::StrideB; + using StrideD = typename Config::StrideD; + using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel:: + CollectiveMainloop::Sm1xxBlkScaledConfig; int m = static_cast(M); int n = static_cast(N); @@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options( auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( cute::make_shape(m, n, k, 1)); - typename T::Gemm::Arguments arguments{ + typename Config::Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, {// Mainloop arguments @@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options( return arguments; } -template +template void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, cudaStream_t stream) { - typename Fp4GemmSm100::Gemm gemm; + typename Config::Gemm gemm; auto arguments = - args_from_options>(D, A, B, A_sf, B_sf, alpha, m, n, k); + args_from_options(D, A, B, A_sf, B_sf, alpha, m, n, k); - size_t workspace_size = Fp4GemmSm100::Gemm::get_workspace_size(arguments); + size_t workspace_size = Config::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); @@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); } + +// Dispatch function to select appropriate config based on M +template +void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, int64_t m, int64_t n, + int64_t k, cudaStream_t stream) { + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + + if (mp2 <= 16) { + // m in [1, 16] + runGemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { + // m in (16, 256] + runGemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + // m in (256, inf) + runGemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } +} + #else -template -void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, - at::Tensor const& A_sf, at::Tensor const& B_sf, - at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, - cudaStream_t stream) { +template +void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha, int64_t m, int64_t n, + int64_t k, cudaStream_t stream) { TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " "a CUTLASS 3.8 source directory to enable support."); @@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); if (out_dtype == at::ScalarType::Half) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlass_fp4_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, m, n, + k, stream); } else if (out_dtype == at::ScalarType::BFloat16) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); - } else if (out_dtype == at::ScalarType::Float) { - runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + cutlass_fp4_gemm_dispatch(D, A, B, A_sf, B_sf, alpha, + m, n, k, stream); } else { - TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype, + ")"); } } diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index f3f9f669e00a..0e1eab66f0b9 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor const& scale) // [1] { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); int const block_size = 256; int const num_tokens = input.numel() / input.size(-1); int const num_elems = input.numel(); @@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor const& input, // [..., d] torch::Tensor& scale) // [1] { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); int const block_size = 256; int const num_tokens = input.numel() / input.size(-1); int const num_elems = input.numel(); diff --git a/csrc/quantization/fp8/per_token_group_quant.cu b/csrc/quantization/fp8/per_token_group_quant.cu new file mode 100644 index 000000000000..afc41faeca90 --- /dev/null +++ b/csrc/quantization/fp8/per_token_group_quant.cu @@ -0,0 +1,213 @@ +#include +#include + +#include + +#include +#include + +#include + +#include "../vectorization.cuh" +#include "../vectorization_utils.cuh" +#include "../../dispatch_utils.h" + +__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { + unsigned mask = 0xffff; + + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + return val; +} + +template +__global__ void per_token_group_quant_8bit_kernel( + const T* __restrict__ input, void* __restrict__ output_q, + scale_packed_t* __restrict__ output_s, const int group_size, + const int num_groups, const int groups_per_block, const float eps, + const float min_8bit, const float max_8bit, const int scale_num_rows = 0, + const int scale_stride = 0) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; + + float local_absmax = eps; + + using scale_element_t = float; + static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); + + const T* group_input = input + block_group_offset; + DST_DTYPE* group_output = + static_cast(output_q) + block_group_offset; + scale_element_t* scale_output; + + if constexpr (IS_COLUMN_MAJOR) { + const int num_elems_per_pack = + static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; + const int row_idx = global_group_id / scale_num_rows_element; + const int col_idx_raw = global_group_id % scale_num_rows_element; + const int col_idx = col_idx_raw / num_elems_per_pack; + const int pack_idx = col_idx_raw % num_elems_per_pack; + scale_output = reinterpret_cast(output_s) + + (col_idx * scale_stride * num_elems_per_pack + + row_idx * num_elems_per_pack + pack_idx); + } else { + scale_output = output_s + global_group_id; + } + + // shared memory to cache each group's data to avoid double DRAM reads. + extern __shared__ __align__(16) char smem_raw[]; + T* smem = reinterpret_cast(smem_raw); + T* smem_group = smem + local_group_id * group_size; + + constexpr int vec_size = 16 / sizeof(T); + using vec_t = vllm::vec_n_t; + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(T & dst, const T& src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment( + group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax, lane_id); + + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + } + + scale_element_t y_s_quant = y_s; + + if (lane_id == 0) { + *scale_output = y_s_quant; + } + + __syncthreads(); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = DST_DTYPE(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + +void per_token_group_quant_8bit(const torch::Tensor& input, + torch::Tensor& output_q, + torch::Tensor& output_s, int64_t group_size, + double eps, double min_8bit, double max_8bit, + bool scale_ue8m0 = false) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(output_q.is_contiguous()); + + const int num_groups = input.numel() / group_size; + + TORCH_CHECK(input.numel() % group_size == 0); + TORCH_CHECK(output_s.dim() == 2); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + + auto dst_type = output_q.scalar_type(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + const bool is_column_major = output_s.stride(0) < output_s.stride(1); + const int scale_num_rows = output_s.size(1); + const int scale_stride = output_s.stride(1); + +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + size_t smem_bytes = \ + static_cast(groups_per_block) * group_size * sizeof(T); \ + if (is_column_major) { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit, scale_num_rows, scale_stride); \ + } else { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit, scale_num_rows, scale_stride); \ + } \ + } else { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit); \ + } else { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input.data_ptr()), output_q.data_ptr(), \ + static_cast(output_s.data_ptr()), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit); \ + } \ + } \ + } while (0) + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "per_token_group_quant_8bit", ([&] { + if (dst_type == at::ScalarType::Float8_e4m3fn) { + LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn); + } + })); + +#undef LAUNCH_KERNEL +} + +void per_token_group_quant_fp8(const torch::Tensor& input, + torch::Tensor& output_q, torch::Tensor& output_s, + int64_t group_size, double eps, double fp8_min, + double fp8_max, bool scale_ue8m0) { + per_token_group_quant_8bit(input, output_q, output_s, group_size, eps, + fp8_min, fp8_max, scale_ue8m0); +} diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 3b5180b51623..76fe73e95040 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -4,7 +4,7 @@ #include #include -#include "cuda_compat.h" +#include "../../cuda_compat.h" #include "dispatch_utils.h" #include "ggml-common.h" diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 3bddd12cad07..65cb1c1d1478 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -19,7 +19,7 @@ #include #include #include -#include "cuda_compat.h" +#include "../cuda_compat.h" #include #include "../attention/dtype_fp8.cuh" diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 6212570c79d1..eb47139208c9 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -9,7 +9,7 @@ #include #include -#include "cuda_compat.h" +#include "../cuda_compat.h" #include "dispatch_utils.h" #include "quantization/fp8/common.cuh" diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 9414e26196b2..95f8541bc9e2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops // - // The default behavior in PyTorch 2.6 is "requires_contiguous", so we need + // The default behavior in PyTorch 2.6 was changed to "requires_contiguous", + // so we need // to override this for many GEMMs with the following tag. Otherwise, // torch.compile will force all input tensors to be contiguous(), which // will break many custom ops that require column-major weight matrices. - // TODO: remove this for PyTorch 2.8, when the default is planned to switch - // to match exact eager-mode strides. - at::Tag stride_tag = at::Tag::needs_fixed_stride_order; + // This was a bug and PyTorch 2.7 has since fixed this. +#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6 + #define stride_tag at::Tag::needs_fixed_stride_order +#else + #define stride_tag +#endif ops.def("weak_ref_tensor(Tensor input) -> Tensor"); ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor); @@ -514,6 +518,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor page_table, float scale) -> ()"); ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // SM100 CUTLASS MLA decode + ops.def( + "sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table, Tensor workspace, float " + "scale," + " int num_kv_splits) -> ()"); + // conditionally compiled so impl in source file + + // SM100 CUTLASS MLA workspace + ops.def( + "sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches," + " int sm_count, int num_kv_splits) " + "-> int"); + // conditionally compiled so impl in source file + // Compute NVFP4 block quantized tensor. ops.def( "scaled_fp4_quant(Tensor! output, Tensor input," @@ -594,29 +614,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); +#ifndef USE_ROCM + // Compute per-token-group FP8 quantized tensor and scaling factor. ops.def( - "causal_conv1d_update(Tensor! x," - "Tensor! conv_state," - "Tensor! weight," - "Tensor? bias_," - "bool silu_activation," - "Tensor? cache_seqlens_," - "Tensor? conv_state_indices," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); - - ops.def( - "causal_conv1d_fwd(Tensor! x, Tensor! weight," - "Tensor? bias_," - "Tensor!? conv_states," - "Tensor? query_start_loc," - "Tensor? cache_indices," - "Tensor? has_initial_state," - "bool silu_activation," - "int pad_slot_id) -> ()"); - ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + "per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! " + "output_s, " + "int group_size, float eps, float fp8_min, float fp8_max, bool " + "scale_ue8m0) -> ()"); + ops.impl("per_token_group_fp8_quant", torch::kCUDA, + &per_token_group_quant_fp8); -#ifndef USE_ROCM // reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel ops.def( "rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, " diff --git a/docker/Dockerfile b/docker/Dockerfile index c49b5da2714c..868b81704662 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -63,7 +63,7 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly ARG PIP_KEYRING_PROVIDER=disabled ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER} -# Flag enables build-in KV-connector dependency libs into docker images +# Flag enables built-in KV-connector dependency libs into docker images ARG INSTALL_KV_CONNECTORS=false #################### BASE BUILD IMAGE #################### @@ -207,6 +207,19 @@ ARG SCCACHE_ENDPOINT ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 + +# Flag to control whether to use pre-built vLLM wheels +ARG VLLM_USE_PRECOMPILED +# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed +ENV VLLM_USE_PRECOMPILED="" +RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \ + export VLLM_USE_PRECOMPILED=1 && \ + echo "Using precompiled wheels"; \ + else \ + unset VLLM_USE_PRECOMPILED && \ + echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \ + fi + # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=.git,target=.git \ @@ -252,7 +265,7 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ #################### EXTENSION Build IMAGE #################### #################### DEV IMAGE #################### -FROM base as dev +FROM base AS dev ARG PIP_INDEX_URL UV_INDEX_URL ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL @@ -375,48 +388,33 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist # -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl # $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl -# Allow specifying a version, Git revision or local .whl file -ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer" -ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl" +# Install FlashInfer from source ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git" -ARG FLASHINFER_GIT_REF="v0.2.6.post1" +ARG FLASHINFER_GIT_REF="v0.2.8rc1" RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH' . /etc/environment - if [ "$TARGETPLATFORM" != "linux/arm64" ]; then - # FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use - if [[ "$CUDA_VERSION" == 12.8* ]]; then - uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} - else - export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' - git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive - # Needed to build AOT kernels - (cd flashinfer && \ - python3 -m flashinfer.aot && \ - uv pip install --system --no-build-isolation . \ - ) - rm -rf flashinfer - - # Default arches (skipping 10.0a and 12.0 since these need 12.8) - # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. - TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" - if [[ "${CUDA_VERSION}" == 11.* ]]; then - TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" - fi - echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}" - - git clone --depth 1 --recursive --shallow-submodules \ - --branch v0.2.6.post1 \ - https://github.com/flashinfer-ai/flashinfer.git flashinfer - - pushd flashinfer + git clone --depth 1 --recursive --shallow-submodules \ + --branch ${FLASHINFER_GIT_REF} \ + ${FLASHINFER_GIT_REPO} flashinfer + # Exclude CUDA arches for older versions (11.x and 12.0-12.7) + # TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg. + if [[ "${CUDA_VERSION}" == 11.* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9" + elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a" + else + # CUDA 12.8+ supports 10.0a and 12.0 + FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0" + fi + echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}" + # Needed to build AOT kernels + pushd flashinfer + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ python3 -m flashinfer.aot - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \ - uv pip install --system --no-build-isolation . - popd - - rm -rf flashinfer - fi \ - fi + TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \ + uv pip install --system --no-build-isolation . + popd + rm -rf flashinfer BASH COPY examples examples COPY benchmarks benchmarks @@ -508,10 +506,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/kv_connectors.txt; \ fi; \ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \ - uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ + BITSANDBYTES_VERSION="0.42.0"; \ else \ - uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.46.1' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \ - fi + BITSANDBYTES_VERSION="0.46.1"; \ + fi; \ + uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3] ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 5da2c9467bfc..982c1ddf2743 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -95,7 +95,7 @@ WORKDIR /workspace/vllm RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \ cp requirements/test.in requirements/cpu-test.in && \ sed -i '/mamba_ssm/d' requirements/cpu-test.in && \ - sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \ + sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \ sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \ sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \ uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu diff --git a/docker/Dockerfile.hpu b/docker/Dockerfile.hpu deleted file mode 100644 index 224f142b5ff4..000000000000 --- a/docker/Dockerfile.hpu +++ /dev/null @@ -1,21 +0,0 @@ -FROM vault.habana.ai/gaudi-docker/1.20.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest - -COPY ./ /workspace/vllm - -WORKDIR /workspace/vllm - -RUN pip install -v -r requirements/hpu.txt - -ENV no_proxy=localhost,127.0.0.1 -ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true - -RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install - -# install development dependencies (for testing) -RUN python3 -m pip install -e tests/vllm_test_utils - -WORKDIR /workspace/ - -RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks - -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index dc8ec5f1a15e..3414c0aa845c 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="6487649" +ARG AITER_BRANCH="916bf3c" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 295270d29f76..3474ff50de7b 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -1,5 +1,5 @@ -ARG NIGHTLY_DATE="20250124" -ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" +ARG NIGHTLY_DATE="20250714" +ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE WORKDIR /workspace/vllm diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 466ba9833363..7d5a589eb1d7 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' + pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope ENV VLLM_USAGE_SOURCE production-docker-image \ TRITON_XPU_PROFILE 1 diff --git a/docs/.nav.yml b/docs/.nav.yml index 06bfcc3f1eff..ab54dc3e535b 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -55,6 +55,7 @@ nav: - contributing/model/registration.md - contributing/model/tests.md - contributing/model/multimodal.md + - CI: contributing/ci - Design Documents: - V0: design - V1: design/v1 diff --git a/docs/README.md b/docs/README.md index e1d1046951a5..6823008ed336 100644 --- a/docs/README.md +++ b/docs/README.md @@ -36,7 +36,7 @@ vLLM is flexible and easy to use with: - Seamless integration with popular HuggingFace models - High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more -- Tensor parallelism and pipeline parallelism support for distributed inference +- Tensor, pipeline, data and expert parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators. @@ -48,4 +48,4 @@ For more information, check out the following: - [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention) - [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023) - [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al. -- [vLLM Meetups][meetups] +- [vLLM Meetups](community/meetups.md) diff --git a/docs/api/README.md b/docs/api/README.md index 5c7b2ca79ee2..db4dab0ae534 100644 --- a/docs/api/README.md +++ b/docs/api/README.md @@ -8,14 +8,12 @@ API documentation for vLLM's configuration classes. - [vllm.config.ModelConfig][] - [vllm.config.CacheConfig][] -- [vllm.config.TokenizerPoolConfig][] - [vllm.config.LoadConfig][] - [vllm.config.ParallelConfig][] - [vllm.config.SchedulerConfig][] - [vllm.config.DeviceConfig][] - [vllm.config.SpeculativeConfig][] - [vllm.config.LoRAConfig][] -- [vllm.config.PromptAdapterConfig][] - [vllm.config.MultiModalConfig][] - [vllm.config.PoolerConfig][] - [vllm.config.DecodingConfig][] @@ -64,7 +62,7 @@ vLLM provides experimental support for multi-modal models through the [vllm.mult Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models] via the `multi_modal_data` field in [vllm.inputs.PromptType][]. -Looking to add your own multi-modal model? Please follow the instructions listed [here][supports-multimodal]. +Looking to add your own multi-modal model? Please follow the instructions listed [here](../contributing/model/multimodal.md). - [vllm.multimodal.MULTIMODAL_REGISTRY][] diff --git a/docs/assets/deployment/dp_external_lb.png b/docs/assets/deployment/dp_external_lb.png new file mode 100644 index 000000000000..a5d3a2f31db7 Binary files /dev/null and b/docs/assets/deployment/dp_external_lb.png differ diff --git a/docs/assets/deployment/dp_internal_lb.png b/docs/assets/deployment/dp_internal_lb.png new file mode 100644 index 000000000000..6d6a78a03f03 Binary files /dev/null and b/docs/assets/deployment/dp_internal_lb.png differ diff --git a/docs/assets/deployment/open_webui.png b/docs/assets/deployment/open_webui.png index fe9a7e15ea71..7018b4dff6bb 100644 Binary files a/docs/assets/deployment/open_webui.png and b/docs/assets/deployment/open_webui.png differ diff --git a/docs/cli/README.md b/docs/cli/README.md index b2587a5e7cd2..dfb6051a8c8a 100644 --- a/docs/cli/README.md +++ b/docs/cli/README.md @@ -1,3 +1,7 @@ +--- +toc_depth: 4 +--- + # vLLM CLI Guide The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with: @@ -16,7 +20,7 @@ vllm {chat,complete,serve,bench,collect-env,run-batch} Start the vLLM OpenAI Compatible API server. -??? Examples +??? console "Examples" ```bash # Start with a model @@ -37,8 +41,15 @@ Start the vLLM OpenAI Compatible API server. # To search by keyword vllm serve --help=max + + # To view full help with pager (less/more) + vllm serve --help=page ``` +### Options + +--8<-- "docs/argparse/serve.md" + ## chat Generate chat completions via the running API server. diff --git a/docs/community/contact_us.md b/docs/community/contact_us.md index a10e6bfc9b0a..04c28cde5f6b 100644 --- a/docs/community/contact_us.md +++ b/docs/community/contact_us.md @@ -1,6 +1,3 @@ ---- -title: Contact Us ---- -[](){ #contactus } +# Contact Us --8<-- "README.md:contact-us" diff --git a/docs/community/meetups.md b/docs/community/meetups.md index 8ea42e3cad18..e8b3a9c9c8e6 100644 --- a/docs/community/meetups.md +++ b/docs/community/meetups.md @@ -1,7 +1,4 @@ ---- -title: Meetups ---- -[](){ #meetups } +# Meetups We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index e2303067e3ee..4d5c961af98f 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -33,7 +33,7 @@ Quantized models take less memory at the cost of lower precision. Statically quantized models can be downloaded from HF Hub (some popular ones are available at [Red Hat AI](https://huggingface.co/RedHatAI)) and used directly without extra configuration. -Dynamic quantization is also supported via the `quantization` option -- see [here][quantization-index] for more details. +Dynamic quantization is also supported via the `quantization` option -- see [here](../features/quantization/README.md) for more details. ## Context length and batch size @@ -57,7 +57,7 @@ By default, we optimize model inference using CUDA graphs which take up extra me You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage: -??? Code +??? code ```python from vllm import LLM @@ -129,7 +129,7 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory. Here are some examples: -??? Code +??? code ```python from vllm import LLM diff --git a/docs/configuration/engine_args.md b/docs/configuration/engine_args.md index e02c7090d373..c3c1d5a1c362 100644 --- a/docs/configuration/engine_args.md +++ b/docs/configuration/engine_args.md @@ -1,18 +1,20 @@ --- -title: Engine Arguments +toc_depth: 3 --- -[](){ #engine-args } + +# Engine Arguments Engine arguments control the behavior of the vLLM engine. -- For [offline inference][offline-inference], they are part of the arguments to [LLM][vllm.LLM] class. -- For [online serving][serving-openai-compatible-server], they are part of the arguments to `vllm serve`. +- For [offline inference](../serving/offline_inference.md), they are part of the arguments to [LLM][vllm.LLM] class. +- For [online serving](../serving/openai_compatible_server.md), they are part of the arguments to `vllm serve`. + +The engine argument classes, [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs], are a combination of the configuration classes defined in [vllm.config][]. Therefore, if you are interested in developer documentation, we recommend looking at these configuration classes as they are the source of truth for types, defaults and docstrings. -You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments. +## `EngineArgs` -However, these classes are a combination of the configuration classes defined in [vllm.config][]. Therefore, we would recommend you read about them there where they are best documented. +--8<-- "docs/argparse/engine_args.md" -For offline inference you will have access to these configuration classes and for online serving you can cross-reference the configs with `vllm serve --help`, which has its arguments grouped by config. +## `AsyncEngineArgs` -!!! note - Additional arguments are available to the [AsyncLLMEngine][vllm.engine.async_llm_engine.AsyncLLMEngine] which is used for online serving. These can be found by running `vllm serve --help` +--8<-- "docs/argparse/async_engine_args.md" diff --git a/docs/configuration/env_vars.md b/docs/configuration/env_vars.md index c875931c305b..2c0a898754fa 100644 --- a/docs/configuration/env_vars.md +++ b/docs/configuration/env_vars.md @@ -7,7 +7,7 @@ vLLM uses the following environment variables to configure the system: All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables). -??? Code +??? code ```python --8<-- "vllm/envs.py:env-vars-definition" diff --git a/docs/configuration/model_resolution.md b/docs/configuration/model_resolution.md index 8757c257d3e9..49576a8217d0 100644 --- a/docs/configuration/model_resolution.md +++ b/docs/configuration/model_resolution.md @@ -14,10 +14,10 @@ For example: ```python from vllm import LLM -model = LLM( +llm = LLM( model="cerebras/Cerebras-GPT-1.3B", hf_overrides={"architectures": ["GPT2LMHeadModel"]}, # GPT-2 ) ``` -Our [list of supported models][supported-models] shows the model architectures that are recognized by vLLM. +Our [list of supported models](../models/supported_models.md) shows the model architectures that are recognized by vLLM. diff --git a/docs/configuration/serve_args.md b/docs/configuration/serve_args.md index 16b4b29f45d9..c1cc5577bc7a 100644 --- a/docs/configuration/serve_args.md +++ b/docs/configuration/serve_args.md @@ -1,19 +1,16 @@ ---- -title: Server Arguments ---- -[](){ #serve-args } +# Server Arguments The `vllm serve` command is used to launch the OpenAI-compatible server. ## CLI Arguments The `vllm serve` command is used to launch the OpenAI-compatible server. -To see the available CLI arguments, run `vllm serve --help`! +To see the available options, take a look at the [CLI Reference](../cli/README.md#options)! ## Configuration file You can load CLI arguments via a [YAML](https://yaml.org/) config file. -The argument names must be the long form of those outlined [above][serve-args]. +The argument names must be the long form of those outlined [above](serve_args.md). For example: diff --git a/docs/contributing/README.md b/docs/contributing/README.md index 83525436be13..f2d439e37ccc 100644 --- a/docs/contributing/README.md +++ b/docs/contributing/README.md @@ -95,7 +95,7 @@ For additional features and advanced configurations, refer to the official [MkDo ## Testing -??? note "Commands" +??? console "Commands" ```bash pip install -r requirements/dev.txt diff --git a/docs/contributing/benchmarks.md b/docs/contributing/benchmarks.md index 00505fc6f2a9..0ebd99ba5ae1 100644 --- a/docs/contributing/benchmarks.md +++ b/docs/contributing/benchmarks.md @@ -1,7 +1,4 @@ ---- -title: Benchmark Suites ---- -[](){ #benchmarks } +# Benchmark Suites vLLM contains two sets of benchmarks: diff --git a/docs/contributing/ci-failures.md b/docs/contributing/ci/failures.md similarity index 81% rename from docs/contributing/ci-failures.md rename to docs/contributing/ci/failures.md index 7caaf10ceb5c..573efb3b05f6 100644 --- a/docs/contributing/ci-failures.md +++ b/docs/contributing/ci/failures.md @@ -6,9 +6,9 @@ the failure? - Check the dashboard of current CI test failures: 👉 [CI Failures Dashboard](https://github.com/orgs/vllm-project/projects/20) -- If your failure **is already listed**, it's likely unrelated to your PR. - Help fixing it is always welcome! - - Leave comments with links to additional instances of the failure. +- If your failure **is already listed**, it's likely unrelated to your PR. + Help fixing it is always welcome! + - Leave comments with links to additional instances of the failure. - React with a 👍 to signal how many are affected. - If your failure **is not listed**, you should **file an issue**. @@ -19,25 +19,25 @@ the failure? 👉 [New CI Failure Report](https://github.com/vllm-project/vllm/issues/new?template=450-ci-failure.yml) - **Use this title format:** - + ``` [CI Failure]: failing-test-job - regex/matching/failing:test ``` - **For the environment field:** - + ``` Still failing on main as of commit abcdef123 ``` - **In the description, include failing tests:** - + ``` - FAILED failing/test.py:failing_test1 - Failure description - FAILED failing/test.py:failing_test2 - Failure description - https://github.com/orgs/vllm-project/projects/20 - https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml - FAILED failing/test.py:failing_test3 - Failure description + FAILED failing/test.py:failing_test1 - Failure description + FAILED failing/test.py:failing_test2 - Failure description + https://github.com/orgs/vllm-project/projects/20 + https://github.com/vllm-project/vllm/issues/new?template=400-bug-report.yml + FAILED failing/test.py:failing_test3 - Failure description ``` - **Attach logs** (collapsible section example): @@ -45,17 +45,17 @@ the failure? Logs: ```text - ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data + ERROR 05-20 03:26:38 [dump_input.py:68] Dumping input data --- Logging error --- Traceback (most recent call last): File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 203, in execute_model - return self.model_executor.execute_model(scheduler_output) + return self.model_executor.execute_model(scheduler_output) ... - FAILED failing/test.py:failing_test1 - Failure description - FAILED failing/test.py:failing_test2 - Failure description - FAILED failing/test.py:failing_test3 - Failure description + FAILED failing/test.py:failing_test1 - Failure description + FAILED failing/test.py:failing_test2 - Failure description + FAILED failing/test.py:failing_test3 - Failure description ``` - + ## Logs Wrangling @@ -78,7 +78,7 @@ tail -525 ci_build.log | wl-copy ## Investigating a CI Test Failure -1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main) +1. Go to 👉 [Buildkite main branch](https://buildkite.com/vllm/ci/builds?branch=main) 2. Bisect to find the first build that shows the issue. 3. Add your findings to the GitHub issue. 4. If you find a strong candidate PR, mention it in the issue and ping contributors. @@ -97,9 +97,9 @@ CI test failures may be flaky. Use a bash loop to run repeatedly: If you submit a PR to fix a CI failure: -- Link the PR to the issue: +- Link the PR to the issue: Add `Closes #12345` to the PR description. -- Add the `ci-failure` label: +- Add the `ci-failure` label: This helps track it in the [CI Failures GitHub Project](https://github.com/orgs/vllm-project/projects/20). ## Other Resources diff --git a/docs/ci/update_pytorch_version.md b/docs/contributing/ci/update_pytorch_version.md similarity index 55% rename from docs/ci/update_pytorch_version.md rename to docs/contributing/ci/update_pytorch_version.md index 69fdc82ef971..1fe18d5d8856 100644 --- a/docs/ci/update_pytorch_version.md +++ b/docs/contributing/ci/update_pytorch_version.md @@ -1,15 +1,12 @@ ---- -title: Update PyTorch version on vLLM OSS CI/CD ---- +# Update PyTorch version on vLLM OSS CI/CD vLLM's current policy is to always use the latest PyTorch stable release in CI/CD. It is standard practice to submit a PR to update the PyTorch version as early as possible when a new [PyTorch stable release](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-cadence) becomes available. This process is non-trivial due to the gap between PyTorch -releases. Using [#16859](https://github.com/vllm-project/vllm/pull/16859) as -an example, this document outlines common steps to achieve this update along with -a list of potential issues and how to address them. +releases. Using as an example, this document outlines common steps to achieve this +update along with a list of potential issues and how to address them. ## Test PyTorch release candidates (RCs) @@ -19,11 +16,12 @@ by waiting for the next release or by implementing hacky workarounds in vLLM. The better solution is to test vLLM with PyTorch release candidates (RC) to ensure compatibility before each release. -PyTorch release candidates can be downloaded from PyTorch test index at https://download.pytorch.org/whl/test. -For example, torch2.7.0+cu12.8 RC can be installed using the following command: +PyTorch release candidates can be downloaded from [PyTorch test index](https://download.pytorch.org/whl/test). +For example, `torch2.7.0+cu12.8` RC can be installed using the following command: -``` -uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 +```bash +uv pip install torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/test/cu128 ``` When the final RC is ready for testing, it will be announced to the community @@ -31,13 +29,28 @@ on the [PyTorch dev-discuss forum](https://dev-discuss.pytorch.org/c/release-ann After this announcement, we can begin testing vLLM integration by drafting a pull request following this 3-step process: -1. Update requirements files in https://github.com/vllm-project/vllm/tree/main/requirements -to point to the new releases for torch, torchvision, and torchaudio. -2. Use `--extra-index-url https://download.pytorch.org/whl/test/` to -get the final release candidates' wheels. Some common platforms are `cpu`, `cu128`, -and `rocm6.2.4`. -3. As vLLM uses uv, make sure that `unsafe-best-match` strategy is set either -via `UV_INDEX_STRATEGY` env variable or via `--index-strategy unsafe-best-match`. +1. Update [requirements files](https://github.com/vllm-project/vllm/tree/main/requirements) +to point to the new releases for `torch`, `torchvision`, and `torchaudio`. + +2. Use the following option to get the final release candidates' wheels. Some common platforms are `cpu`, `cu128`, and `rocm6.2.4`. + + ```bash + --extra-index-url https://download.pytorch.org/whl/test/ + ``` + +3. Since vLLM uses `uv`, ensure the following index strategy is applied: + + - Via environment variable: + + ```bash + export UV_INDEX_STRATEGY=unsafe-best-match + ``` + + - Or via CLI flag: + + ```bash + --index-strategy unsafe-best-match + ``` If failures are found in the pull request, raise them as issues on vLLM and cc the PyTorch release team to initiate discussion on how to address them. @@ -45,20 +58,25 @@ cc the PyTorch release team to initiate discussion on how to address them. ## Update CUDA version The PyTorch release matrix includes both stable and experimental [CUDA versions](https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix). Due to limitations, only the latest stable CUDA version (for example, -torch2.7.0+cu12.6) is uploaded to PyPI. However, vLLM may require a different CUDA version, +`torch2.7.0+cu12.6`) is uploaded to PyPI. However, vLLM may require a different CUDA version, such as 12.8 for Blackwell support. This complicates the process as we cannot use the out-of-the-box `pip install torch torchvision torchaudio` command. The solution is to use `--extra-index-url` in vLLM's Dockerfiles. -1. Use `--extra-index-url https://download.pytorch.org/whl/cu128` to install torch+cu128. -2. Other important indexes at the moment include: - 1. CPU ‒ https://download.pytorch.org/whl/cpu - 2. ROCm ‒ https://download.pytorch.org/whl/rocm6.2.4 and https://download.pytorch.org/whl/rocm6.3 - 3. XPU ‒ https://download.pytorch.org/whl/xpu -3. Update .buildkite/release-pipeline.yaml and .buildkite/scripts/upload-wheels.sh to -match the CUDA version from step 1. This makes sure that the release vLLM wheel is tested -on CI. +- Important indexes at the moment include: + +| Platform | `--extra-index-url` | +|----------|-----------------| +| CUDA 12.8| [https://download.pytorch.org/whl/cu128](https://download.pytorch.org/whl/cu128)| +| CPU | [https://download.pytorch.org/whl/cpu](https://download.pytorch.org/whl/cpu)| +| ROCm 6.2 | [https://download.pytorch.org/whl/rocm6.2.4](https://download.pytorch.org/whl/rocm6.2.4) | +| ROCm 6.3 | [https://download.pytorch.org/whl/rocm6.3](https://download.pytorch.org/whl/rocm6.3) | +| XPU | [https://download.pytorch.org/whl/xpu](https://download.pytorch.org/whl/xpu) | + +- Update the below files to match the CUDA version from step 1. This makes sure that the release vLLM wheel is tested on CI. + - `.buildkite/release-pipeline.yaml` + - `.buildkite/scripts/upload-wheels.sh` ## Address long vLLM build time @@ -68,8 +86,8 @@ and timeout. Additionally, since vLLM's fastcheck pipeline runs in read-only mod it doesn't populate the cache, so re-running it to warm up the cache is ineffective. -While ongoing efforts like [#17419](https://github.com/vllm-project/vllm/issues/17419) -address the long build time at its source, the current workaround is to set VLLM_CI_BRANCH +While ongoing efforts like [#17419](gh-issue:17419) +address the long build time at its source, the current workaround is to set `VLLM_CI_BRANCH` to a custom branch provided by @khluu (`VLLM_CI_BRANCH=khluu/use_postmerge_q`) when manually triggering a build on Buildkite. This branch accomplishes two things: @@ -89,17 +107,18 @@ releases (which would take too much time), they can be built from source to unblock the update process. ### FlashInfer -Here is how to build and install it from source with torch2.7.0+cu128 in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271): +Here is how to build and install it from source with `torch2.7.0+cu128` in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271): ```bash export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX' export FLASHINFER_ENABLE_SM90=1 -uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" +uv pip install --system \ + --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1" ``` One caveat is that building FlashInfer from source adds approximately 30 minutes to the vLLM build time. Therefore, it's preferable to cache the wheel in a -public location for immediate installation, such as https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl. For future releases, contact the PyTorch release +public location for immediate installation, such as [this FlashInfer wheel link](https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl). For future releases, contact the PyTorch release team if you want to get the package published there. ### xFormers @@ -107,13 +126,15 @@ Similar to FlashInfer, here is how to build and install xFormers from source: ```bash export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX' -MAX_JOBS=16 uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" +MAX_JOBS=16 uv pip install --system \ + --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30" ``` ### Mamba ```bash -uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" +uv pip install --system \ + --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" ``` ### causal-conv1d @@ -128,7 +149,6 @@ Rather than attempting to update all vLLM platforms in a single pull request, it to handle some platforms separately. The separation of requirements and Dockerfiles for different platforms in vLLM CI/CD allows us to selectively choose which platforms to update. For instance, updating XPU requires the corresponding -release from https://github.com/intel/intel-extension-for-pytorch by Intel. -While https://github.com/vllm-project/vllm/pull/16859 updated vLLM to PyTorch -2.7.0 on CPU, CUDA, and ROCm, https://github.com/vllm-project/vllm/pull/17444 -completed the update for XPU. +release from [Intel Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch) by Intel. +While updated vLLM to PyTorch 2.7.0 on CPU, CUDA, and ROCm, + completed the update for XPU. diff --git a/docs/contributing/dockerfile/dockerfile.md b/docs/contributing/dockerfile/dockerfile.md index a39f335c87b8..a7ff99aa26d5 100644 --- a/docs/contributing/dockerfile/dockerfile.md +++ b/docs/contributing/dockerfile/dockerfile.md @@ -1,7 +1,7 @@ # Dockerfile We provide a to construct the image for running an OpenAI compatible server with vLLM. -More information about deploying with Docker can be found [here][deployment-docker]. +More information about deploying with Docker can be found [here](../../deployment/docker.md). Below is a visual representation of the multi-stage Dockerfile. The build graph contains the following nodes: diff --git a/docs/contributing/incremental_build.md b/docs/contributing/incremental_build.md index 33584fdd5d40..0e34e69245af 100644 --- a/docs/contributing/incremental_build.md +++ b/docs/contributing/incremental_build.md @@ -84,6 +84,7 @@ Below is an example of what the generated `CMakeUserPresets.json` might look lik ``` **What do the various configurations mean?** + - `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically. - `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default. - `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable. @@ -98,16 +99,16 @@ Once your `CMakeUserPresets.json` is configured: 1. **Initialize the CMake build environment:** This step configures the build system according to your chosen preset (e.g., `release`) and creates the build directory at `binaryDir` - ```console - cmake --preset release - ``` + ```console + cmake --preset release + ``` 2. **Build and install the vLLM components:** This command compiles the code and installs the resulting binaries into your vLLM source directory, making them available to your editable Python installation. - ```console - cmake --build --preset release --target install - ``` + ```console + cmake --build --preset release --target install + ``` 3. **Make changes and repeat!** Now you start using your editable install of vLLM, testing and making changes as needed. If you need to build again to update based on changes, simply run the CMake command again to build only the affected files. diff --git a/docs/contributing/model/README.md b/docs/contributing/model/README.md index 63abb7991050..0ca77fa499db 100644 --- a/docs/contributing/model/README.md +++ b/docs/contributing/model/README.md @@ -1,12 +1,9 @@ ---- -title: Summary ---- -[](){ #new-model } +# Summary !!! important Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve ` works first! -vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features][compatibility-matrix] to optimize their performance. +vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features](../../features/compatibility_matrix.md) to optimize their performance. The complexity of integrating a model into vLLM depends heavily on the model's architecture. The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM. diff --git a/docs/contributing/model/basic.md b/docs/contributing/model/basic.md index d552cd06be20..edd9a47e132f 100644 --- a/docs/contributing/model/basic.md +++ b/docs/contributing/model/basic.md @@ -1,7 +1,4 @@ ---- -title: Basic Model ---- -[](){ #new-model-basic } +# Basic Model This guide walks you through the steps to implement a basic vLLM model. @@ -27,7 +24,7 @@ All vLLM modules within the model must include a `prefix` argument in their cons The initialization code should look like this: -??? Code +??? code ```python from torch import nn @@ -76,6 +73,8 @@ def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: ... ``` @@ -108,7 +107,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a ## 5. Register your model -See [this page][new-model-registration] for instructions on how to register your new model to be used by vLLM. +See [this page](registration.md) for instructions on how to register your new model to be used by vLLM. ## Frequently Asked Questions diff --git a/docs/contributing/model/multimodal.md b/docs/contributing/model/multimodal.md index ed1cd46dd858..3295b8c711c0 100644 --- a/docs/contributing/model/multimodal.md +++ b/docs/contributing/model/multimodal.md @@ -1,18 +1,15 @@ ---- -title: Multi-Modal Support ---- -[](){ #supports-multimodal } +# Multi-Modal Support -This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs][multimodal-inputs]. +This document walks you through the steps to extend a basic model so that it accepts [multi-modal inputs](../../features/multimodal_inputs.md). ## 1. Update the base vLLM model -It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic]. +It is assumed that you have already implemented the model in vLLM according to [these steps](basic.md). Further update the model as follows: - Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model. - ??? Code + ??? code ```python class YourModelForImage2Seq(nn.Module): @@ -41,7 +38,7 @@ Further update the model as follows: - Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs. - ??? Code + ??? code ```python class YourModelForImage2Seq(nn.Module): @@ -71,7 +68,7 @@ Further update the model as follows: - Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings. - ??? Code + ??? code ```python from .utils import merge_multimodal_embeddings @@ -155,7 +152,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in Looking at the code of HF's `LlavaForConditionalGeneration`: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544 @@ -179,7 +176,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in The number of placeholder feature tokens per image is `image_features.shape[1]`. `image_features` is calculated inside the `get_image_features` method: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300 @@ -217,7 +214,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257 @@ -244,7 +241,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in Overall, the number of placeholder feature tokens for an image can be calculated as: - ??? Code + ??? code ```python def get_num_image_tokens( @@ -269,7 +266,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in Notice that the number of image tokens doesn't depend on the image width and height. We can simply use a dummy `image_size` to calculate the multimodal profiling data: - ??? Code + ??? code ```python # NOTE: In actuality, this is usually implemented as part of the @@ -314,7 +311,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in Looking at the code of HF's `FuyuForCausalLM`: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322 @@ -344,7 +341,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`, returning the dimensions after resizing (but before padding) as metadata. - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544 @@ -382,7 +379,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425 @@ -420,7 +417,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562 @@ -457,7 +454,7 @@ Assuming that the memory usage increases with the number of tokens, the dummy in For the multimodal image profiling data, the logic is very similar to LLaVA: - ??? Code + ??? code ```python def get_dummy_mm_data( @@ -483,7 +480,7 @@ Afterwards, create a subclass of [BaseMultiModalProcessor][vllm.multimodal.proce to fill in the missing details about HF processing. !!! info - [Multi-Modal Data Processing][mm-processing] + [Multi-Modal Data Processing](../../design/mm_processing.md) ### Multi-modal fields @@ -546,7 +543,7 @@ return a schema of the tensors outputted by the HF processor that are related to In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA, we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]: - ??? Code + ??? code ```python def _call_hf_processor( @@ -623,7 +620,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`). Based on this, we override [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] as follows: - ??? Code + ??? code ```python def _get_prompt_updates( @@ -668,7 +665,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies We define a helper function to return `ncols` and `nrows` directly: - ??? Code + ??? code ```python def get_image_feature_grid_size( @@ -698,7 +695,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies Based on this, we can initially define our replacement tokens as: - ??? Code + ??? code ```python def get_replacement(item_idx: int): @@ -718,7 +715,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called, a BOS token (``) is also added to the promopt: - ??? Code + ??? code ```python # https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435 @@ -745,7 +742,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies To assign the vision embeddings to only the image tokens, instead of a string you can return an instance of [PromptUpdateDetails][vllm.multimodal.processing.PromptUpdateDetails]: - ??? Code + ??? code ```python hf_config = self.info.get_hf_config() @@ -772,7 +769,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt, we can search for it to conduct the replacement at the start of the string: - ??? Code + ??? code ```python def _get_prompt_updates( @@ -819,7 +816,7 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies After you have defined [BaseProcessingInfo][vllm.multimodal.processing.BaseProcessingInfo] (Step 2), [BaseDummyInputsBuilder][vllm.multimodal.profiling.BaseDummyInputsBuilder] (Step 3), and [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] (Step 4), -decorate the model class with {meth}`MULTIMODAL_REGISTRY.register_processor ` +decorate the model class with [MULTIMODAL_REGISTRY.register_processor][vllm.multimodal.processing.MultiModalRegistry.register_processor] to register them to the multi-modal registry: ```diff @@ -846,7 +843,7 @@ Examples: ### Handling prompt updates unrelated to multi-modal data -[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design][mm-processing]. +[_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] assumes that each application of prompt update corresponds to one multi-modal item. If the HF processor performs additional processing regardless of how many multi-modal items there are, you should override [_apply_hf_processor_tokens_only][vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_tokens_only] so that the processed token inputs are consistent with the result of applying the HF processor on text inputs. This is because token inputs bypass the HF processor according to [our design](../../design/mm_processing.md). Examples: diff --git a/docs/contributing/model/registration.md b/docs/contributing/model/registration.md index 758caa72cd4a..35f35ffa4cde 100644 --- a/docs/contributing/model/registration.md +++ b/docs/contributing/model/registration.md @@ -1,10 +1,7 @@ ---- -title: Registering a Model ---- -[](){ #new-model-registration } +# Registering a Model vLLM relies on a model registry to determine how to run each model. -A list of pre-registered architectures can be found [here][supported-models]. +A list of pre-registered architectures can be found [here](../../models/supported_models.md). If your model is not on this list, you must register it to vLLM. This page provides detailed instructions on how to do so. @@ -14,16 +11,16 @@ This page provides detailed instructions on how to do so. To add a model directly to the vLLM library, start by forking our [GitHub repository](https://github.com/vllm-project/vllm) and then [build it from source][build-from-source]. This gives you the ability to modify the codebase and test your model. -After you have implemented your model (see [tutorial][new-model-basic]), put it into the directory. +After you have implemented your model (see [tutorial](basic.md)), put it into the directory. Then, add your model class to `_VLLM_MODELS` in so that it is automatically registered upon importing vLLM. -Finally, update our [list of supported models][supported-models] to promote your model! +Finally, update our [list of supported models](../../models/supported_models.md) to promote your model! !!! important The list of models in each section should be maintained in alphabetical order. ## Out-of-tree models -You can load an external model [using a plugin][plugin-system] without modifying the vLLM codebase. +You can load an external model [using a plugin](../../design/plugin_system.md) without modifying the vLLM codebase. To register the model, use the following code: @@ -51,4 +48,4 @@ def register(): !!! important If your model is a multimodal model, ensure the model class implements the [SupportsMultiModal][vllm.model_executor.models.interfaces.SupportsMultiModal] interface. - Read more about that [here][supports-multimodal]. + Read more about that [here](multimodal.md). diff --git a/docs/contributing/model/tests.md b/docs/contributing/model/tests.md index c7bcc02a8b80..1206ad36771e 100644 --- a/docs/contributing/model/tests.md +++ b/docs/contributing/model/tests.md @@ -1,7 +1,4 @@ ---- -title: Unit Testing ---- -[](){ #new-model-tests } +# Unit Testing This page explains how to write unit tests to verify the implementation of your model. diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 20f4867057d3..a5851cfe963d 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -125,7 +125,7 @@ to manually kill the profiler and generate your `nsys-rep` report. You can view these profiles either as summaries in the CLI, using `nsys stats [profile-file]`, or in the GUI by installing Nsight [locally following the directions here](https://developer.nvidia.com/nsight-systems/get-started). -??? CLI example +??? console "CLI example" ```bash nsys stats report1.nsys-rep diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md index 5f6a22c28c28..e500751896b3 100644 --- a/docs/deployment/docker.md +++ b/docs/deployment/docker.md @@ -1,7 +1,4 @@ ---- -title: Using Docker ---- -[](){ #deployment-docker } +# Using Docker [](){ #deployment-docker-pre-built-image } @@ -32,7 +29,7 @@ podman run --gpus all \ --model mistralai/Mistral-7B-v0.1 ``` -You can add any other [engine-args][engine-args] you need after the image tag (`vllm/vllm-openai:latest`). +You can add any other [engine-args](../configuration/engine_args.md) you need after the image tag (`vllm/vllm-openai:latest`). !!! note You can either use the `ipc=host` flag or `--shm-size` flag to allow the @@ -97,7 +94,7 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `-- flags to speed up build process. However, ensure your `max_jobs` is substantially larger than `nvcc_threads` to get the most benefits. Keep an eye on memory usage with parallel jobs as it can be substantial (see example below). -??? Command +??? console "Command" ```bash # Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB) diff --git a/docs/deployment/frameworks/anyscale.md b/docs/deployment/frameworks/anyscale.md new file mode 100644 index 000000000000..9957c5b14134 --- /dev/null +++ b/docs/deployment/frameworks/anyscale.md @@ -0,0 +1,17 @@ +# Anyscale + +[](){ #deployment-anyscale } + +[Anyscale](https://www.anyscale.com) is a managed, multi-cloud platform developed by the creators of Ray. + +Anyscale automates the entire lifecycle of Ray clusters in your AWS, GCP, or Azure account, delivering the flexibility of open-source Ray +without the operational overhead of maintaining Kubernetes control planes, configuring autoscalers, managing observability stacks, or manually managing head and worker nodes with helper scripts like . + +When serving large language models with vLLM, Anyscale can rapidly provision [production-ready HTTPS endpoints](https://docs.anyscale.com/examples/deploy-ray-serve-llms) or [fault-tolerant batch inference jobs](https://docs.anyscale.com/examples/ray-data-llm). + +## Production-ready vLLM on Anyscale quickstarts + +- [Offline batch inference](https://console.anyscale.com/template-preview/llm_batch_inference?utm_source=vllm_docs) +- [Deploy vLLM services](https://console.anyscale.com/template-preview/llm_serving?utm_source=vllm_docs) +- [Curate a dataset](https://console.anyscale.com/template-preview/audio-dataset-curation-llm-judge?utm_source=vllm_docs) +- [Finetune an LLM](https://console.anyscale.com/template-preview/entity-recognition-with-llms?utm_source=vllm_docs) diff --git a/docs/deployment/frameworks/anything-llm.md b/docs/deployment/frameworks/anything-llm.md index 4633c2946cde..d6b28a358cc3 100644 --- a/docs/deployment/frameworks/anything-llm.md +++ b/docs/deployment/frameworks/anything-llm.md @@ -1,7 +1,4 @@ ---- -title: Anything LLM ---- -[](){ #deployment-anything-llm } +# Anything LLM [Anything LLM](https://github.com/Mintplex-Labs/anything-llm) is a full-stack application that enables you to turn any document, resource, or piece of content into context that any LLM can use as references during chatting. diff --git a/docs/deployment/frameworks/autogen.md b/docs/deployment/frameworks/autogen.md index 13930e67ab2f..c255a85d3840 100644 --- a/docs/deployment/frameworks/autogen.md +++ b/docs/deployment/frameworks/autogen.md @@ -1,7 +1,4 @@ ---- -title: AutoGen ---- -[](){ #deployment-autogen } +# AutoGen [AutoGen](https://github.com/microsoft/autogen) is a framework for creating multi-agent AI applications that can act autonomously or work alongside humans. @@ -30,7 +27,7 @@ python -m vllm.entrypoints.openai.api_server \ - Call it with AutoGen: -??? Code +??? code ```python import asyncio diff --git a/docs/deployment/frameworks/bentoml.md b/docs/deployment/frameworks/bentoml.md index 7e64b6eb6fb0..9c8f2527f2e2 100644 --- a/docs/deployment/frameworks/bentoml.md +++ b/docs/deployment/frameworks/bentoml.md @@ -1,7 +1,4 @@ ---- -title: BentoML ---- -[](){ #deployment-bentoml } +# BentoML [BentoML](https://github.com/bentoml/BentoML) allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. You can serve the model locally or containerize it as an OCI-compliant image and deploy it on Kubernetes. diff --git a/docs/deployment/frameworks/cerebrium.md b/docs/deployment/frameworks/cerebrium.md index 5c5f2f48d50b..1f233c3204a1 100644 --- a/docs/deployment/frameworks/cerebrium.md +++ b/docs/deployment/frameworks/cerebrium.md @@ -1,7 +1,4 @@ ---- -title: Cerebrium ---- -[](){ #deployment-cerebrium } +# Cerebrium

vLLM_plus_cerebrium @@ -34,7 +31,7 @@ vllm = "latest" Next, let us add our code to handle inference for the LLM of your choice (`mistralai/Mistral-7B-Instruct-v0.1` for this example), add the following code to your `main.py`: -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -64,7 +61,7 @@ cerebrium deploy If successful, you should be returned a CURL command that you can call inference against. Just remember to end the url with the function name you are calling (in our case`/run`) -??? Command +??? console "Command" ```python curl -X POST https://api.cortex.cerebrium.ai/v4/p-xxxxxx/vllm/run \ @@ -82,7 +79,7 @@ If successful, you should be returned a CURL command that you can call inference You should get a response like: -??? Response +??? console "Response" ```python { diff --git a/docs/deployment/frameworks/chatbox.md b/docs/deployment/frameworks/chatbox.md index b1b50b55146c..15f92ed1e34d 100644 --- a/docs/deployment/frameworks/chatbox.md +++ b/docs/deployment/frameworks/chatbox.md @@ -1,7 +1,4 @@ ---- -title: Chatbox ---- -[](){ #deployment-chatbox } +# Chatbox [Chatbox](https://github.com/chatboxai/chatbox) is a desktop client for LLMs, available on Windows, Mac, Linux. diff --git a/docs/deployment/frameworks/dify.md b/docs/deployment/frameworks/dify.md index a0e40784f0ea..a3063194fb51 100644 --- a/docs/deployment/frameworks/dify.md +++ b/docs/deployment/frameworks/dify.md @@ -1,7 +1,4 @@ ---- -title: Dify ---- -[](){ #deployment-dify } +# Dify [Dify](https://github.com/langgenius/dify) is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. diff --git a/docs/deployment/frameworks/dstack.md b/docs/deployment/frameworks/dstack.md index 8b4bc459683b..23dc58c974ed 100644 --- a/docs/deployment/frameworks/dstack.md +++ b/docs/deployment/frameworks/dstack.md @@ -1,7 +1,4 @@ ---- -title: dstack ---- -[](){ #deployment-dstack } +# dstack

vLLM_plus_dstack @@ -26,7 +23,7 @@ dstack init Next, to provision a VM instance with LLM of your choice (`NousResearch/Llama-2-7b-chat-hf` for this example), create the following `serve.dstack.yml` file for the dstack `Service`: -??? Config +??? code "Config" ```yaml type: service @@ -48,7 +45,7 @@ Next, to provision a VM instance with LLM of your choice (`NousResearch/Llama-2- Then, run the following CLI for provisioning: -??? Command +??? console "Command" ```console $ dstack run . -f serve.dstack.yml @@ -79,7 +76,7 @@ Then, run the following CLI for provisioning: After the provisioning, you can interact with the model by using the OpenAI SDK: -??? Code +??? code ```python from openai import OpenAI diff --git a/docs/deployment/frameworks/haystack.md b/docs/deployment/frameworks/haystack.md index 7a4cab4c2ee3..a18d68142cab 100644 --- a/docs/deployment/frameworks/haystack.md +++ b/docs/deployment/frameworks/haystack.md @@ -1,7 +1,4 @@ ---- -title: Haystack ---- -[](){ #deployment-haystack } +# Haystack # Haystack @@ -27,7 +24,7 @@ vllm serve mistralai/Mistral-7B-Instruct-v0.1 - Use the `OpenAIGenerator` and `OpenAIChatGenerator` components in Haystack to query the vLLM server. -??? Code +??? code ```python from haystack.components.generators.chat import OpenAIChatGenerator diff --git a/docs/deployment/frameworks/helm.md b/docs/deployment/frameworks/helm.md index d929665e8a3d..e5d44945ba72 100644 --- a/docs/deployment/frameworks/helm.md +++ b/docs/deployment/frameworks/helm.md @@ -1,7 +1,4 @@ ---- -title: Helm ---- -[](){ #deployment-helm } +# Helm A Helm chart to deploy vLLM for Kubernetes diff --git a/docs/deployment/frameworks/litellm.md b/docs/deployment/frameworks/litellm.md index 8279613b1a27..c7e514f2276e 100644 --- a/docs/deployment/frameworks/litellm.md +++ b/docs/deployment/frameworks/litellm.md @@ -1,7 +1,4 @@ ---- -title: LiteLLM ---- -[](){ #deployment-litellm } +# LiteLLM [LiteLLM](https://github.com/BerriAI/litellm) call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, Groq etc.] @@ -34,7 +31,7 @@ vllm serve qwen/Qwen1.5-0.5B-Chat - Call it with litellm: -??? Code +??? code ```python import litellm diff --git a/docs/deployment/frameworks/lobe-chat.md b/docs/deployment/frameworks/lobe-chat.md index cd95c028155e..e3e7dbe6e1e8 100644 --- a/docs/deployment/frameworks/lobe-chat.md +++ b/docs/deployment/frameworks/lobe-chat.md @@ -1,7 +1,4 @@ ---- -title: Lobe Chat ---- -[](){ #deployment-lobe-chat } +# Lobe Chat [Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework. diff --git a/docs/deployment/frameworks/lws.md b/docs/deployment/frameworks/lws.md index 9df952876906..3319dc6c90e1 100644 --- a/docs/deployment/frameworks/lws.md +++ b/docs/deployment/frameworks/lws.md @@ -1,7 +1,4 @@ ---- -title: LWS ---- -[](){ #deployment-lws } +# LWS LeaderWorkerSet (LWS) is a Kubernetes API that aims to address common deployment patterns of AI/ML inference workloads. A major use case is for multi-host/multi-node distributed inference. @@ -17,7 +14,7 @@ vLLM can be deployed with [LWS](https://github.com/kubernetes-sigs/lws) on Kuber Deploy the following yaml file `lws.yaml` -??? Yaml +??? code "Yaml" ```yaml apiVersion: leaderworkerset.x-k8s.io/v1 @@ -177,7 +174,7 @@ curl http://localhost:8080/v1/completions \ The output should be similar to the following -??? Output +??? console "Output" ```text { diff --git a/docs/deployment/frameworks/modal.md b/docs/deployment/frameworks/modal.md index dbdb739a1000..0ab5ed92fe6b 100644 --- a/docs/deployment/frameworks/modal.md +++ b/docs/deployment/frameworks/modal.md @@ -1,7 +1,4 @@ ---- -title: Modal ---- -[](){ #deployment-modal } +# Modal vLLM can be run on cloud GPUs with [Modal](https://modal.com), a serverless computing platform designed for fast auto-scaling. diff --git a/docs/deployment/frameworks/open-webui.md b/docs/deployment/frameworks/open-webui.md index 676a0f58b54f..eaa51bb61328 100644 --- a/docs/deployment/frameworks/open-webui.md +++ b/docs/deployment/frameworks/open-webui.md @@ -1,29 +1,42 @@ ---- -title: Open WebUI ---- -[](){ #deployment-open-webui } +# Open WebUI -1. Install the [Docker](https://docs.docker.com/engine/install/) +[Open WebUI](https://github.com/open-webui/open-webui) is an extensible, feature-rich, +and user-friendly self-hosted AI platform designed to operate entirely offline. +It supports various LLM runners like Ollama and OpenAI-compatible APIs, +with built-in RAG capabilities, making it a powerful AI deployment solution. -2. Start the vLLM server with the supported chat completion model, e.g. +To get started with Open WebUI using vLLM, follow these steps: -```bash -vllm serve qwen/Qwen1.5-0.5B-Chat -``` +1. Install the [Docker](https://docs.docker.com/engine/install/). -1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port): +2. Start the vLLM server with a supported chat completion model: -```bash -docker run -d -p 3000:8080 \ ---name open-webui \ --v open-webui:/app/backend/data \ --e OPENAI_API_BASE_URL=http://:/v1 \ ---restart always \ -ghcr.io/open-webui/open-webui:main -``` + ```console + vllm serve Qwen/Qwen3-0.6B-Chat + ``` -1. Open it in the browser: + !!! note + When starting the vLLM server, be sure to specify the host and port using the `--host` and `--port` flags. + For example: -On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`. + ```console + python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 + ``` -![](../../assets/deployment/open_webui.png) +3. Start the Open WebUI Docker container: + + ```console + docker run -d \ + --name open-webui \ + -p 3000:8080 \ + -v open-webui:/app/backend/data \ + -e OPENAI_API_BASE_URL=http://0.0.0.0:8000/v1 \ + --restart always \ + ghcr.io/open-webui/open-webui:main + ``` + +4. Open it in the browser: + + At the top of the page, you should see the model `Qwen/Qwen3-0.6B-Chat`. + + ![Web portal of model Qwen/Qwen3-0.6B-Chat](../../assets/deployment/open_webui.png) diff --git a/docs/deployment/frameworks/retrieval_augmented_generation.md b/docs/deployment/frameworks/retrieval_augmented_generation.md index 851c31db32f2..96dd99e7118b 100644 --- a/docs/deployment/frameworks/retrieval_augmented_generation.md +++ b/docs/deployment/frameworks/retrieval_augmented_generation.md @@ -1,7 +1,4 @@ ---- -title: Retrieval-Augmented Generation ---- -[](){ #deployment-retrieval-augmented-generation } +# Retrieval-Augmented Generation [Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources. diff --git a/docs/deployment/frameworks/skypilot.md b/docs/deployment/frameworks/skypilot.md index ecf987539ced..06e2fed38f05 100644 --- a/docs/deployment/frameworks/skypilot.md +++ b/docs/deployment/frameworks/skypilot.md @@ -1,7 +1,4 @@ ---- -title: SkyPilot ---- -[](){ #deployment-skypilot } +# SkyPilot

vLLM @@ -24,7 +21,7 @@ sky check See the vLLM SkyPilot YAML for serving, [serving.yaml](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm/serve.yaml). -??? Yaml +??? code "Yaml" ```yaml resources: @@ -95,7 +92,7 @@ HF_TOKEN="your-huggingface-token" \ SkyPilot can scale up the service to multiple service replicas with built-in autoscaling, load-balancing and fault-tolerance. You can do it by adding a services section to the YAML file. -??? Yaml +??? code "Yaml" ```yaml service: @@ -111,7 +108,7 @@ SkyPilot can scale up the service to multiple service replicas with built-in aut max_completion_tokens: 1 ``` -??? Yaml +??? code "Yaml" ```yaml service: @@ -186,7 +183,7 @@ vllm 2 1 xx.yy.zz.245 18 mins ago 1x GCP([Spot]{'L4': 1}) R After the service is READY, you can find a single endpoint for the service and access the service with the endpoint: -??? Commands +??? console "Commands" ```bash ENDPOINT=$(sky serve status --endpoint 8081 vllm) @@ -220,7 +217,7 @@ service: This will scale the service up to when the QPS exceeds 2 for each replica. -??? Yaml +??? code "Yaml" ```yaml service: @@ -285,7 +282,7 @@ sky serve down vllm It is also possible to access the Llama-3 service with a separate GUI frontend, so the user requests send to the GUI will be load-balanced across replicas. -??? Yaml +??? code "Yaml" ```yaml envs: diff --git a/docs/deployment/frameworks/streamlit.md b/docs/deployment/frameworks/streamlit.md index 5e998e3cca6e..af0f0690c68e 100644 --- a/docs/deployment/frameworks/streamlit.md +++ b/docs/deployment/frameworks/streamlit.md @@ -1,7 +1,4 @@ ---- -title: Streamlit ---- -[](){ #deployment-streamlit } +# Streamlit [Streamlit](https://github.com/streamlit/streamlit) lets you transform Python scripts into interactive web apps in minutes, instead of weeks. Build dashboards, generate reports, or create chat apps. diff --git a/docs/deployment/frameworks/triton.md b/docs/deployment/frameworks/triton.md index 082bc24d85aa..faff4a4263eb 100644 --- a/docs/deployment/frameworks/triton.md +++ b/docs/deployment/frameworks/triton.md @@ -1,6 +1,3 @@ ---- -title: NVIDIA Triton ---- -[](){ #deployment-triton } +# NVIDIA Triton The [Triton Inference Server](https://github.com/triton-inference-server) hosts a tutorial demonstrating how to quickly deploy a simple [facebook/opt-125m](https://huggingface.co/facebook/opt-125m) model using vLLM. Please see [Deploying a vLLM model in Triton](https://github.com/triton-inference-server/tutorials/blob/main/Quick_Deploy/vLLM/README.md#deploying-a-vllm-model-in-triton) for more details. diff --git a/docs/deployment/integrations/kserve.md b/docs/deployment/integrations/kserve.md index 754b983dee92..edf79fca4f93 100644 --- a/docs/deployment/integrations/kserve.md +++ b/docs/deployment/integrations/kserve.md @@ -1,7 +1,4 @@ ---- -title: KServe ---- -[](){ #deployment-kserve } +# KServe vLLM can be deployed with [KServe](https://github.com/kserve/kserve) on Kubernetes for highly scalable distributed model serving. diff --git a/docs/deployment/integrations/kubeai.md b/docs/deployment/integrations/kubeai.md index ba0a3c52cca7..89d072215e95 100644 --- a/docs/deployment/integrations/kubeai.md +++ b/docs/deployment/integrations/kubeai.md @@ -1,7 +1,4 @@ ---- -title: KubeAI ---- -[](){ #deployment-kubeai } +# KubeAI [KubeAI](https://github.com/substratusai/kubeai) is a Kubernetes operator that enables you to deploy and manage AI models on Kubernetes. It provides a simple and scalable way to deploy vLLM in production. Functionality such as scale-from-zero, load based autoscaling, model caching, and much more is provided out of the box with zero external dependencies. diff --git a/docs/deployment/integrations/kuberay.md b/docs/deployment/integrations/kuberay.md new file mode 100644 index 000000000000..1dcc98024e8d --- /dev/null +++ b/docs/deployment/integrations/kuberay.md @@ -0,0 +1,20 @@ +# KubeRay + +[KubeRay](https://github.com/ray-project/kuberay) provides a Kubernetes-native way to run vLLM workloads on Ray clusters. +A Ray cluster can be declared in YAML, and the operator then handles pod scheduling, networking configuration, restarts, and blue-green deployments — all while preserving the familiar Kubernetes experience. + +## Why KubeRay instead of manual scripts? + +| Feature | Manual scripts | KubeRay | +|---------|-----------------------------------------------------------|---------| +| Cluster bootstrap | Manually SSH into every node and run a script | One command to create or update the whole cluster: `kubectl apply -f cluster.yaml` | +| Autoscaling | Manual | Automatically patches CRDs for adjusting cluster size | +| Upgrades | Tear down & re-create manually | Blue/green deployment updates supported | +| Declarative config | Bash flags & environment variables | Git-ops-friendly YAML CRDs (RayCluster/RayService) | + +Using KubeRay reduces the operational burden and simplifies integration of Ray + vLLM with existing Kubernetes workflows (CI/CD, secrets, storage classes, etc.). + +## Learn more + +* ["Serve a Large Language Model using Ray Serve LLM on Kubernetes"](https://docs.ray.io/en/master/cluster/kubernetes/examples/rayserve-llm-example.html) - An end-to-end example of how to serve a model using vLLM, KubeRay, and Ray Serve. +* [KubeRay documentation](https://docs.ray.io/en/latest/cluster/kubernetes/index.html) diff --git a/docs/deployment/integrations/llamastack.md b/docs/deployment/integrations/llamastack.md index 9bbc6b5b296c..28031f01f85e 100644 --- a/docs/deployment/integrations/llamastack.md +++ b/docs/deployment/integrations/llamastack.md @@ -1,7 +1,4 @@ ---- -title: Llama Stack ---- -[](){ #deployment-llamastack } +# Llama Stack vLLM is also available via [Llama Stack](https://github.com/meta-llama/llama-stack) . diff --git a/docs/deployment/integrations/llmaz.md b/docs/deployment/integrations/llmaz.md index 03d284c34769..77730a26c24f 100644 --- a/docs/deployment/integrations/llmaz.md +++ b/docs/deployment/integrations/llmaz.md @@ -1,7 +1,4 @@ ---- -title: llmaz ---- -[](){ #deployment-llmaz } +# llmaz [llmaz](https://github.com/InftyAI/llmaz) is an easy-to-use and advanced inference platform for large language models on Kubernetes, aimed for production use. It uses vLLM as the default model serving backend. diff --git a/docs/deployment/integrations/production-stack.md b/docs/deployment/integrations/production-stack.md index 2b1cc6f6fee1..497f9f1a92a5 100644 --- a/docs/deployment/integrations/production-stack.md +++ b/docs/deployment/integrations/production-stack.md @@ -1,7 +1,4 @@ ---- -title: Production stack ---- -[](){ #deployment-production-stack } +# Production stack Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using the [vLLM production stack](https://github.com/vllm-project/production-stack). Born out of a Berkeley-UChicago collaboration, [vLLM production stack](https://github.com/vllm-project/production-stack) is an officially released, production-optimized codebase under the [vLLM project](https://github.com/vllm-project), designed for LLM deployment with: @@ -44,7 +41,8 @@ vllm-deployment-router-859d8fb668-2x2b7 1/1 Running 0 2m38 vllm-opt125m-deployment-vllm-84dfc9bd7-vb9bs 1/1 Running 0 2m38s ``` -**NOTE**: It may take some time for the containers to download the Docker images and LLM weights. +!!! note + It may take some time for the containers to download the Docker images and LLM weights. ### Send a Query to the Stack @@ -60,7 +58,7 @@ And then you can send out a query to the OpenAI-compatible API to check the avai curl -o- http://localhost:30080/models ``` -??? Output +??? console "Output" ```json { @@ -89,7 +87,7 @@ curl -X POST http://localhost:30080/completions \ }' ``` -??? Output +??? console "Output" ```json { @@ -121,7 +119,7 @@ sudo helm uninstall vllm The core vLLM production stack configuration is managed with YAML. Here is the example configuration used in the installation above: -??? Yaml +??? code "Yaml" ```yaml servingEngineSpec: @@ -152,6 +150,8 @@ In this YAML configuration: * **`requestGPU`**: Specifies the number of GPUs required. * **`pvcStorage`**: Allocates persistent storage for the model. -**NOTE:** If you intend to set up two pods, please refer to this [YAML file](https://github.com/vllm-project/production-stack/blob/main/tutorials/assets/values-01-2pods-minimal-example.yaml). +!!! note + If you intend to set up two pods, please refer to this [YAML file](https://github.com/vllm-project/production-stack/blob/main/tutorials/assets/values-01-2pods-minimal-example.yaml). -**NOTE:** vLLM production stack offers many more features (*e.g.* CPU offloading and a wide range of routing algorithms). Please check out these [examples and tutorials](https://github.com/vllm-project/production-stack/tree/main/tutorials) and our [repo](https://github.com/vllm-project/production-stack) for more details! +!!! tip + vLLM production stack offers many more features (*e.g.* CPU offloading and a wide range of routing algorithms). Please check out these [examples and tutorials](https://github.com/vllm-project/production-stack/tree/main/tutorials) and our [repo](https://github.com/vllm-project/production-stack) for more details! diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index f01e3d2fae0e..f244b0858eb6 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -1,7 +1,4 @@ ---- -title: Using Kubernetes ---- -[](){ #deployment-k8s } +# Using Kubernetes Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine learning models. This guide walks you through deploying vLLM using native Kubernetes. @@ -16,6 +13,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: - [Helm](frameworks/helm.md) - [InftyAI/llmaz](integrations/llmaz.md) - [KServe](integrations/kserve.md) +- [KubeRay](integrations/kuberay.md) - [kubernetes-sigs/lws](frameworks/lws.md) - [meta-llama/llama-stack](integrations/llamastack.md) - [substratusai/kubeai](integrations/kubeai.md) @@ -29,7 +27,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following: First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model: -??? Config +??? console "Config" ```bash cat < That code can be found in . -More details on the API server can be found in the [OpenAI-Compatible Server][serving-openai-compatible-server] document. +More details on the API server can be found in the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) document. ## LLM Engine @@ -132,7 +129,7 @@ input tensors and capturing cudagraphs. ## Model Every model runner object has one model object, which is the actual -`torch.nn.Module` instance. See [huggingface_integration][huggingface-integration] for how various +`torch.nn.Module` instance. See [huggingface_integration](huggingface_integration.md) for how various configurations affect the class we ultimately get. ## Class Hierarchy @@ -180,7 +177,7 @@ vision-language model. To avoid accidentally passing incorrect arguments, the constructor is now keyword-only. This ensures that the constructor will raise an error if old configurations are passed. vLLM developers have already made this change for all models within vLLM. For out-of-tree registered models, developers need to update their models, for example by adding shim code to adapt the old constructor signature to the new one: - ??? Code + ??? code ```python class MyOldModel(nn.Module): diff --git a/docs/design/automatic_prefix_caching.md b/docs/design/automatic_prefix_caching.md index 80883bb1d90d..60e21f6ad0fc 100644 --- a/docs/design/automatic_prefix_caching.md +++ b/docs/design/automatic_prefix_caching.md @@ -1,7 +1,4 @@ ---- -title: Automatic Prefix Caching ---- -[](){ #design-automatic-prefix-caching } +# Automatic Prefix Caching The core idea of [PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) is to partition the KV cache of each request into KV Blocks. Each block contains the attention keys and values for a fixed number of tokens. The PagedAttention algorithm allows these blocks to be stored in non-contiguous physical memory so that we can eliminate memory fragmentation by allocating the memory on demand. diff --git a/docs/design/huggingface_integration.md b/docs/design/huggingface_integration.md index 2d462ccb6535..7b01313ddb00 100644 --- a/docs/design/huggingface_integration.md +++ b/docs/design/huggingface_integration.md @@ -1,7 +1,4 @@ ---- -title: Integration with HuggingFace ---- -[](){ #huggingface-integration } +# Integration with HuggingFace This document describes how vLLM integrates with HuggingFace libraries. We will explain step by step what happens under the hood when we run `vllm serve`. diff --git a/docs/design/kernel/paged_attention.md b/docs/design/kernel/paged_attention.md index ff135a731960..94bfa97ee221 100644 --- a/docs/design/kernel/paged_attention.md +++ b/docs/design/kernel/paged_attention.md @@ -1,7 +1,4 @@ ---- -title: vLLM Paged Attention ---- -[](){ #design-paged-attention } +# vLLM Paged Attention Currently, vLLM utilizes its own implementation of a multi-head query attention kernel (`csrc/attention/attention_kernels.cu`). @@ -448,7 +445,7 @@ elements of the entire head for all context tokens. However, overall, all results for output have been calculated but are just stored in different thread register memory. -??? Code +??? code ```cpp float* out_smem = reinterpret_cast(shared_mem); diff --git a/docs/design/mm_processing.md b/docs/design/mm_processing.md index f3685ce76a4b..1e9b6ad6e821 100644 --- a/docs/design/mm_processing.md +++ b/docs/design/mm_processing.md @@ -1,9 +1,6 @@ ---- -title: Multi-Modal Data Processing ---- -[](){ #mm-processing } +# Multi-Modal Data Processing -To enable various optimizations in vLLM such as [chunked prefill][chunked-prefill] and [prefix caching][automatic-prefix-caching], we use [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. +To enable various optimizations in vLLM such as [chunked prefill][chunked-prefill] and [prefix caching](../features/automatic_prefix_caching.md), we use [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor] to provide the correspondence between placeholder feature tokens (e.g. ``) and multi-modal inputs (e.g. the raw input image) based on the outputs of HF processor. Here are the main features of [BaseMultiModalProcessor][vllm.multimodal.processing.BaseMultiModalProcessor]: diff --git a/docs/design/plugin_system.md b/docs/design/plugin_system.md index 944f0e680de4..23a05ac719ce 100644 --- a/docs/design/plugin_system.md +++ b/docs/design/plugin_system.md @@ -1,19 +1,16 @@ ---- -title: vLLM's Plugin System ---- -[](){ #plugin-system } +# vLLM's Plugin System The community frequently requests the ability to extend vLLM with custom features. To facilitate this, vLLM includes a plugin system that allows users to add custom features without modifying the vLLM codebase. This document explains how plugins work in vLLM and how to create a plugin for vLLM. ## How Plugins Work in vLLM -Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview][arch-overview]), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work. +Plugins are user-registered code that vLLM executes. Given vLLM's architecture (see [Arch Overview](arch_overview.md)), multiple processes may be involved, especially when using distributed inference with various parallelism techniques. To enable plugins successfully, every process created by vLLM needs to load the plugin. This is done by the [load_general_plugins](https://github.com/vllm-project/vllm/blob/c76ac49d266e27aa3fea84ef2df1f813d24c91c7/vllm/plugins/__init__.py#L16) function in the `vllm.plugins` module. This function is called for every process created by vLLM before it starts any work. ## How vLLM Discovers Plugins vLLM's plugin system uses the standard Python `entry_points` mechanism. This mechanism allows developers to register functions in their Python packages for use by other packages. An example of a plugin: -??? Code +??? code ```python # inside `setup.py` file diff --git a/docs/design/v1/metrics.md b/docs/design/v1/metrics.md index 7156ee9dd3ec..52cd320dd4e1 100644 --- a/docs/design/v1/metrics.md +++ b/docs/design/v1/metrics.md @@ -5,17 +5,17 @@ Ensure the v1 LLM Engine exposes a superset of the metrics available in v0. ## Objectives - Achieve parity of metrics between v0 and v1. -- The priority use case is accessing these metrics via Prometheus as this is what we expect to be used in production environments. -- Logging support - i.e. printing metrics to the info log - is provided for more ad-hoc testing, debugging, development, and exploratory use cases. +- The priority use case is accessing these metrics via Prometheus, as this is what we expect to be used in production environments. +- Logging support (i.e. printing metrics to the info log) is provided for more ad-hoc testing, debugging, development, and exploratory use cases. ## Background Metrics in vLLM can be categorized as follows: -1. Server-level metrics: these are global metrics that track the state and performance of the LLM engine. These are typically exposed as Gauges or Counters in Prometheus. -2. Request-level metrics: these are metrics that track the characteristics - e.g. size and timing - of individual requests. These are typically exposed as Histograms in Prometheus, and are often the SLO that an SRE monitoring vLLM will be tracking. +1. Server-level metrics: Global metrics that track the state and performance of the LLM engine. These are typically exposed as Gauges or Counters in Prometheus. +2. Request-level metrics: Metrics that track the characteristics (e.g. size and timing) of individual requests. These are typically exposed as Histograms in Prometheus and are often the SLOs that an SRE monitoring vLLM will be tracking. -The mental model is that the "Server-level Metrics" explain why the "Request-level Metrics" are what they are. +The mental model is that server-level metrics help explain the values of request-level metrics. ### v0 Metrics @@ -61,24 +61,24 @@ These are documented under [Inferencing and Serving -> Production Metrics](../.. ### Grafana Dashboard -vLLM also provides [a reference example](https://docs.vllm.ai/en/latest/examples/prometheus_grafana.html) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. +vLLM also provides [a reference example](../../examples/online_serving/prometheus_grafana.md) for how to collect and store these metrics using Prometheus and visualize them using a Grafana dashboard. The subset of metrics exposed in the Grafana dashboard gives us an indication of which metrics are especially important: -- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds -- `vllm:prompt_tokens_total` - Prompt Tokens -- `vllm:generation_tokens_total` - Generation Tokens -- `vllm:time_per_output_token_seconds` - Inter token latency (Time Per Output Token, TPOT) in second. +- `vllm:e2e_request_latency_seconds_bucket` - End to end request latency measured in seconds. +- `vllm:prompt_tokens_total` - Prompt tokens. +- `vllm:generation_tokens_total` - Generation tokens. +- `vllm:time_per_output_token_seconds` - Inter-token latency (Time Per Output Token, TPOT) in seconds. - `vllm:time_to_first_token_seconds` - Time to First Token (TTFT) latency in seconds. -- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in RUNNING, WAITING, and SWAPPED state +- `vllm:num_requests_running` (also, `_swapped` and `_waiting`) - Number of requests in the RUNNING, WAITING, and SWAPPED states. - `vllm:gpu_cache_usage_perc` - Percentage of used cache blocks by vLLM. -- `vllm:request_prompt_tokens` - Request prompt length -- `vllm:request_generation_tokens` - request generation length -- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached -- `vllm:request_queue_time_seconds` - Queue Time -- `vllm:request_prefill_time_seconds` - Requests Prefill Time -- `vllm:request_decode_time_seconds` - Requests Decode Time -- `vllm:request_max_num_generation_tokens` - Max Generation Token in Sequence Group +- `vllm:request_prompt_tokens` - Request prompt length. +- `vllm:request_generation_tokens` - Request generation length. +- `vllm:request_success_total` - Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached. +- `vllm:request_queue_time_seconds` - Queue time. +- `vllm:request_prefill_time_seconds` - Requests prefill time. +- `vllm:request_decode_time_seconds` - Requests decode time. +- `vllm:request_max_num_generation_tokens` - Max generation tokens in a sequence group. See [the PR which added this Dashboard](gh-pr:2316) for interesting and useful background on the choices made here. @@ -103,7 +103,7 @@ In v0, metrics are collected in the engine core process and we use multi-process ### Built in Python/Process Metrics -The following metrics are supported by default by `prometheus_client`, but the are not exposed with multiprocess mode is used: +The following metrics are supported by default by `prometheus_client`, but they are not exposed when multi-process mode is used: - `python_gc_objects_collected_total` - `python_gc_objects_uncollectable_total` @@ -158,6 +158,7 @@ In v1, we wish to move computation and overhead out of the engine core process to minimize the time between each forward pass. The overall idea of V1 EngineCore design is: + - EngineCore is the inner loop. Performance is most critical here - AsyncLLM is the outer loop. This is overlapped with GPU execution (ideally), so this is where any "overheads" should be if @@ -178,7 +179,7 @@ time" (`time.time()`) to calculate intervals as the former is unaffected by system clock changes (e.g. from NTP). It's also important to note that monotonic clocks differ between -processes - each process has its own reference. point. So it is +processes - each process has its own reference point. So it is meaningless to compare monotonic timestamps from different processes. Therefore, in order to calculate an interval, we must compare two @@ -343,14 +344,15 @@ vllm:time_to_first_token_seconds_bucket{le="0.1",model_name="meta-llama/Llama-3. vllm:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 140.0 ``` -Note - the choice of histogram buckets to be most useful to users -across a broad set of use cases is not straightforward and will -require refinement over time. +!!! note + The choice of histogram buckets to be most useful to users + across a broad set of use cases is not straightforward and will + require refinement over time. ### Cache Config Info -`prometheus_client` has support for [Info -metrics](https://prometheus.github.io/client_python/instrumenting/info/) +`prometheus_client` has support for +[Info metrics](https://prometheus.github.io/client_python/instrumenting/info/) which are equivalent to a `Gauge` whose value is permanently set to 1, but exposes interesting key/value pair information via labels. This is used for information about an instance that does not change - so it @@ -363,14 +365,11 @@ We use this concept for the `vllm:cache_config_info` metric: # HELP vllm:cache_config_info Information of the LLMEngine CacheConfig # TYPE vllm:cache_config_info gauge vllm:cache_config_info{block_size="16",cache_dtype="auto",calculate_kv_scales="False",cpu_offload_gb="0",enable_prefix_caching="False",gpu_memory_utilization="0.9",...} 1.0 - ``` -However, `prometheus_client` has [never supported Info metrics in -multiprocessing -mode](https://github.com/prometheus/client_python/pull/300) - for -[unclear -reasons](gh-pr:7279#discussion_r1710417152). We +However, `prometheus_client` has +[never supported Info metrics in multiprocessing mode](https://github.com/prometheus/client_python/pull/300) - +for [unclear reasons](gh-pr:7279#discussion_r1710417152). We simply use a `Gauge` metric set to 1 and `multiprocess_mode="mostrecent"` instead. @@ -395,11 +394,9 @@ distinguish between per-adapter counts. This should be revisited. Note that `multiprocess_mode="livemostrecent"` is used - the most recent metric is used, but only from currently running processes. -This was added in - and there is -[at least one known -user](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/54). If -we revisit this design and deprecate the old metric, we should reduce +This was added in and there is +[at least one known user](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/54). +If we revisit this design and deprecate the old metric, we should reduce the need for a significant deprecation period by making the change in v0 also and asking this project to move to the new metric. @@ -442,23 +439,20 @@ suddenly (from their perspective) when it is removed, even if there is an equivalent metric for them to use. As an example, see how `vllm:avg_prompt_throughput_toks_per_s` was -[deprecated](gh-pr:2764) (with a -comment in the code), -[removed](gh-pr:12383), and then -[noticed by a -user](gh-issue:13218). +[deprecated](gh-pr:2764) (with a comment in the code), +[removed](gh-pr:12383), and then [noticed by a user](gh-issue:13218). In general: -1) We should be cautious about deprecating metrics, especially since +1. We should be cautious about deprecating metrics, especially since it can be hard to predict the user impact. -2) We should include a prominent deprecation notice in the help string +2. We should include a prominent deprecation notice in the help string that is included in the `/metrics' output. -3) We should list deprecated metrics in user-facing documentation and +3. We should list deprecated metrics in user-facing documentation and release notes. -4) We should consider hiding deprecated metrics behind a CLI argument - in order to give administrators [an escape - hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics) +4. We should consider hiding deprecated metrics behind a CLI argument + in order to give administrators + [an escape hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics) for some time before deleting them. See the [deprecation policy](../../contributing/deprecation_policy.md) for @@ -474,7 +468,7 @@ removed. The `vllm:time_in_queue_requests` Histogram metric was added by and its calculation is: -``` +```python self.metrics.first_scheduled_time = now self.metrics.time_in_queue = now - self.metrics.arrival_time ``` @@ -482,7 +476,7 @@ The `vllm:time_in_queue_requests` Histogram metric was added by Two weeks later, added `vllm:request_queue_time_seconds` leaving us with: -``` +```python if seq_group.is_finished(): if (seq_group.metrics.first_scheduled_time is not None and seq_group.metrics.first_token_time is not None): @@ -517,8 +511,7 @@ cache to complete other requests), we swap kv cache blocks out to CPU memory. This is also known as "KV cache offloading" and is configured with `--swap-space` and `--preemption-mode`. -In v0, [vLLM has long supported beam -search](gh-issue:6226). The +In v0, [vLLM has long supported beam search](gh-issue:6226). The SequenceGroup encapsulated the idea of N Sequences which all shared the same prompt kv blocks. This enabled KV cache block sharing between requests, and copy-on-write to do branching. CPU @@ -530,9 +523,8 @@ option than CPU swapping since blocks can be evicted slowly on demand and the part of the prompt that was evicted can be recomputed. SequenceGroup was removed in V1, although a replacement will be -required for "parallel sampling" (`n>1`). [Beam search was moved out of -the core (in -V0)](gh-issue:8306). There was a +required for "parallel sampling" (`n>1`). +[Beam search was moved out of the core (in V0)](gh-issue:8306). There was a lot of complex code for a very uncommon feature. In V1, with prefix caching being better (zero over head) and therefore @@ -547,18 +539,18 @@ Some v0 metrics are only relevant in the context of "parallel sampling". This is where the `n` parameter in a request is used to request multiple completions from the same prompt. -As part of adding parallel sampling support in we should +As part of adding parallel sampling support in , we should also add these metrics. - `vllm:request_params_n` (Histogram) -Observes the value of the 'n' parameter of every finished request. + Observes the value of the 'n' parameter of every finished request. - `vllm:request_max_num_generation_tokens` (Histogram) -Observes the maximum output length of all sequences in every finished -sequence group. In the absence of parallel sampling, this is -equivalent to `vllm:request_generation_tokens`. + Observes the maximum output length of all sequences in every finished + sequence group. In the absence of parallel sampling, this is + equivalent to `vllm:request_generation_tokens`. ### Speculative Decoding @@ -576,26 +568,23 @@ There is a PR under review () to add "prompt lookup (ngram)" seculative decoding to v1. Other techniques will follow. We should revisit the v0 metrics in this context. -Note - we should probably expose acceptance rate as separate accepted -and draft counters, like we do for prefix caching hit rate. Efficiency -likely also needs similar treatment. +!!! note + We should probably expose acceptance rate as separate accepted + and draft counters, like we do for prefix caching hit rate. Efficiency + likely also needs similar treatment. ### Autoscaling and Load-balancing A common use case for our metrics is to support automated scaling of vLLM instances. -For related discussion from the [Kubernetes Serving Working -Group](https://github.com/kubernetes/community/tree/master/wg-serving), +For related discussion from the +[Kubernetes Serving Working Group](https://github.com/kubernetes/community/tree/master/wg-serving), see: -- [Standardizing Large Model Server Metrics in - Kubernetes](https://docs.google.com/document/d/1SpSp1E6moa4HSrJnS4x3NpLuj88sMXr2tbofKlzTZpk) -- [Benchmarking LLM Workloads for Performance Evaluation and - Autoscaling in - Kubernetes](https://docs.google.com/document/d/1k4Q4X14hW4vftElIuYGDu5KDe2LtV1XammoG-Xi3bbQ) -- [Inference - Perf](https://github.com/kubernetes-sigs/wg-serving/tree/main/proposals/013-inference-perf) +- [Standardizing Large Model Server Metrics in Kubernetes](https://docs.google.com/document/d/1SpSp1E6moa4HSrJnS4x3NpLuj88sMXr2tbofKlzTZpk) +- [Benchmarking LLM Workloads for Performance Evaluation and Autoscaling in Kubernetes](https://docs.google.com/document/d/1k4Q4X14hW4vftElIuYGDu5KDe2LtV1XammoG-Xi3bbQ) +- [Inference Perf](https://github.com/kubernetes-sigs/wg-serving/tree/main/proposals/013-inference-perf) - and . This is a non-trivial topic. Consider this comment from Rob: @@ -619,19 +608,16 @@ should judge an instance as approaching saturation: Our approach to naming metrics probably deserves to be revisited: -1. The use of colons in metric names seems contrary to ["colons are - reserved for user defined recording - rules"](https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels) +1. The use of colons in metric names seems contrary to + ["colons are reserved for user defined recording rules"](https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels). 2. Most of our metrics follow the convention of ending with units, but not all do. 3. Some of our metric names end with `_total`: -``` -If there is a suffix of `_total` on the metric name, it will be removed. When -exposing the time series for counter, a `_total` suffix will be added. This is -for compatibility between OpenMetrics and the Prometheus text format, as OpenMetrics -requires the `_total` suffix. -``` + If there is a suffix of `_total` on the metric name, it will be removed. When + exposing the time series for counter, a `_total` suffix will be added. This is + for compatibility between OpenMetrics and the Prometheus text format, as OpenMetrics + requires the `_total` suffix. ### Adding More Metrics @@ -642,8 +628,7 @@ There is no shortage of ideas for new metrics: - Proposals arising from specific use cases, like the Kubernetes auto-scaling topic above - Proposals that might arise out of standardisation efforts like - [OpenTelemetry Semantic Conventions for Gen - AI](https://github.com/open-telemetry/semantic-conventions/tree/main/docs/gen-ai). + [OpenTelemetry Semantic Conventions for Gen AI](https://github.com/open-telemetry/semantic-conventions/tree/main/docs/gen-ai). We should be cautious in our approach to adding new metrics. While metrics are often relatively straightforward to add: @@ -668,19 +653,14 @@ fall under the more general heading of "Observability". v0 has support for OpenTelemetry tracing: - Added by -- Configured with `--oltp-traces-endpoint` and - `--collect-detailed-traces` -- [OpenTelemetry blog - post](https://opentelemetry.io/blog/2024/llm-observability/) -- [User-facing - docs](https://docs.vllm.ai/en/latest/examples/opentelemetry.html) -- [Blog - post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f) -- [IBM product - docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview) +- Configured with `--oltp-traces-endpoint` and `--collect-detailed-traces` +- [OpenTelemetry blog post](https://opentelemetry.io/blog/2024/llm-observability/) +- [User-facing docs](../../examples/online_serving/opentelemetry.md) +- [Blog post](https://medium.com/@ronen.schaffer/follow-the-trail-supercharging-vllm-with-opentelemetry-distributed-tracing-aa655229b46f) +- [IBM product docs](https://www.ibm.com/docs/en/instana-observability/current?topic=mgaa-monitoring-large-language-models-llms-vllm-public-preview) -OpenTelemetry has a [Gen AI Working -Group](https://github.com/open-telemetry/community/blob/main/projects/gen-ai.md). +OpenTelemetry has a +[Gen AI Working Group](https://github.com/open-telemetry/community/blob/main/projects/gen-ai.md). Since metrics is a big enough topic on its own, we are going to tackle the topic of tracing in v1 separately. @@ -699,7 +679,7 @@ These metrics are only enabled when OpenTelemetry tracing is enabled and if `--collect-detailed-traces=all/model/worker` is used. The documentation for this option states: -> collect detailed traces for the specified "modules. This involves +> collect detailed traces for the specified modules. This involves > use of possibly costly and or blocking operations and hence might > have a performance impact. diff --git a/docs/design/v1/p2p_nccl_connector.md b/docs/design/v1/p2p_nccl_connector.md index 32cdaacf058a..9f6acf3291dd 100644 --- a/docs/design/v1/p2p_nccl_connector.md +++ b/docs/design/v1/p2p_nccl_connector.md @@ -31,7 +31,7 @@ Each P/D instance periodically sends a heartbeat packet to the Proxy/Router (cur ## KV Cache Transfer Methods -There are three methods for KVcache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVcache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVcache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVcache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVcache from the P instance once it has allocated space for the KVcache. +There are three methods for KVCache transfer: PUT, GET, and PUT_ASYNC. These methods can be specified using the `--kv-transfer-config` and `kv_connector_extra_config` parameters, specifically through the `send_type` field. Both PUT and PUT_ASYNC involve the P instance actively sending KVCache to the D instance. The difference is that PUT is a synchronous transfer method that blocks the main process, while PUT_ASYNC is an asynchronous transfer method. PUT_ASYNC uses a dedicated thread for sending KVCache, which means it does not block the main process. In contrast, the GET method involves the P instance saving the KVCache to the memory buffer after computing the prefill. The D instance then actively retrieves the computed KVCache from the P instance once it has allocated space for the KVCache. Experimental results have shown that the performance of these methods, from highest to lowest, is as follows: PUT_ASYNC → GET → PUT. @@ -39,13 +39,13 @@ Experimental results have shown that the performance of these methods, from high As long as the address of the counterpart is known, point-to-point KV cache transfer (using NCCL) can be performed, without being constrained by rank and world size. To support dynamic scaling (expansion and contraction) of instances with PD disaggregation. This means that adding or removing P/D instances does not require a full system restart. -Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVcache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVcache data itself. +Each P/D instance only needs to create a single `P2pNcclEngine` instance. This instance maintains a ZMQ Server, which runs a dedicated thread to listen on the `zmq_addr` address and receive control flow requests from other instances. These requests include requests to establish an NCCL connection and requests to send KVCache metadata (such as tensor shapes and data types). However, it does not actually transmit the KVCache data itself. -When a P instance and a D instance transmit KVcache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVcache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVcache transmission can be performed, without being restricted by rank or world size. +When a P instance and a D instance transmit KVCache for the first time, they need to establish a ZMQ connection and an NCCL group. For subsequent KVCache transmissions, this ZMQ connection and NCCL group are reused. The NCCL group consists of only two ranks, meaning the world size is equal to 2. This design is intended to support dynamic scaling, which means that adding or removing P/D instances does not require a full system restart. As long as the address of the counterpart is known, point-to-point KVCache transmission can be performed, without being restricted by rank or world size. ## NCCL Group Topology -Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVcache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance. +Currently, only symmetric TP (Tensor Parallelism) methods are supported for KVCache transmission. Asymmetric TP and PP (Pipeline Parallelism) methods will be supported in the future. Figure 2 illustrates the 1P2D setup, where each instance has a TP (Tensor Parallelism) degree of 2. There are a total of 7 NCCL groups: three vLLM instances each have one NCCL group with TP=2. Additionally, the 0th GPU card of the P instance establishes an NCCL group with the 0th GPU card of each D instance. Similarly, the 1st GPU card of the P instance establishes an NCCL group with the 1st GPU card of each D instance. ![image2](https://github.com/user-attachments/assets/837e61d6-365e-4cbf-8640-6dd7ab295b36) @@ -53,33 +53,17 @@ Each NCCL group occupies a certain amount of GPU memory buffer for communication ## GPU Memory Buffer and Tensor Memory Pool -The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVcache sent by P instances. If it is too large, it will reduce the KVcache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%~10% of the memory size. +The trade-off in the size of the memory buffer is as follows: For P instances, the memory buffer is not required in PUT and PUT_ASYNC modes, but it is necessary in GET mode. For D instances, a memory buffer is needed in all three modes. The memory buffer for D instances should not be too large. Similarly, for P instances in GET mode, the memory buffer should also not be too large. The memory buffer of D instances is used to temporarily store KVCache sent by P instances. If it is too large, it will reduce the KVCache space available for normal inference by D instances, thereby decreasing the inference batch size and ultimately leading to a reduction in output throughput. The size of the memory buffer is configured by the parameter `kv_buffer_size`, measured in bytes, and is typically set to 5%~10% of the memory size. -If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVcache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVcache loss. Once KVcache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance. +If the `--max-num-seqs` parameter for P instances is set to a large value, due to the large batch size, P instances will generate a large amount of KVCache simultaneously. This may exceed the capacity of the memory buffer of D instances, resulting in KVCache loss. Once KVCache is lost, D instances need to recompute Prefill, which is equivalent to performing Prefill twice. Consequently, the time-to-first-token (TTFT) will significantly increase, leading to degraded performance. -To address the above issues, I have designed and developed a local Tensor memory pool for storing KVcache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVcache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVcache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store. +To address the above issues, I have designed and developed a local Tensor memory pool for storing KVCache, inspired by the buddy system used in Linux memory modules. Since the memory is sufficiently large, typically in the TB range on servers, there is no need to consider prefix caching or using block-based designs to reuse memory, thereby saving space. When the memory buffer is insufficient, KVCache can be directly stored in the Tensor memory pool, and D instances can subsequently retrieve KVCache from it. The read and write speed is that of PCIe, with PCIe 4.0 having a speed of approximately 21 GB/s, which is usually faster than the Prefill speed. Otherwise, solutions like Mooncake and lmcache would not be necessary. The Tensor memory pool acts as a flood diversion area, typically unused except during sudden traffic surges. In the worst-case scenario, my solution performs no worse than the normal situation with a Cache store. # Install vLLM -??? Commands - - ```shell - # Enter the home directory or your working directory. - cd /home - - # Download the installation package, and I will update the commit-id in time. You can directly copy the command. - wget https://vllm-wheels.s3.us-west-2.amazonaws.com/9112b443a042d8d815880b8780633882ad32b183/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl - - # Download the code repository. - git clone -b xpyd-v1 https://github.com/Abatom/vllm.git - cd vllm - - # Set the installation package path. - export VLLM_PRECOMPILED_WHEEL_LOCATION=/home/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl - - # installation - pip install -e . -v - ``` +```shell +pip install "vllm>=0.9.2" +``` # Run xPyD @@ -90,7 +74,7 @@ To address the above issues, I have designed and developed a local Tensor memory - You may need to modify the `kv_buffer_size` and `port` in the following commands (if there is a conflict). - `PUT_ASYNC` offers the best performance and should be prioritized. - The `--port` must be consistent with the `http_port` in the `--kv-transfer-config`. -- The `disagg_prefill_proxy_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances). +- The `disagg_proxy_p2p_nccl_xpyd.py` script will use port 10001 (for receiving client requests) and port 30001 (for receiving service discovery from P and D instances). - The node running the proxy must have `quart` installed. - Supports multiple nodes; you just need to modify the `proxy_ip` and `proxy_port` in `--kv-transfer-config`. - In the following examples, it is assumed that **the proxy's IP is 10.0.1.1**. @@ -100,18 +84,18 @@ To address the above issues, I have designed and developed a local Tensor memory ### Proxy (e.g. 10.0.1.1) ```shell -cd {your vllm directory}/examples/online_serving/disagg_xpyd/ -python3 disagg_prefill_proxy_xpyd.py & +cd {your vllm directory}/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/ +python3 disagg_proxy_p2p_nccl_xpyd.py & ``` ### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ --host 0.0.0.0 \ - --port 20005 \ + --port 20001 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name base_model \ @@ -123,17 +107,17 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20001"}}' > /var/vllm.log 2>&1 & ``` ### Decode1 (e.g. 10.0.1.3 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ --host 0.0.0.0 \ - --port 20009 \ + --port 20002 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name base_model \ @@ -145,12 +129,12 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20002"}}' > /var/vllm.log 2>&1 & ``` ### Decode2 (e.g. 10.0.1.4 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \ @@ -167,17 +151,17 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003"}}' > /var/vllm.log 2>&1 & ``` ### Decode3 (e.g. 10.0.1.5 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ --host 0.0.0.0 \ - --port 20008 \ + --port 20004 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name base_model \ @@ -189,7 +173,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20004"}}' > /var/vllm.log 2>&1 & ``` ## Run 3P1D @@ -197,18 +181,18 @@ python3 disagg_prefill_proxy_xpyd.py & ### Proxy (e.g. 10.0.1.1) ```shell -cd {your vllm directory}/examples/online_serving/disagg_xpyd/ -python3 disagg_prefill_proxy_xpyd.py & +cd {your vllm directory}/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/ +python3 disagg_proxy_p2p_nccl_xpyd.py & ``` ### Prefill1 (e.g. 10.0.1.2 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=0 vllm serve {your model directory} \ --host 0.0.0.0 \ - --port 20005 \ + --port 20001 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name base_model \ @@ -220,17 +204,17 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20005","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"21001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20001"}}' > /var/vllm.log 2>&1 & ``` ### Prefill2 (e.g. 10.0.1.3 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=1 vllm serve {your model directory} \ --host 0.0.0.0 \ - --port 20009 \ + --port 20002 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name base_model \ @@ -242,12 +226,12 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20009","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"22001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20002"}}' > /var/vllm.log 2>&1 & ``` ### Prefill3 (e.g. 10.0.1.4 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=2 vllm serve {your model directory} \ @@ -264,17 +248,17 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.9 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_buffer_size":"1e1","kv_port":"23001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20003"}}' > /var/vllm.log 2>&1 & ``` ### Decode1 (e.g. 10.0.1.5 or 10.0.1.1) -??? Command +??? console "Command" ```shell VLLM_USE_V1=1 CUDA_VISIBLE_DEVICES=3 vllm serve {your model directory} \ --host 0.0.0.0 \ - --port 20008 \ + --port 20004 \ --tensor-parallel-size 1 \ --seed 1024 \ --served-model-name base_model \ @@ -286,7 +270,7 @@ python3 disagg_prefill_proxy_xpyd.py & --gpu-memory-utilization 0.7 \ --disable-log-request \ --kv-transfer-config \ - '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20008","send_type":"PUT_ASYNC","nccl_num_channels":"16"}}' > /var/vllm.log 2>&1 & + '{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_buffer_size":"8e9","kv_port":"24001","kv_connector_extra_config":{"proxy_ip":"10.0.1.1","proxy_port":"30001","http_port":"20004"}}' > /var/vllm.log 2>&1 & ``` # Single request @@ -304,7 +288,7 @@ curl -X POST -s http://10.0.1.1:10001/v1/completions \ # Benchmark -??? Command +??? console "Command" ```shell python3 benchmark_serving.py \ @@ -334,24 +318,6 @@ pgrep python | xargs kill -9 && pkill -f python # Test data -## **Scenario 1**: 1K input & 1K output tokens, E2E P99 latency ~20s -- **1P5D (6×A800) vs vLLM (1×A800)**: - - Throughput ↑7.2% (1085 → 6979/6) - - ITL (P99) ↓81.3% (120ms → 22.9ms) - - TTFT (P99) ↑26.8% (175ms → 222ms) - - TPOT: No change - -- **1P6D (7×A800) vs vLLM (1×A800)**: - - Throughput ↑9.6% (1085 → 8329/7) - - ITL (P99) ↓81.0% (120ms → 22.7ms) - - TTFT (P99) ↑210% (175ms →543ms) - - TPOT: No change - -## **Scenario 2**: 1K input & 200 output tokens, E2E P99 latency ~4s -- **1P1D (2×A800) vs vLLM (1×A800)**: - - Throughput ↑37.4% (537 → 1476/2) - - ITL (P99) ↓81.8% (127ms → 23.1ms) - - TTFT (P99) ↑41.8% (160ms → 227ms) - - TPOT: No change - -![testdata](https://github.com/user-attachments/assets/f791bfc7-9f3d-4e5c-9171-a42f9f4da627) +## **Scenario**: 1K input & 200 output tokens, E2E P99 latency ~2s + +![testdata](https://github.com/user-attachments/assets/cef0953b-4567-4bf9-b940-405b92a28eb1) diff --git a/docs/design/v1/torch_compile.md b/docs/design/v1/torch_compile.md index b65099bd62a2..ea5d8ac212f7 100644 --- a/docs/design/v1/torch_compile.md +++ b/docs/design/v1/torch_compile.md @@ -28,7 +28,7 @@ A unique aspect of vLLM's `torch.compile` integration, is that we guarantee all In the very verbose logs, we can see: -??? Logs +??? console "Logs" ```text DEBUG 03-07 03:06:52 [decorators.py:203] Start compiling function @@ -110,7 +110,7 @@ Then it will also compile a specific kernel just for batch size `1, 2, 4, 8`. At When all the shapes are known, `torch.compile` can compare different configs, and often find some better configs to run the kernel. For example, we can see the following log: -??? Logs +??? console "Logs" ``` AUTOTUNE mm(8x2048, 2048x3072) diff --git a/docs/features/automatic_prefix_caching.md b/docs/features/automatic_prefix_caching.md index 5e92796ddda7..f3c4bdd85c37 100644 --- a/docs/features/automatic_prefix_caching.md +++ b/docs/features/automatic_prefix_caching.md @@ -1,14 +1,11 @@ ---- -title: Automatic Prefix Caching ---- -[](){ #automatic-prefix-caching } +# Automatic Prefix Caching ## Introduction Automatic Prefix Caching (APC in short) caches the KV cache of existing queries, so that a new query can directly reuse the KV cache if it shares the same prefix with one of the existing queries, allowing the new query to skip the computation of the shared part. !!! note - Technical details on how vLLM implements APC can be found [here][design-automatic-prefix-caching]. + Technical details on how vLLM implements APC can be found [here](../design/automatic_prefix_caching.md). ## Enabling APC in vLLM diff --git a/docs/features/compatibility_matrix.md b/docs/features/compatibility_matrix.md index 4f475ee4db83..8be1585f8e76 100644 --- a/docs/features/compatibility_matrix.md +++ b/docs/features/compatibility_matrix.md @@ -1,7 +1,4 @@ ---- -title: Compatibility Matrix ---- -[](){ #compatibility-matrix } +# Compatibility Matrix The tables below show mutually exclusive features and the support on some hardware. @@ -37,23 +34,22 @@ th:not(:first-child) { } -| Feature | [CP][chunked-prefill] | [APC][automatic-prefix-caching] | [LoRA][lora-adapter] | prmpt adptr | [SD][spec-decode] | CUDA graph | pooling | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | -|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| +| Feature | [CP][chunked-prefill] | [APC](automatic_prefix_caching.md) | [LoRA](lora.md) | [SD](spec_decode.md) | CUDA graph | pooling | enc-dec | logP | prmpt logP | async output | multi-step | mm | best-of | beam-search | +|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---| | [CP][chunked-prefill] | ✅ | | | | | | | | | | | | | | | -| [APC][automatic-prefix-caching] | ✅ | ✅ | | | | | | | | | | | | | | -| [LoRA][lora-adapter] | ✅ | ✅ | ✅ | | | | | | | | | | | | | -| prmpt adptr | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | | -| [SD][spec-decode] | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | | | | | -| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | -| pooling | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | | -| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | -| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | -| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | -| async output | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | -| multi-step | ❌ | ✅ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | -| mm | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | -| best-of | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | -| beam-search | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | +| [APC](automatic_prefix_caching.md) | ✅ | ✅ | | | | | | | | | | | | | | +| [LoRA](lora.md) | ✅ | ✅ | ✅ | | | | | | | | | | | | | +| [SD](spec_decode.md) | ✅ | ✅ | ❌ | ✅ | | | | | | | | | | | +| CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | +| pooling | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | | | | | | | | | +| enc-dec | ❌ | [❌](gh-issue:7366) | ❌ | [❌](gh-issue:7366) | ✅ | ✅ | ✅ | | | | | | | | +| logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | | | | | | +| prmpt logP | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | | | | | | +| async output | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | | | | | +| multi-step | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | | | | +| mm | ✅ | [🟠](gh-pr:8348) | [🟠](gh-pr:4194) | ❔ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ✅ | | | +| best-of | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ✅ | ✅ | | +| beam-search | ✅ | ✅ | ✅ | [❌](gh-issue:6137) | ✅ | ❌ | ✅ | ✅ | ✅ | ❔ | [❌](gh-issue:7968) | ❔ | ✅ | ✅ | [](){ #feature-x-hardware } @@ -62,10 +58,9 @@ th:not(:first-child) { | Feature | Volta | Turing | Ampere | Ada | Hopper | CPU | AMD | TPU | |-----------------------------------------------------------|---------------------|-----------|-----------|--------|------------|--------------------|--------|-----| | [CP][chunked-prefill] | [❌](gh-issue:2729) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [APC][automatic-prefix-caching] | [❌](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [LoRA][lora-adapter] | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| prmpt adptr | ✅ | ✅ | ✅ | ✅ | ✅ | [❌](gh-issue:8475) | ✅ | ❌ | -| [SD][spec-decode] | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | +| [APC](automatic_prefix_caching.md) | [❌](gh-issue:3687) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [LoRA](lora.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| [SD](spec_decode.md) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | | CUDA graph | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | | pooling | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❔ | ❌ | | enc-dec | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/features/disagg_prefill.md b/docs/features/disagg_prefill.md index 54be05647d94..c0c32594f266 100644 --- a/docs/features/disagg_prefill.md +++ b/docs/features/disagg_prefill.md @@ -1,7 +1,4 @@ ---- -title: Disaggregated Prefilling (experimental) ---- -[](){ #disagg-prefill } +# Disaggregated Prefilling (experimental) This page introduces you the disaggregated prefilling feature in vLLM. diff --git a/docs/features/lora.md b/docs/features/lora.md index 07fe93ab3754..d72c0bb4160c 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -1,7 +1,4 @@ ---- -title: LoRA Adapters ---- -[](){ #lora-adapter } +# LoRA Adapters This document shows you how to use [LoRA adapters](https://arxiv.org/abs/2106.09685) with vLLM on top of a base model. @@ -29,7 +26,7 @@ We can now submit the prompts and call `llm.generate` with the `lora_request` pa of `LoRARequest` is a human identifiable name, the second parameter is a globally unique ID for the adapter and the third parameter is the path to the LoRA adapter. -??? Code +??? code ```python sampling_params = SamplingParams( @@ -70,7 +67,7 @@ The server entrypoint accepts all other LoRA configuration parameters (`max_lora etc.), which will apply to all forthcoming requests. Upon querying the `/models` endpoint, we should see our LoRA along with its base model (if `jq` is not installed, you can follow [this guide](https://jqlang.org/download/) to install it.): -??? Command +??? console "Command" ```bash curl localhost:8000/v1/models | jq . @@ -172,7 +169,7 @@ Alternatively, follow these example steps to implement your own plugin: 1. Implement the LoRAResolver interface. - ??? Example of a simple S3 LoRAResolver implementation + ??? code "Example of a simple S3 LoRAResolver implementation" ```python import os @@ -238,7 +235,7 @@ The new format of `--lora-modules` is mainly to support the display of parent mo - The `parent` field of LoRA model `sql-lora` now links to its base model `meta-llama/Llama-2-7b-hf`. This correctly reflects the hierarchical relationship between the base model and the LoRA adapter. - The `root` field points to the artifact location of the lora adapter. -??? Command output +??? console "Command output" ```bash $ curl http://localhost:8000/v1/models diff --git a/docs/features/multimodal_inputs.md b/docs/features/multimodal_inputs.md index ed11d2836037..e820ace4f8fe 100644 --- a/docs/features/multimodal_inputs.md +++ b/docs/features/multimodal_inputs.md @@ -1,7 +1,4 @@ ---- -title: Multimodal Inputs ---- -[](){ #multimodal-inputs } +# Multimodal Inputs This page teaches you how to pass multi-modal inputs to [multi-modal models][supported-mm-models] in vLLM. @@ -20,7 +17,7 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]: You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples: -??? Code +??? code ```python from vllm import LLM @@ -68,7 +65,7 @@ Full example: To substitute multiple images inside the same text prompt, you can pass in a list of images instead: -??? Code +??? code ```python from vllm import LLM @@ -101,7 +98,7 @@ To substitute multiple images inside the same text prompt, you can pass in a lis Full example: -If using the [LLM.chat](https://docs.vllm.ai/en/stable/models/generative_models.html#llmchat) method, you can pass images directly in the message content using various formats: image URLs, PIL Image objects, or pre-computed embeddings: +If using the [LLM.chat](../models/generative_models.md#llmchat) method, you can pass images directly in the message content using various formats: image URLs, PIL Image objects, or pre-computed embeddings: ```python from vllm import LLM @@ -146,7 +143,7 @@ for o in outputs: Multi-image input can be extended to perform video captioning. We show this with [Qwen2-VL](https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct) as it supports videos: -??? Code +??? code ```python from vllm import LLM @@ -193,7 +190,7 @@ Full example: To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model, pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the corresponding field of the multi-modal dictionary. -??? Code +??? code ```python from vllm import LLM @@ -220,7 +217,7 @@ pass a tensor of shape `(num_items, feature_size, hidden_size of LM)` to the cor For Qwen2-VL and MiniCPM-V, we accept additional parameters alongside the embeddings: -??? Code +??? code ```python # Construct the prompt based on your model @@ -288,7 +285,7 @@ vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ Then, you can use the OpenAI client as follows: -??? Code +??? code ```python from openai import OpenAI @@ -366,7 +363,7 @@ vllm serve llava-hf/llava-onevision-qwen2-0.5b-ov-hf --task generate --max-model Then, you can use the OpenAI client as follows: -??? Code +??? code ```python from openai import OpenAI @@ -430,7 +427,7 @@ vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b Then, you can use the OpenAI client as follows: -??? Code +??? code ```python import base64 @@ -486,7 +483,7 @@ Then, you can use the OpenAI client as follows: Alternatively, you can pass `audio_url`, which is the audio counterpart of `image_url` for image input: -??? Code +??? code ```python chat_completion_from_url = client.chat.completions.create( @@ -531,7 +528,7 @@ pass a tensor of shape to the corresponding field of the multi-modal dictionary. For image embeddings, you can pass the base64-encoded tensor to the `image_embeds` field. The following example demonstrates how to pass image embeddings to the OpenAI server: -??? Code +??? code ```python image_embedding = torch.load(...) diff --git a/docs/features/quantization/README.md b/docs/features/quantization/README.md index 614b43dd0044..e8c3b1123078 100644 --- a/docs/features/quantization/README.md +++ b/docs/features/quantization/README.md @@ -1,7 +1,4 @@ ---- -title: Quantization ---- -[](){ #quantization-index } +# Quantization Quantization trades off model precision for smaller memory footprint, allowing large models to be run on a wider range of devices. @@ -13,6 +10,7 @@ Contents: - [BitBLAS](bitblas.md) - [GGUF](gguf.md) - [GPTQModel](gptqmodel.md) +- [INC](inc.md) - [INT4 W4A16](int4.md) - [INT8 W8A8](int8.md) - [FP8 W8A8](fp8.md) diff --git a/docs/features/quantization/auto_awq.md b/docs/features/quantization/auto_awq.md index 9f97ea406e25..fc998387d29a 100644 --- a/docs/features/quantization/auto_awq.md +++ b/docs/features/quantization/auto_awq.md @@ -1,7 +1,4 @@ ---- -title: AutoAWQ ---- -[](){ #auto-awq } +# AutoAWQ To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ). Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint. @@ -15,7 +12,7 @@ pip install autoawq After installing AutoAWQ, you are ready to quantize a model. Please refer to the [AutoAWQ documentation](https://casper-hansen.github.io/AutoAWQ/examples/#basic-quantization) for further details. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: -??? Code +??? code ```python from awq import AutoAWQForCausalLM @@ -51,7 +48,7 @@ python examples/offline_inference/llm_engine_example.py \ AWQ models are also supported directly through the LLM entrypoint: -??? Code +??? code ```python from vllm import LLM, SamplingParams diff --git a/docs/features/quantization/bitblas.md b/docs/features/quantization/bitblas.md index c8f874ff8414..6f53a448ee36 100644 --- a/docs/features/quantization/bitblas.md +++ b/docs/features/quantization/bitblas.md @@ -1,14 +1,11 @@ ---- -title: BitBLAS ---- -[](){ #bitblas } +# BitBLAS vLLM now supports [BitBLAS](https://github.com/microsoft/BitBLAS) for more efficient and flexible model inference. Compared to other quantization frameworks, BitBLAS provides more precision combinations. !!! note Ensure your hardware supports the selected `dtype` (`torch.bfloat16` or `torch.float16`). Most recent NVIDIA GPUs support `float16`, while `bfloat16` is more common on newer architectures like Ampere or Hopper. - For details see [supported hardware](https://docs.vllm.ai/en/latest/features/quantization/supported_hardware.html). + For details see [supported hardware](supported_hardware.md). Below are the steps to utilize BitBLAS with vLLM. @@ -43,7 +40,7 @@ llm = LLM( ## Read gptq format checkpoint -??? Code +??? code ```python from vllm import LLM diff --git a/docs/features/quantization/bnb.md b/docs/features/quantization/bnb.md index ca13ee107ef4..3b15a6072d47 100644 --- a/docs/features/quantization/bnb.md +++ b/docs/features/quantization/bnb.md @@ -1,7 +1,4 @@ ---- -title: BitsAndBytes ---- -[](){ #bits-and-bytes } +# BitsAndBytes vLLM now supports [BitsAndBytes](https://github.com/TimDettmers/bitsandbytes) for more efficient model inference. BitsAndBytes quantizes models to reduce memory usage and enhance performance without significantly sacrificing accuracy. diff --git a/docs/features/quantization/fp8.md b/docs/features/quantization/fp8.md index b9ed668b2ef3..0661933acd61 100644 --- a/docs/features/quantization/fp8.md +++ b/docs/features/quantization/fp8.md @@ -1,7 +1,4 @@ ---- -title: FP8 W8A8 ---- -[](){ #fp8 } +# FP8 W8A8 vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x. Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8. @@ -58,7 +55,7 @@ For FP8 quantization, we can recover accuracy with simple RTN quantization. We r Since simple RTN does not require data for weight quantization and the activations are quantized dynamically, we do not need any calibration data for this quantization flow. -??? Code +??? code ```python from llmcompressor.transformers import oneshot @@ -89,8 +86,9 @@ Load and run the model in `vllm`: ```python from vllm import LLM -model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic") -result = model.generate("Hello my name is") + +llm = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic") +result = llm.generate("Hello my name is") print(result[0].outputs[0].text) ``` @@ -128,9 +126,10 @@ In this mode, all Linear modules (except for the final `lm_head`) have their wei ```python from vllm import LLM -model = LLM("facebook/opt-125m", quantization="fp8") + +llm = LLM("facebook/opt-125m", quantization="fp8") # INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB -result = model.generate("Hello, my name is") +result = llm.generate("Hello, my name is") print(result[0].outputs[0].text) ``` diff --git a/docs/features/quantization/gguf.md b/docs/features/quantization/gguf.md index 102a3ee1cccc..2a1c3bdd775f 100644 --- a/docs/features/quantization/gguf.md +++ b/docs/features/quantization/gguf.md @@ -1,7 +1,4 @@ ---- -title: GGUF ---- -[](){ #gguf } +# GGUF !!! warning Please note that GGUF support in vLLM is highly experimental and under-optimized at the moment, it might be incompatible with other features. Currently, you can use GGUF as a way to reduce memory footprint. If you encounter any issues, please report them to the vLLM team. @@ -41,7 +38,7 @@ vllm serve ./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf \ You can also use the GGUF model directly through the LLM entrypoint: -??? Code +??? code ```python from vllm import LLM, SamplingParams diff --git a/docs/features/quantization/gptqmodel.md b/docs/features/quantization/gptqmodel.md index 37bb02d4fb5b..47cb2d65bae4 100644 --- a/docs/features/quantization/gptqmodel.md +++ b/docs/features/quantization/gptqmodel.md @@ -1,7 +1,4 @@ ---- -title: GPTQModel ---- -[](){ #gptqmodel } +# GPTQModel To create a new 4-bit or 8-bit GPTQ quantized model, you can leverage [GPTQModel](https://github.com/ModelCloud/GPTQModel) from ModelCloud.AI. @@ -31,7 +28,7 @@ After installing GPTQModel, you are ready to quantize a model. Please refer to t Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`: -??? Code +??? code ```python from datasets import load_dataset @@ -69,7 +66,7 @@ python examples/offline_inference/llm_engine_example.py \ GPTQModel quantized models are also supported directly through the LLM entrypoint: -??? Code +??? code ```python from vllm import LLM, SamplingParams diff --git a/docs/features/quantization/inc.md b/docs/features/quantization/inc.md new file mode 100644 index 000000000000..d97a462f5432 --- /dev/null +++ b/docs/features/quantization/inc.md @@ -0,0 +1,56 @@ +--- +title: FP8 INC +--- +[](){ #inc } + +vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators. +Currently, quantization is validated only in Llama models. + +Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to: +[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules). + +!!! note + Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package. + +!!! note + `QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options). + The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference. + +## Run Online Inference Using FP8 + +Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command: + +```bash +export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json +vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8 +``` + +!!! tip + If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop. + +!!! tip + When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables: + `VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes. + `VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes. + +## Run Offline Inference Using FP8 + +To run offline inference (after completing the model calibration process): + +* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode. +* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object. +* Call shutdown method of the model_executor at the end of the run. + +```python +from vllm import LLM +llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc") +... +# Call llm.generate on the required prompts and sampling params. +... +llm.llm_engine.model_executor.shutdown() +``` + +## Device for the Model's Weights Uploading + +The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution. +This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory. diff --git a/docs/features/quantization/int4.md b/docs/features/quantization/int4.md index 2008bef5c8a2..1df32a11ed9d 100644 --- a/docs/features/quantization/int4.md +++ b/docs/features/quantization/int4.md @@ -1,7 +1,4 @@ ---- -title: INT4 W4A16 ---- -[](){ #int4 } +# INT4 W4A16 vLLM supports quantizing weights to INT4 for memory savings and inference acceleration. This quantization method is particularly useful for reducing model size and maintaining low latency in workloads with low queries per second (QPS). @@ -53,7 +50,7 @@ When quantizing weights to INT4, you need sample data to estimate the weight upd It's best to use calibration data that closely matches your deployment data. For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`: -??? Code +??? code ```python from datasets import load_dataset @@ -78,7 +75,7 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra Now, apply the quantization algorithms: -??? Code +??? code ```python from llmcompressor.transformers import oneshot @@ -111,7 +108,8 @@ After quantization, you can load and run the model in vLLM: ```python from vllm import LLM -model = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128") + +llm = LLM("./Meta-Llama-3-8B-Instruct-W4A16-G128") ``` To evaluate accuracy, you can use `lm_eval`: @@ -141,7 +139,7 @@ lm_eval --model vllm \ The following is an example of an expanded quantization recipe you can tune to your own use case: -??? Code +??? code ```python from compressed_tensors.quantization import ( diff --git a/docs/features/quantization/int8.md b/docs/features/quantization/int8.md index 3a8f855aa057..45fae58a6486 100644 --- a/docs/features/quantization/int8.md +++ b/docs/features/quantization/int8.md @@ -1,7 +1,4 @@ ---- -title: INT8 W8A8 ---- -[](){ #int8 } +# INT8 W8A8 vLLM supports quantizing weights and activations to INT8 for memory savings and inference acceleration. This quantization method is particularly useful for reducing model size while maintaining good performance. @@ -54,7 +51,7 @@ When quantizing activations to INT8, you need sample data to estimate the activa It's best to use calibration data that closely matches your deployment data. For a general-purpose instruction-tuned model, you can use a dataset like `ultrachat`: -??? Code +??? code ```python from datasets import load_dataset @@ -81,7 +78,7 @@ For a general-purpose instruction-tuned model, you can use a dataset like `ultra Now, apply the quantization algorithms: -??? Code +??? code ```python from llmcompressor.transformers import oneshot @@ -117,7 +114,8 @@ After quantization, you can load and run the model in vLLM: ```python from vllm import LLM -model = LLM("./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token") + +llm = LLM("./Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token") ``` To evaluate accuracy, you can use `lm_eval`: diff --git a/docs/features/quantization/modelopt.md b/docs/features/quantization/modelopt.md index 39f2a78e705f..39ae03b1bdac 100644 --- a/docs/features/quantization/modelopt.md +++ b/docs/features/quantization/modelopt.md @@ -14,7 +14,7 @@ You can quantize HuggingFace models using the example scripts provided in the Te Below is an example showing how to quantize a model using modelopt's PTQ API: -??? Code +??? code ```python import modelopt.torch.quantization as mtq @@ -50,7 +50,7 @@ with torch.inference_mode(): The quantized checkpoint can then be deployed with vLLM. As an example, the following code shows how to deploy `nvidia/Llama-3.1-8B-Instruct-FP8`, which is the FP8 quantized checkpoint derived from `meta-llama/Llama-3.1-8B-Instruct`, using vLLM: -??? Code +??? code ```python from vllm import LLM, SamplingParams diff --git a/docs/features/quantization/quantized_kvcache.md b/docs/features/quantization/quantized_kvcache.md index 323dcb7d052d..c54ec43658a4 100644 --- a/docs/features/quantization/quantized_kvcache.md +++ b/docs/features/quantization/quantized_kvcache.md @@ -1,7 +1,4 @@ ---- -title: Quantized KV Cache ---- -[](){ #quantized-kvcache } +# Quantized KV Cache ## FP8 KV Cache @@ -35,7 +32,7 @@ Studies have shown that FP8 E4M3 quantization typically only minimally degrades Here is an example of how to enable FP8 quantization: -??? Code +??? code ```python # To calculate kv cache scales on the fly enable the calculate_kv_scales @@ -73,7 +70,7 @@ pip install llmcompressor Here's a complete example using `meta-llama/Llama-3.1-8B-Instruct` (most models can use this same pattern): -??? Code +??? code ```python from datasets import load_dataset diff --git a/docs/features/quantization/quark.md b/docs/features/quantization/quark.md index 77e383495406..5abfae35eeec 100644 --- a/docs/features/quantization/quark.md +++ b/docs/features/quantization/quark.md @@ -1,7 +1,4 @@ ---- -title: AMD Quark ---- -[](){ #quark } +# AMD Quark Quantization can effectively reduce memory and bandwidth usage, accelerate computation and improve throughput while with minimal accuracy loss. vLLM can leverage [Quark](https://quark.docs.amd.com/latest/), @@ -42,7 +39,7 @@ The Quark quantization process can be listed for 5 steps as below: Quark uses [Transformers](https://huggingface.co/docs/transformers/en/index) to fetch model and tokenizer. -??? Code +??? code ```python from transformers import AutoTokenizer, AutoModelForCausalLM @@ -65,7 +62,7 @@ Quark uses the [PyTorch Dataloader](https://pytorch.org/tutorials/beginner/basic to load calibration data. For more details about how to use calibration datasets efficiently, please refer to [Adding Calibration Datasets](https://quark.docs.amd.com/latest/pytorch/calibration_datasets.html). -??? Code +??? code ```python from datasets import load_dataset @@ -98,7 +95,7 @@ kv-cache and the quantization algorithm is AutoSmoothQuant. AutoSmoothQuant config file for Llama is `examples/torch/language_modeling/llm_ptq/models/llama/autosmoothquant_config.json`. -??? Code +??? code ```python from quark.torch.quantization import (Config, QuantizationConfig, @@ -145,7 +142,7 @@ HuggingFace `safetensors`, you can refer to [HuggingFace format exporting](https://quark.docs.amd.com/latest/pytorch/export/quark_export_hf.html) for more exporting format details. -??? Code +??? code ```python import torch @@ -176,7 +173,7 @@ for more exporting format details. Now, you can load and run the Quark quantized model directly through the LLM entrypoint: -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -232,3 +229,28 @@ python3 quantize_quark.py --model_dir meta-llama/Llama-2-70b-chat-hf \ --model_export hf_format \ --tasks gsm8k ``` + +## Using MXFP4 models + +vLLM supports loading MXFP4 models quantized offline through AMD Quark, compliant with [Open Compute Project (OCP) specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). + +The scheme currently only supports dynamic quantization for activations. + +Example usage, after installing the latest AMD Quark release: + +```bash +vllm serve fxmarty/qwen_1.5-moe-a2.7b-mxfp4 --tensor-parallel-size 1 +``` + +A simulation of the matrix multiplication execution in MXFP4 can be run on devices that do not support MXFP4 operations natively (e.g. AMD Instinct MI325, MI300 and MI250), dequantizing weights from MXFP4 to half precision on the fly, using a fused kernel. This is useful e.g. to evaluate MXFP4 models using vLLM, or alternatively to benefit from the ~4x memory savings (compared to float16 and bfloat16). + +To generate offline models quantized using MXFP4 data type, the easiest approach is to use AMD Quark's [quantization script](https://quark.docs.amd.com/latest/pytorch/example_quark_torch_llm_ptq.html), as an example: + +```bash +python quantize_quark.py --model_dir Qwen/Qwen1.5-MoE-A2.7B-Chat \ + --quant_scheme w_mxfp4_a_mxfp4_sym \ + --output_dir qwen_1.5-moe-a2.7b-mxfp4 \ + --skip_evaluation \ + --model_export hf_format \ + --group_size 32 +``` diff --git a/docs/features/quantization/supported_hardware.md b/docs/features/quantization/supported_hardware.md index 6a585b1ccb2c..70a6a499562a 100644 --- a/docs/features/quantization/supported_hardware.md +++ b/docs/features/quantization/supported_hardware.md @@ -1,22 +1,20 @@ ---- -title: Supported Hardware ---- -[](){ #quantization-supported-hardware } +# Supported Hardware The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM: -| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU | -|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------| -| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ | -| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ | -| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | -| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ | -| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | -| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | +| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU | +|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------| +| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ | +| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | +| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ | +| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | +| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ | - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. - ✅︎ indicates that the quantization method is supported on the specified hardware. diff --git a/docs/features/quantization/torchao.md b/docs/features/quantization/torchao.md index f8df3c4b0809..ab6802177048 100644 --- a/docs/features/quantization/torchao.md +++ b/docs/features/quantization/torchao.md @@ -15,7 +15,7 @@ pip install \ ## Quantizing HuggingFace Models You can quantize your own huggingface model with torchao, e.g. [transformers](https://huggingface.co/docs/transformers/main/en/quantization/torchao) and [diffusers](https://huggingface.co/docs/diffusers/en/quantization/torchao), and save the checkpoint to huggingface hub like [this](https://huggingface.co/jerryzh168/llama3-8b-int8wo) with the following example code: -??? Code +??? code ```Python import torch diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 2e6afe61663c..6b84eca27530 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -1,7 +1,4 @@ ---- -title: Reasoning Outputs ---- -[](){ #reasoning-outputs } +# Reasoning Outputs vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions. @@ -17,6 +14,7 @@ vLLM currently supports the following reasoning models: | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | +| [Hunyuan A13B series](https://huggingface.co/collections/tencent/hunyuan-a13b-685ec38e5b46321e3ea7c4be) | `hunyuan_a13b` | `guided_json`, `guided_regex` | ✅ | !!! note IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. @@ -33,7 +31,7 @@ vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ Next, make a request to the model that should return the reasoning content in the response. -??? Code +??? code ```python from openai import OpenAI @@ -70,7 +68,7 @@ The `reasoning_content` field contains the reasoning steps that led to the final Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in [chat completion response chunks](https://platform.openai.com/docs/api-reference/chat/streaming). -??? Json +??? console "Json" ```json { @@ -95,7 +93,7 @@ Streaming chat completions are also supported for reasoning models. The `reasoni OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client supports extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: -??? Code +??? code ```python from openai import OpenAI @@ -152,7 +150,7 @@ Remember to check whether the `reasoning_content` exists in the response before The reasoning content is also available when both tool calling and the reasoning parser are enabled. Additionally, tool calling only parses functions from the `content` field, not from the `reasoning_content`. -??? Code +??? code ```python from openai import OpenAI @@ -200,7 +198,7 @@ For more examples, please refer to . -??? Code +??? code ```python # import the required packages @@ -258,7 +256,7 @@ You can add a new `ReasoningParser` similar to . -??? Code +??? code ```python @dataclass diff --git a/docs/features/spec_decode.md b/docs/features/spec_decode.md index abda7db53f91..be4b91feda7a 100644 --- a/docs/features/spec_decode.md +++ b/docs/features/spec_decode.md @@ -1,7 +1,4 @@ ---- -title: Speculative Decoding ---- -[](){ #spec-decode } +# Speculative Decoding !!! warning Please note that speculative decoding in vLLM is not yet optimized and does @@ -18,7 +15,7 @@ Speculative decoding is a technique which improves inter-token latency in memory The following code configures vLLM in an offline mode to use speculative decoding with a draft model, speculating 5 tokens at a time. -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -62,7 +59,7 @@ python -m vllm.entrypoints.openai.api_server \ Then use a client: -??? Code +??? code ```python from openai import OpenAI @@ -103,7 +100,7 @@ Then use a client: The following code configures vLLM to use speculative decoding where proposals are generated by matching n-grams in the prompt. For more information read [this thread.](https://x.com/joao_gante/status/1747322413006643259) -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -137,7 +134,7 @@ draft models that conditioning draft predictions on both context vectors and sam For more information see [this blog](https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/) or [this technical report](https://arxiv.org/abs/2404.19124). -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -185,7 +182,7 @@ A variety of speculative models of this type are available on HF hub: The following code configures vLLM to use speculative decoding where proposals are generated by an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https://arxiv.org/pdf/2401.15077) based draft model. A more detailed example for offline mode, including how to extract request level acceptance rate, can be found [here](gh-file:examples/offline_inference/eagle.py). -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -217,8 +214,8 @@ an [EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency)](https A few important things to consider when using the EAGLE based draft models: 1. The EAGLE draft models available in the [HF repository for EAGLE models](https://huggingface.co/yuhuili) should - be able to be loaded and used directly by vLLM after [PR 12304](https://github.com/vllm-project/vllm/pull/12304). - If you are using vllm version before [PR 12304](https://github.com/vllm-project/vllm/pull/12304), please use the + be able to be loaded and used directly by vLLM after . + If you are using vllm version before , please use the [script](https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d) to convert the speculative model, and specify `"model": "path/to/modified/eagle/model"` in `speculative_config`. If weight-loading problems still occur when using the latest version of vLLM, please leave a comment or raise an issue. @@ -228,7 +225,7 @@ A few important things to consider when using the EAGLE based draft models: 3. When using EAGLE-based speculators with vLLM, the observed speedup is lower than what is reported in the reference implementation [here](https://github.com/SafeAILab/EAGLE). This issue is under - investigation and tracked here: [https://github.com/vllm-project/vllm/issues/9565](https://github.com/vllm-project/vllm/issues/9565). + investigation and tracked here: . A variety of EAGLE draft models are available on the Hugging Face hub: @@ -259,17 +256,17 @@ speculative decoding, breaking down the guarantees into three key areas: 2. **Algorithmic Losslessness** \- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include: - > - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target - > distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252) - > - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling - > without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, - > provides a lossless guarantee. Almost all of the tests in . - > verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291) + > - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target + > distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252) + > - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling + > without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, + > provides a lossless guarantee. Almost all of the tests in . + > verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291) 3. **vLLM Logprob Stability** \- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the same request across runs. For more details, see the FAQ section - titled *Can the output of a prompt vary across runs in vLLM?* in the [FAQs][faq]. + titled *Can the output of a prompt vary across runs in vLLM?* in the [FAQs](../usage/faq.md). While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding can occur due to following factors: @@ -278,7 +275,7 @@ can occur due to following factors: - **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially due to non-deterministic behavior in batched operations or numerical instability. -For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the [FAQs][faq]. +For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the [FAQs](../usage/faq.md). ## Resources for vLLM contributors diff --git a/docs/features/structured_outputs.md b/docs/features/structured_outputs.md index 614b0bfe9679..4f737afa80f5 100644 --- a/docs/features/structured_outputs.md +++ b/docs/features/structured_outputs.md @@ -1,7 +1,4 @@ ---- -title: Structured Outputs ---- -[](){ #structured-outputs } +# Structured Outputs vLLM supports the generation of structured outputs using [xgrammar](https://github.com/mlc-ai/xgrammar) or @@ -21,7 +18,7 @@ The following parameters are supported, which must be added as extra parameters: - `guided_grammar`: the output will follow the context free grammar. - `structural_tag`: Follow a JSON schema within a set of specified tags within the generated text. -You can see the complete list of supported parameters on the [OpenAI-Compatible Server][serving-openai-compatible-server] page. +You can see the complete list of supported parameters on the [OpenAI-Compatible Server](../serving/openai_compatible_server.md) page. Structured outputs are supported by default in the OpenAI-Compatible Server. You may choose to specify the backend to use by setting the @@ -33,7 +30,7 @@ text. Now let´s see an example for each of the cases, starting with the `guided_choice`, as it´s the easiest one: -??? Code +??? code ```python from openai import OpenAI @@ -55,7 +52,7 @@ Now let´s see an example for each of the cases, starting with the `guided_choic The next example shows how to use the `guided_regex`. The idea is to generate an email address, given a simple regex template: -??? Code +??? code ```python completion = client.chat.completions.create( @@ -79,7 +76,7 @@ For this we can use the `guided_json` parameter in two different ways: The next example shows how to use the `guided_json` parameter with a Pydantic model: -??? Code +??? code ```python from pydantic import BaseModel @@ -127,7 +124,7 @@ difficult to use, but it´s really powerful. It allows us to define complete languages like SQL queries. It works by using a context free EBNF grammar. As an example, we can use to define a specific format of simplified SQL queries: -??? Code +??? code ```python simplified_sql_grammar = """ @@ -157,7 +154,7 @@ As an example, we can use to define a specific format of simplified SQL queries: print(completion.choices[0].message.content) ``` -See also: [full example](https://docs.vllm.ai/en/latest/examples/online_serving/structured_outputs.html) +See also: [full example](../examples/online_serving/structured_outputs.md) ## Reasoning Outputs @@ -169,7 +166,7 @@ vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-7B --reasoning-parser deepseek_r Note that you can use reasoning with any provided structured outputs feature. The following uses one with JSON schema: -??? Code +??? code ```python from pydantic import BaseModel @@ -200,7 +197,7 @@ Note that you can use reasoning with any provided structured outputs feature. Th print("content: ", completion.choices[0].message.content) ``` -See also: [full example](https://docs.vllm.ai/en/latest/examples/online_serving/structured_outputs.html) +See also: [full example](../examples/online_serving/structured_outputs.md) ## Experimental Automatic Parsing (OpenAI API) @@ -212,7 +209,7 @@ For the following examples, vLLM was setup using `vllm serve meta-llama/Llama-3. Here is a simple example demonstrating how to get structured output using Pydantic models: -??? Code +??? code ```python from pydantic import BaseModel @@ -248,7 +245,7 @@ Age: 28 Here is a more complex example using nested Pydantic models to handle a step-by-step math solution: -??? Code +??? code ```python from typing import List @@ -308,7 +305,7 @@ These parameters can be used in the same way as the parameters from the Online Serving examples above. One example for the usage of the `choice` parameter is shown below: -??? Code +??? code ```python from vllm import LLM, SamplingParams @@ -325,4 +322,4 @@ shown below: print(outputs[0].outputs[0].text) ``` -See also: [full example](https://docs.vllm.ai/en/latest/examples/online_serving/structured_outputs.html) +See also: [full example](../examples/online_serving/structured_outputs.md) diff --git a/docs/features/tool_calling.md b/docs/features/tool_calling.md index 8858b9a4015a..ce74683a1620 100644 --- a/docs/features/tool_calling.md +++ b/docs/features/tool_calling.md @@ -1,10 +1,10 @@ # Tool Calling -vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`) and `none` options for the `tool_choice` field in the chat completion API. +vLLM currently supports named function calling, as well as the `auto`, `required` (as of `vllm>=0.8.3`), and `none` options for the `tool_choice` field in the chat completion API. ## Quickstart -Start the server with tool calling enabled. This example uses Meta's Llama 3.1 8B model, so we need to use the llama3 tool calling chat template from the vLLM examples directory: +Start the server with tool calling enabled. This example uses Meta's Llama 3.1 8B model, so we need to use the `llama3_json` tool calling chat template from the vLLM examples directory: ```bash vllm serve meta-llama/Llama-3.1-8B-Instruct \ @@ -13,9 +13,9 @@ vllm serve meta-llama/Llama-3.1-8B-Instruct \ --chat-template examples/tool_chat_template_llama3.1_json.jinja ``` -Next, make a request to the model that should result in it using the available tools: +Next, make a request that triggers the model to use the available tools: -??? Code +??? code ```python from openai import OpenAI @@ -73,7 +73,7 @@ This example demonstrates: You can also specify a particular function using named function calling by setting `tool_choice={"type": "function", "function": {"name": "get_weather"}}`. Note that this will use the guided decoding backend - so the first time this is used, there will be several seconds of latency (or more) as the FSM is compiled for the first time before it is cached for subsequent requests. -Remember that it's the callers responsibility to: +Remember that it's the caller's responsibility to: 1. Define appropriate tools in the request 2. Include relevant context in the chat messages @@ -84,7 +84,7 @@ For more advanced usage, including parallel tool calls and different model-speci ## Named Function Calling vLLM supports named function calling in the chat completion API by default. It does so using Outlines through guided decoding, so this is -enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +enabled by default and will work with any supported model. You are guaranteed a validly-parsable function call - not a high-quality one. vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. @@ -95,7 +95,7 @@ specify the `name` of one of the tools in the `tool_choice` parameter of the cha ## Required Function Calling -vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The required guided decoding features (JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](https://docs.vllm.ai/en/latest/usage/v1_guide.html#feature-model) for the V1 engine. +vLLM supports the `tool_choice='required'` option in the chat completion API. Similar to the named function calling, it also uses guided decoding, so this is enabled by default and will work with any supported model. The guided decoding features for `tool_choice='required'` (such as JSON schema with `anyOf`) are currently only supported in the V0 engine with the guided decoding backend `outlines`. However, support for alternative decoding backends are on the [roadmap](../usage/v1_guide.md#features) for the V1 engine. When tool_choice='required' is set, the model is guaranteed to generate one or more tool calls based on the specified tool list in the `tools` parameter. The number of tool calls depends on the user's query. The output format strictly follows the schema defined in the `tools` parameter. @@ -103,24 +103,22 @@ When tool_choice='required' is set, the model is guaranteed to generate one or m vLLM supports the `tool_choice='none'` option in the chat completion API. When this option is set, the model will not generate any tool calls and will respond with regular text content only, even if tools are defined in the request. -By default, when `tool_choice='none'` is specified, vLLM excludes tool definitions from the prompt to optimize context usage. To include tool definitions even with `tool_choice='none'`, use the `--expand-tools-even-if-tool-choice-none` option. - -Note: This behavior will change in v0.10.0, where tool definitions will be included by default even with `tool_choice='none'`. +However, when `tool_choice='none'` is specified, vLLM includes tool definitions from the prompt. ## Automatic Function Calling To enable this feature, you should set the following flags: -* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. It tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. * `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers -will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. +will continue to be added in the future. You can also register your own tool parsers in the `--tool-parser-plugin`. * `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. -* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +* `--chat-template` -- **optional** for auto tool choice. It's the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their `tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) -from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json). If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! @@ -132,7 +130,7 @@ All Nous Research Hermes-series models newer than Hermes 2 Pro should be support * `NousResearch/Hermes-2-Theta-*` * `NousResearch/Hermes-3-*` -_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality and capabilities due to the merge step in their creation_. Flags: `--tool-call-parser hermes` @@ -148,13 +146,13 @@ Known issues: 1. Mistral 7B struggles to generate parallel tool calls correctly. 2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is -much shorter than what vLLM generates. Since an exception is thrown when this condition -is not met, the following additional chat templates are provided: + much shorter than what vLLM generates. Since an exception is thrown when this condition + is not met, the following additional chat templates are provided: -* - this is the "official" Mistral chat template, but tweaked so that -it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) -* - this is a "better" version that adds a tool-use system prompt -when tools are provided, that results in much better reliability when working with parallel tool calling. + * - this is the "official" Mistral chat template, but tweaked so that + it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) + * - this is a "better" version that adds a tool-use system prompt + when tools are provided, that results in much better reliability when working with parallel tool calling. Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` @@ -168,17 +166,17 @@ All Llama 3.1, 3.2 and 4 models should be supported. * `meta-llama/Llama-3.2-*` * `meta-llama/Llama-4-*` -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. +The tool calling that is supported is the [JSON-based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for Llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: -1. Parallel tool calls are not supported for llama 3, but it is supported in llama 4 models. -2. The model can generate parameters with a wrong format, such as generating +1. Parallel tool calls are not supported for Llama 3, but it is supported in Llama 4 models. +2. The model can generate parameters in an incorrect format, such as generating an array serialized as string instead of an array. -VLLM provides two JSON based chat templates for Llama 3.1 and 3.2: +VLLM provides two JSON-based chat templates for Llama 3.1 and 3.2: * - this is the "official" chat template for the Llama 3.1 models, but tweaked so that it works better with vLLM. @@ -187,7 +185,8 @@ images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` -VLLM also provides a pythonic and JSON based chat template for Llama 4, but pythonic tool calling is recommended: +VLLM also provides a pythonic and JSON-based chat template for Llama 4, but pythonic tool calling is recommended: + * - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models. For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. @@ -198,21 +197,21 @@ Supported models: * `ibm-granite/granite-3.0-8b-instruct` -Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` + Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` -: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. + : this is a modified chat template from the original on Hugging Face. Parallel function calls are supported. * `ibm-granite/granite-3.1-8b-instruct` -Recommended flags: `--tool-call-parser granite` + Recommended flags: `--tool-call-parser granite` -The chat template from Huggingface can be used directly. Parallel function calls are supported. + The chat template from Huggingface can be used directly. Parallel function calls are supported. * `ibm-granite/granite-20b-functioncalling` -Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` + Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` -: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + : this is a modified chat template from the original on Hugging Face, which is not vLLM-compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. ### InternLM Models (`internlm`) @@ -248,10 +247,12 @@ The xLAM tool parser is designed to support models that generate tool calls in v Parallel function calls are supported, and the parser can effectively separate text content from tool calls. Supported models: + * Salesforce Llama-xLAM models: `Salesforce/Llama-xLAM-2-8B-fc-r`, `Salesforce/Llama-xLAM-2-70B-fc-r` * Qwen-xLAM models: `Salesforce/xLAM-1B-fc-r`, `Salesforce/xLAM-3B-fc-r`, `Salesforce/Qwen-xLAM-32B-fc-r` Flags: + * For Llama-based xLAM models: `--tool-call-parser xlam --chat-template examples/tool_chat_template_xlam_llama.jinja` * For Qwen-based xLAM models: `--tool-call-parser xlam --chat-template examples/tool_chat_template_xlam_qwen.jinja` @@ -268,10 +269,10 @@ Flags: `--tool-call-parser hermes` Supported models: -* `MiniMaxAi/MiniMax-M1-40k` (use with ) -* `MiniMaxAi/MiniMax-M1-80k` (use with ) +* `MiniMaxAi/MiniMax-M1-40k` (use with ) +* `MiniMaxAi/MiniMax-M1-80k` (use with ) -Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax.jinja` +Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax_m1.jinja` ### DeepSeek-V3 Models (`deepseek_v3`) @@ -282,6 +283,25 @@ Supported models: Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}` +### Kimi-K2 Models (`kimi_k2`) + +Supported models: + +* `moonshotai/Kimi-K2-Instruct` + +Flags: `--tool-call-parser kimi_k2` + +### Hunyuan Models (`hunyuan_a13b`) + +Supported models: + +* `tencent/Hunyuan-A13B-Instruct` (The chat template is already included in the Hugging Face model files.) + +Flags: + +* For non-reasoning: `--tool-call-parser hunyuan_a13b` +* For reasoning: `--tool-call-parser hunyuan_a13b --reasoning-parser hunyuan_a13b --enable_reasoning` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. @@ -299,28 +319,25 @@ Limitations: Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct`\* (use with ) -* `meta-llama/Llama-3.2-3B-Instruct`\* (use with ) +* `meta-llama/Llama-3.2-1B-Instruct` ⚠️ (use with ) +* `meta-llama/Llama-3.2-3B-Instruct` ⚠️ (use with ) * `Team-ACE/ToolACE-8B` (use with ) * `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with ) -* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with ) -* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with ) +* `meta-llama/Llama-4-Scout-17B-16E-Instruct` ⚠️ (use with ) +* `meta-llama/Llama-4-Maverick-17B-128E-Instruct` ⚠️ (use with ) Flags: `--tool-call-parser pythonic --chat-template {see_above}` ---- -**WARNING** -Llama's smaller models frequently fail to emit tool calls in the correct format. Your mileage may vary. - ---- +!!! warning + Llama's smaller models frequently fail to emit tool calls in the correct format. Results may vary depending on the model. -## How to write a tool parser plugin +## How to Write a Tool Parser Plugin A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in . Here is a summary of a plugin file: -??? Code +??? code ```python diff --git a/docs/getting_started/installation/README.md b/docs/getting_started/installation/README.md index c5348adfa528..a252343dcee8 100644 --- a/docs/getting_started/installation/README.md +++ b/docs/getting_started/installation/README.md @@ -1,7 +1,4 @@ ---- -title: Installation ---- -[](){ #installation-index } +# Installation vLLM supports the following hardware platforms: diff --git a/docs/getting_started/installation/cpu.md b/docs/getting_started/installation/cpu.md index 5f2d0dbe27d3..2d2598da943c 100644 --- a/docs/getting_started/installation/cpu.md +++ b/docs/getting_started/installation/cpu.md @@ -76,80 +76,62 @@ Currently, there are no pre-built CPU wheels. ### Build image from source -??? Commands - - ```bash - docker build -f docker/Dockerfile.cpu \ - --tag vllm-cpu-env \ - --target vllm-openai . - - # Launching OpenAI server - docker run --rm \ - --privileged=true \ - --shm-size=4g \ - -p 8000:8000 \ - -e VLLM_CPU_KVCACHE_SPACE= \ - -e VLLM_CPU_OMP_THREADS_BIND= \ - vllm-cpu-env \ - --model=meta-llama/Llama-3.2-1B-Instruct \ - --dtype=bfloat16 \ - other vLLM OpenAI server arguments - ``` +=== "Intel/AMD x86" -!!! tip - For ARM or Apple silicon, use `docker/Dockerfile.arm` + --8<-- "docs/getting_started/installation/cpu/x86.inc.md:build-image-from-source" + +=== "ARM AArch64" -!!! tip - For IBM Z (s390x), use `docker/Dockerfile.s390x` and in `docker run` use flag `--dtype float` + --8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-image-from-source" -## Supported features +=== "Apple silicon" -vLLM CPU backend supports the following vLLM features: + --8<-- "docs/getting_started/installation/cpu/arm.inc.md:build-image-from-source" -- Tensor Parallel -- Model Quantization (`INT8 W8A8, AWQ, GPTQ`) -- Chunked-prefill -- Prefix-caching -- FP8-E5M2 KV cache +=== "IBM Z (S390X)" + --8<-- "docs/getting_started/installation/cpu/s390x.inc.md:build-image-from-source" ## Related runtime environment variables - `VLLM_CPU_KVCACHE_SPACE`: specify the KV Cache size (e.g, `VLLM_CPU_KVCACHE_SPACE=40` means 40 GiB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. Default value is `0`. -- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads. For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node. By setting to `all`, the OpenMP threads of each rank uses all CPU cores available on the system. Default value is `auto`. -- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `0`. -- `VLLM_CPU_MOE_PREPACK`: whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). -- `VLLM_CPU_SGL_KERNEL` (Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). +- `VLLM_CPU_OMP_THREADS_BIND`: specify the CPU cores dedicated to the OpenMP threads, can be set as CPU id lists or `auto` (by default). For example, `VLLM_CPU_OMP_THREADS_BIND=0-31` means there will be 32 OpenMP threads bound on 0-31 CPU cores. `VLLM_CPU_OMP_THREADS_BIND=0-31|32-63` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. By setting to `auto`, the OpenMP threads of each rank are bound to the CPU cores in each NUMA node respectively. +- `VLLM_CPU_NUM_OF_RESERVED_CPU`: specify the number of CPU cores which are not dedicated to the OpenMP threads for each rank. The variable only takes effect when VLLM_CPU_OMP_THREADS_BIND is set to `auto`. Default value is `None`. If the value is not set and use `auto` thread binding, no CPU will be reserved for `world_size == 1`, 1 CPU per rank will be reserved for `world_size > 1`. +- `VLLM_CPU_MOE_PREPACK` (x86 only): whether to use prepack for MoE layer. This will be passed to `ipex.llm.modules.GatedMLPMOE`. Default is `1` (True). On unsupported CPUs, you might need to set this to `0` (False). +- `VLLM_CPU_SGL_KERNEL` (x86 only, Experimental): whether to use small-batch optimized kernels for linear layer and MoE layer, especially for low-latency requirements like online serving. The kernels require AMX instruction set, BFloat16 weight type and weight shapes divisible by 32. Default is `0` (False). -## Performance tips +## FAQ -- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run: +### Which `dtype` should be used? -```bash -sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library -find / -name *libtcmalloc* # find the dynamic link library path -export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD -python examples/offline_inference/basic/basic.py # run vLLM -``` +- Currently vLLM CPU uses model default settings as `dtype`. However, due to unstable float16 support in torch CPU, it is recommended to explicitly set `dtype=bfloat16` if there are any performance or accuracy problem. -- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP: +### How to launch a vLLM service on CPU? + +- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 31 for the framework and using CPU 0-30 for inference threads: ```bash export VLLM_CPU_KVCACHE_SPACE=40 -export VLLM_CPU_OMP_THREADS_BIND=0-29 -vllm serve facebook/opt-125m +export VLLM_CPU_OMP_THREADS_BIND=0-30 +vllm serve facebook/opt-125m --dtype=bfloat16 ``` or using default auto thread binding: ```bash export VLLM_CPU_KVCACHE_SPACE=40 -export VLLM_CPU_NUM_OF_RESERVED_CPU=2 -vllm serve facebook/opt-125m +export VLLM_CPU_NUM_OF_RESERVED_CPU=1 +vllm serve facebook/opt-125m --dtype=bfloat16 ``` -- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using `VLLM_CPU_OMP_THREADS_BIND` or using auto thread binding feature by default. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: +Note, it is recommended to manually reserve 1 CPU for vLLM front-end process when `world_size == 1`. + +### How to decide `VLLM_CPU_OMP_THREADS_BIND`? + +- Default `auto` thread-binding is recommended for most cases. Ideally, each OpenMP thread will be bound to a dedicated physical core respectively, threads of each rank will be bound to a same NUMA node respectively, and 1 CPU per rank will be reserved for other vLLM components when `world_size > 1`. If have any performance problems or unexpected binding behaviours, please try to bind threads as following. + +- On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: -??? Commands +??? console "Commands" ```console $ lscpu -e # check the mapping between logical CPU cores and physical CPU cores @@ -178,34 +160,36 @@ vllm serve facebook/opt-125m $ python examples/offline_inference/basic/basic.py ``` -- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using `VLLM_CPU_OMP_THREADS_BIND` to avoid cross NUMA node memory access. +- When deploy vLLM CPU backend on a multi-socket machine with NUMA and enable tensor parallel or pipeline parallel, each NUMA node is treated as a TP/PP rank. So be aware to set CPU cores of a single rank on a same NUMA node to avoid cross NUMA node memory access. -## Other considerations +### How to decide `VLLM_CPU_KVCACHE_SPACE`? -- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance. + - This value is 4GB by default. Larger space can support more concurrent requests, longer context length. However, users should take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, the TP worker will be killed with `exitcode 9` due to out-of-memory. -- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. +### How to do performance tuning for vLLM CPU? -- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the [topology](https://github.com/intel/intel-extension-for-pytorch/blob/main/docs/tutorials/performance_tuning/tuning_guide.md#non-uniform-memory-access-numa). For NUMA architecture, Tensor Parallel is a option for better performance. +First of all, please make sure the thread-binding and KV cache space are properly set and take effect. You can check the thread-binding by running a vLLM benchmark and observing CPU cores usage via `htop`. - - Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: +Inference batch size is a important parameter for the performance. Larger batch usually provides higher throughput, smaller batch provides lower latency. Tuning max batch size starts from default value to balance throughput and latency is an effective way to improve vLLM CPU performance on specific platforms. There are two important related parameters in vLLM: - ```bash - VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" \ - vllm serve meta-llama/Llama-2-7b-chat-hf \ - -tp=2 \ - --distributed-executor-backend mp - ``` +- `--max-num-batched-tokens`, defines the limit of token numbers in a single batch, has more impacts on the first token performance. The default value is set as: + - Offline Inference: `4096 * world_size` + - Online Serving: `2048 * world_size` +- `--max-num-seqs`, defines the limit of sequence numbers in a single batch, has more impacts on the output token performance. + - Offline Inference: `256 * world_size` + - Online Serving: `128 * world_size` - or using default auto thread binding: +vLLM CPU supports tensor parallel (TP) and pipeline parallel (PP) to leverage multiple CPU sockets and memory nodes. For more detials of tuning TP and PP, please refer to [Optimization and Tuning](../../configuration/optimization.md). For vLLM CPU, it is recommend to use TP and PP togther if there are enough CPU sockets and memory nodes. - ```bash - VLLM_CPU_KVCACHE_SPACE=40 \ - vllm serve meta-llama/Llama-2-7b-chat-hf \ - -tp=2 \ - --distributed-executor-backend mp - ``` +### Which quantization configs does vLLM CPU support? + + - vLLM CPU supports quantizations: + - AWQ (x86 only) + - GPTQ (x86 only) + - compressed-tensor INT8 W8A8 (x86, s390x) - - For each thread id list in `VLLM_CPU_OMP_THREADS_BIND`, users should guarantee threads in the list belong to a same NUMA node. +### (x86 only) What is the purpose of `VLLM_CPU_MOE_PREPACK` and `VLLM_CPU_SGL_KERNEL`? - - Meanwhile, users should also take care of memory capacity of each NUMA node. The memory usage of each TP rank is the sum of `weight shard size` and `VLLM_CPU_KVCACHE_SPACE`, if it exceeds the capacity of a single NUMA node, TP worker will be killed due to out-of-memory. + - Both of them requires `amx` CPU flag. + - `VLLM_CPU_MOE_PREPACK` can provides better performance for MoE models + - `VLLM_CPU_SGL_KERNEL` can provides better performance for MoE models and small-batch scenarios. diff --git a/docs/getting_started/installation/cpu/apple.inc.md b/docs/getting_started/installation/cpu/apple.inc.md index 1771213f5591..0816f38ac68a 100644 --- a/docs/getting_started/installation/cpu/apple.inc.md +++ b/docs/getting_started/installation/cpu/apple.inc.md @@ -35,28 +35,24 @@ pip install -e . !!! note On macOS the `VLLM_TARGET_DEVICE` is automatically set to `cpu`, which currently is the only supported device. -#### Troubleshooting - -If the build has error like the following snippet where standard C++ headers cannot be found, try to remove and reinstall your -[Command Line Tools for Xcode](https://developer.apple.com/download/all/). - -```text -[...] fatal error: 'map' file not found - 1 | #include - | ^~~~~ - 1 error generated. - [2/8] Building CXX object CMakeFiles/_C.dir/csrc/cpu/pos_encoding.cpp.o - -[...] fatal error: 'cstddef' file not found - 10 | #include - | ^~~~~~~~~ - 1 error generated. -``` +!!! example "Troubleshooting" + If the build has error like the following snippet where standard C++ headers cannot be found, try to remove and reinstall your + [Command Line Tools for Xcode](https://developer.apple.com/download/all/). + + ```text + [...] fatal error: 'map' file not found + 1 | #include + | ^~~~~ + 1 error generated. + [2/8] Building CXX object CMakeFiles/_C.dir/csrc/cpu/pos_encoding.cpp.o + + [...] fatal error: 'cstddef' file not found + 10 | #include + | ^~~~~~~~~ + 1 error generated. + ``` # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] # --8<-- [end:pre-built-images] diff --git a/docs/getting_started/installation/cpu/arm.inc.md b/docs/getting_started/installation/cpu/arm.inc.md index 6c05900cf45c..63ae351b395f 100644 --- a/docs/getting_started/installation/cpu/arm.inc.md +++ b/docs/getting_started/installation/cpu/arm.inc.md @@ -28,14 +28,26 @@ ARM CPU backend currently supports Float32, FP16 and BFloat16 datatypes. Testing has been conducted on AWS Graviton3 instances for compatibility. # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] - +```bash +docker build -f docker/Dockerfile.arm \ + --tag vllm-cpu-env . + +# Launching OpenAI server +docker run --rm \ + --privileged=true \ + --shm-size=4g \ + -p 8000:8000 \ + -e VLLM_CPU_KVCACHE_SPACE= \ + -e VLLM_CPU_OMP_THREADS_BIND= \ + vllm-cpu-env \ + --model=meta-llama/Llama-3.2-1B-Instruct \ + --dtype=bfloat16 \ + other vLLM OpenAI server arguments +``` # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/build.inc.md b/docs/getting_started/installation/cpu/build.inc.md index d9ca04edee02..fa777fe0c8a1 100644 --- a/docs/getting_started/installation/cpu/build.inc.md +++ b/docs/getting_started/installation/cpu/build.inc.md @@ -2,7 +2,7 @@ First, install recommended compiler. We recommend to use `gcc/g++ >= 12.3.0` as ```bash sudo apt-get update -y -sudo apt-get install -y gcc-12 g++-12 libnuma-dev python3-dev +sudo apt-get install -y --no-install-recommends ccache git curl wget ca-certificates gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 ``` @@ -17,7 +17,7 @@ Third, install Python packages for vLLM CPU backend building: ```bash pip install --upgrade pip -pip install "cmake>=3.26.1" wheel packaging ninja "setuptools-scm>=8" numpy +pip install -v -r requirements/cpu-build.txt --extra-index-url https://download.pytorch.org/whl/cpu pip install -v -r requirements/cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu ``` @@ -33,4 +33,7 @@ If you want to develop vllm, install it in editable mode instead. VLLM_TARGET_DEVICE=cpu python setup.py develop ``` +!!! note + If you are building vLLM from source and not using the pre-built images, remember to set `LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD"` on x86 machines before running vLLM. + # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/s390x.inc.md b/docs/getting_started/installation/cpu/s390x.inc.md index 6c6c40baecec..acfb3396896b 100644 --- a/docs/getting_started/installation/cpu/s390x.inc.md +++ b/docs/getting_started/installation/cpu/s390x.inc.md @@ -56,14 +56,28 @@ Execute the following commands to build and install vLLM from the source. ``` # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] +```bash +docker build -f docker/Dockerfile.s390x \ + --tag vllm-cpu-env . + +# Launching OpenAI server +docker run --rm \ + --privileged=true \ + --shm-size=4g \ + -p 8000:8000 \ + -e VLLM_CPU_KVCACHE_SPACE= \ + -e VLLM_CPU_OMP_THREADS_BIND= \ + vllm-cpu-env \ + --model=meta-llama/Llama-3.2-1B-Instruct \ + --dtype=float \ + other vLLM OpenAI server arguments +``` + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/cpu/x86.inc.md b/docs/getting_started/installation/cpu/x86.inc.md index 0412d4ccef00..49e223f9b9bf 100644 --- a/docs/getting_started/installation/cpu/x86.inc.md +++ b/docs/getting_started/installation/cpu/x86.inc.md @@ -1,19 +1,15 @@ # --8<-- [start:installation] -vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. - -!!! warning - There are no pre-built wheels or images for this device, so you must build vLLM from source. +vLLM supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. # --8<-- [end:installation] # --8<-- [start:requirements] - OS: Linux -- Compiler: `gcc/g++ >= 12.3.0` (optional, recommended) -- Instruction Set Architecture (ISA): AVX512 (optional, recommended) +- CPU flags: `avx512f`, `avx512_bf16` (Optional), `avx512_vnni` (Optional) !!! tip - [Intel Extension for PyTorch (IPEX)](https://github.com/intel/intel-extension-for-pytorch) extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. + Use `lscpu` to check the CPU flags. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] @@ -26,21 +22,37 @@ vLLM initially supports basic model inferencing and serving on x86 CPU platform, --8<-- "docs/getting_started/installation/cpu/build.inc.md" -!!! note - - AVX512_BF16 is an extension ISA provides native BF16 data type conversion and vector product instructions, which brings some performance improvement compared with pure AVX512. The CPU backend build script will check the host CPU flags to determine whether to enable AVX512_BF16. - - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable `VLLM_CPU_AVX512BF16=1` before the building. - # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] -See [https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo) +[https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-cpu-release-repo) + +!!! warning + If deploying the pre-built images on machines only contain `avx512f`, `Illegal instruction` error may be raised. It is recommended to build images for these machines with `--build-arg VLLM_CPU_AVX512BF16=false` and `--build-arg VLLM_CPU_AVX512VNNI=false`. # --8<-- [end:pre-built-images] # --8<-- [start:build-image-from-source] +```bash +docker build -f docker/Dockerfile.cpu \ + --build-arg VLLM_CPU_AVX512BF16=false (default)|true \ + --build-arg VLLM_CPU_AVX512VNNI=false (default)|true \ + --tag vllm-cpu-env \ + --target vllm-openai . + +# Launching OpenAI server +docker run --rm \ + --privileged=true \ + --shm-size=4g \ + -p 8000:8000 \ + -e VLLM_CPU_KVCACHE_SPACE= \ + -e VLLM_CPU_OMP_THREADS_BIND= \ + vllm-cpu-env \ + --model=meta-llama/Llama-3.2-1B-Instruct \ + --dtype=bfloat16 \ + other vLLM OpenAI server arguments +``` + # --8<-- [end:build-image-from-source] # --8<-- [start:extra-information] # --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/google_tpu.md b/docs/getting_started/installation/google_tpu.md index 5dc2a7c93f4e..55d69d11fa40 100644 --- a/docs/getting_started/installation/google_tpu.md +++ b/docs/getting_started/installation/google_tpu.md @@ -37,7 +37,7 @@ information, see [Storage options for Cloud TPU data](https://cloud.devsite.corp - Google Cloud TPU VM - TPU versions: v6e, v5e, v5p, v4 -- Python: 3.10 or newer +- Python: 3.11 or newer ### Provision Cloud TPUs @@ -117,7 +117,7 @@ source ~/.bashrc Create and activate a Conda environment for vLLM: ```bash -conda create -n vllm python=3.10 -y +conda create -n vllm python=3.12 -y conda activate vllm ``` diff --git a/docs/getting_started/installation/gpu.md b/docs/getting_started/installation/gpu.md index 1be7557b79e5..e688cefea076 100644 --- a/docs/getting_started/installation/gpu.md +++ b/docs/getting_started/installation/gpu.md @@ -46,11 +46,11 @@ vLLM is a Python library that supports the following GPU variants. Select your G === "AMD ROCm" - There is no extra information on creating a new Python environment for this device. + --8<-- "docs/getting_started/installation/gpu/rocm.inc.md:set-up-using-python" === "Intel XPU" - There is no extra information on creating a new Python environment for this device. + --8<-- "docs/getting_started/installation/gpu/xpu.inc.md:set-up-using-python" ### Pre-built wheels diff --git a/docs/getting_started/installation/gpu/cuda.inc.md b/docs/getting_started/installation/gpu/cuda.inc.md index 0417a25f85ad..5ca5296d0a65 100644 --- a/docs/getting_started/installation/gpu/cuda.inc.md +++ b/docs/getting_started/installation/gpu/cuda.inc.md @@ -232,9 +232,6 @@ pip install -e . ``` # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] See [deployment-docker-pre-built-image][deployment-docker-pre-built-image] for instructions on using the official Docker image. @@ -261,4 +258,3 @@ See [deployment-docker-build-image-from-source][deployment-docker-build-image-fr See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. # --8<-- [end:supported-features] -# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/gpu/rocm.inc.md b/docs/getting_started/installation/gpu/rocm.inc.md index aa4cacaf1aed..560883d3caf9 100644 --- a/docs/getting_started/installation/gpu/rocm.inc.md +++ b/docs/getting_started/installation/gpu/rocm.inc.md @@ -2,6 +2,9 @@ vLLM supports AMD GPUs with ROCm 6.3. +!!! tip + [Docker](#set-up-using-docker) is the recommended way to use vLLM on ROCm. + !!! warning There are no pre-built wheels for this device, so you must either use the pre-built Docker image or build vLLM from source. @@ -14,6 +17,8 @@ vLLM supports AMD GPUs with ROCm 6.3. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] +There is no extra information on creating a new Python environment for this device. + # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] @@ -90,7 +95,7 @@ Currently, there are no pre-built ROCm wheels. 4. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: - ??? Commands + ??? console "Commands" ```bash pip install --upgrade pip @@ -123,9 +128,7 @@ Currently, there are no pre-built ROCm wheels. - For MI300x (gfx942) users, to achieve optimal performance, please refer to [MI300x tuning guide](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/index.html) for performance optimization and tuning tips on system and workflow level. For vLLM, please refer to [vLLM performance optimization](https://rocm.docs.amd.com/en/latest/how-to/tuning-guides/mi300x/workload.html#vllm-performance-optimization). -## Set up using Docker (Recommended) - -# --8<-- [end:set-up-using-docker] +# --8<-- [end:build-wheel-from-source] # --8<-- [start:pre-built-images] The [AMD Infinity hub for vLLM](https://hub.docker.com/r/rocm/vllm/tags) offers a prebuilt, optimized @@ -203,7 +206,7 @@ DOCKER_BUILDKIT=1 docker build \ To run the above docker image `vllm-rocm`, use the below command: -??? Command +??? console "Command" ```bash docker run -it \ @@ -227,4 +230,3 @@ Where the `` is the location where the model is stored, for examp See [feature-x-hardware][feature-x-hardware] compatibility matrix for feature support information. # --8<-- [end:supported-features] -# --8<-- [end:extra-information] diff --git a/docs/getting_started/installation/gpu/xpu.inc.md b/docs/getting_started/installation/gpu/xpu.inc.md index 4469be36c007..b77c4e00cf0c 100644 --- a/docs/getting_started/installation/gpu/xpu.inc.md +++ b/docs/getting_started/installation/gpu/xpu.inc.md @@ -14,6 +14,8 @@ vLLM initially supports basic model inference and serving on Intel GPU platform. # --8<-- [end:requirements] # --8<-- [start:set-up-using-python] +There is no extra information on creating a new Python environment for this device. + # --8<-- [end:set-up-using-python] # --8<-- [start:pre-built-wheels] @@ -43,9 +45,6 @@ VLLM_TARGET_DEVICE=xpu python setup.py install type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. # --8<-- [end:build-wheel-from-source] -# --8<-- [start:set-up-using-docker] - -# --8<-- [end:set-up-using-docker] # --8<-- [start:pre-built-images] Currently, there are no pre-built XPU images. @@ -81,4 +80,8 @@ python -m vllm.entrypoints.openai.api_server \ By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the helper script. # --8<-- [end:supported-features] -# --8<-- [end:extra-information] +# --8<-- [start:distributed-backend] + +XPU platform uses **torch-ccl** for torch<2.8 and **xccl** for torch>=2.8 as distributed backend, since torch 2.8 supports **xccl** as built-in backend for XPU. + +# --8<-- [end:distributed-backend] diff --git a/docs/getting_started/installation/intel_gaudi.md b/docs/getting_started/installation/intel_gaudi.md index 7a7a5a51c24c..0be0d02d0679 100644 --- a/docs/getting_started/installation/intel_gaudi.md +++ b/docs/getting_started/installation/intel_gaudi.md @@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run: hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed -pip list | grep neural # verify that neural_compressor is installed +pip list | grep neural # verify that neural_compressor_pt is installed ``` Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade) @@ -109,8 +109,8 @@ docker run \ ### Supported features -- [Offline inference][offline-inference] -- Online serving via [OpenAI-Compatible Server][serving-openai-compatible-server] +- [Offline inference](../../serving/offline_inference.md) +- Online serving via [OpenAI-Compatible Server](../../serving/openai_compatible_server.md) - HPU autodetection - no need to manually select device within vLLM - Paged KV cache with algorithms enabled for Intel Gaudi accelerators - Custom Intel Gaudi implementations of Paged Attention, KV cache ops, @@ -120,12 +120,13 @@ docker run \ - Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html) for accelerating low-batch latency and throughput - Attention with Linear Biases (ALiBi) +- INC quantization ### Unsupported features - Beam search - LoRA adapters -- Quantization +- AWQ quantization - Prefill chunking (mixed-batch inferencing) ### Supported configurations @@ -133,36 +134,20 @@ docker run \ The following configurations have been validated to function with Gaudi2 devices. Configurations that are not listed may or may not work. -- [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) - on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 - datatype with random or greedy sampling -- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) - on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 - datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) - on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 - datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 - datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) - on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 - datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) - on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16 - datatype with random or greedy sampling -- [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) - with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) - with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) - with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) - with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) - with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling -- [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) - with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling +| Model | TP Size| dtype | Sampling | +|-------|--------|--------|----------| +| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy | +| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy | +| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy | +| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy | +| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy | ## Performance tuning @@ -237,7 +222,7 @@ As an example, if a request of 3 sequences, with max sequence length of 412 come Warmup is an optional, but highly recommended step occurring before vLLM server starts listening. It executes a forward pass for each bucket with dummy data. The goal is to pre-compile all graphs and not incur any graph compilation overheads within bucket boundaries during server runtime. Each warmup step is logged during vLLM startup: -??? Logs +??? console "Logs" ```text INFO 08-01 22:26:47 hpu_model_runner.py:1066] [Warmup][Prompt][1/24] batch_size:4 seq_len:1024 free_mem:79.16 GiB @@ -286,7 +271,7 @@ When there's large amount of requests pending, vLLM scheduler will attempt to fi Each described step is logged by vLLM server, as follows (negative values correspond to memory being released): -??? Logs +??? console "Logs" ```text INFO 08-02 17:37:44 hpu_model_runner.py:493] Prompt bucket config (min, step, max_warmup) bs:[1, 32, 4], seq:[128, 128, 1024] diff --git a/docs/getting_started/quickstart.md b/docs/getting_started/quickstart.md index 39100e4ca540..74235db16a15 100644 --- a/docs/getting_started/quickstart.md +++ b/docs/getting_started/quickstart.md @@ -1,7 +1,4 @@ ---- -title: Quickstart ---- -[](){ #quickstart } +# Quickstart This guide will help you quickly get started with vLLM to perform: @@ -43,7 +40,7 @@ uv pip install vllm --torch-backend=auto ``` !!! note - For more detail and non-CUDA platforms, please refer [here][installation-index] for specific instructions on how to install vLLM. + For more detail and non-CUDA platforms, please refer [here](installation/README.md) for specific instructions on how to install vLLM. [](){ #quickstart-offline } @@ -77,7 +74,7 @@ prompts = [ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) ``` -The [LLM][vllm.LLM] class initializes vLLM's engine and the [OPT-125M model](https://arxiv.org/abs/2205.01068) for offline inference. The list of supported models can be found [here][supported-models]. +The [LLM][vllm.LLM] class initializes vLLM's engine and the [OPT-125M model](https://arxiv.org/abs/2205.01068) for offline inference. The list of supported models can be found [here](../models/supported_models.md). ```python llm = LLM(model="facebook/opt-125m") @@ -147,7 +144,7 @@ curl http://localhost:8000/v1/completions \ Since this server is compatible with OpenAI API, you can use it as a drop-in replacement for any applications using OpenAI API. For example, another way to query the server is via the `openai` Python package: -??? Code +??? code ```python from openai import OpenAI @@ -186,7 +183,7 @@ curl http://localhost:8000/v1/chat/completions \ Alternatively, you can use the `openai` Python package: -??? Code +??? code ```python from openai import OpenAI diff --git a/docs/mkdocs/hooks/generate_argparse.py b/docs/mkdocs/hooks/generate_argparse.py new file mode 100644 index 000000000000..22cf41e6041d --- /dev/null +++ b/docs/mkdocs/hooks/generate_argparse.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +import sys +from argparse import SUPPRESS, HelpFormatter +from pathlib import Path +from typing import Literal +from unittest.mock import MagicMock, patch + +ROOT_DIR = Path(__file__).parent.parent.parent.parent +ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse" + +sys.path.insert(0, str(ROOT_DIR)) +sys.modules["aiohttp"] = MagicMock() +sys.modules["blake3"] = MagicMock() +sys.modules["vllm._C"] = MagicMock() + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 +from vllm.entrypoints.openai.cli_args import make_arg_parser # noqa: E402 +from vllm.utils import FlexibleArgumentParser # noqa: E402 + +logger = logging.getLogger("mkdocs") + + +class MarkdownFormatter(HelpFormatter): + """Custom formatter that generates markdown for argument groups.""" + + def __init__(self, prog, starting_heading_level=3): + super().__init__(prog, + max_help_position=float('inf'), + width=float('inf')) + self._section_heading_prefix = "#" * starting_heading_level + self._argument_heading_prefix = "#" * (starting_heading_level + 1) + self._markdown_output = [] + + def start_section(self, heading): + if heading not in {"positional arguments", "options"}: + heading_md = f"\n{self._section_heading_prefix} {heading}\n\n" + self._markdown_output.append(heading_md) + + def end_section(self): + pass + + def add_text(self, text): + if text: + self._markdown_output.append(f"{text.strip()}\n\n") + + def add_usage(self, usage, actions, groups, prefix=None): + pass + + def add_arguments(self, actions): + for action in actions: + if (len(action.option_strings) == 0 + or "--help" in action.option_strings): + continue + + option_strings = f'`{"`, `".join(action.option_strings)}`' + heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n" + self._markdown_output.append(heading_md) + + if choices := action.choices: + choices = f'`{"`, `".join(str(c) for c in choices)}`' + self._markdown_output.append( + f"Possible choices: {choices}\n\n") + + self._markdown_output.append(f"{action.help}\n\n") + + if (default := action.default) != SUPPRESS: + self._markdown_output.append(f"Default: `{default}`\n\n") + + def format_help(self): + """Return the formatted help as markdown.""" + return "".join(self._markdown_output) + + +def create_parser(cls, **kwargs) -> FlexibleArgumentParser: + """Create a parser for the given class with markdown formatting. + + Args: + cls: The class to create a parser for + **kwargs: Additional keyword arguments to pass to `cls.add_cli_args`. + + Returns: + FlexibleArgumentParser: A parser with markdown formatting for the class. + """ + parser = FlexibleArgumentParser() + parser.formatter_class = MarkdownFormatter + with patch("vllm.config.DeviceConfig.__post_init__"): + return cls.add_cli_args(parser, **kwargs) + + +def create_serve_parser() -> FlexibleArgumentParser: + """Create a parser for the serve command with markdown formatting.""" + parser = FlexibleArgumentParser() + parser.formatter_class = lambda prog: MarkdownFormatter( + prog, starting_heading_level=4) + return make_arg_parser(parser) + + +def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + logger.info("Generating argparse documentation") + logger.debug("Root directory: %s", ROOT_DIR.resolve()) + logger.debug("Output directory: %s", ARGPARSE_DOC_DIR.resolve()) + + # Create the ARGPARSE_DOC_DIR if it doesn't exist + if not ARGPARSE_DOC_DIR.exists(): + ARGPARSE_DOC_DIR.mkdir(parents=True) + + # Create parsers to document + parsers = { + "engine_args": create_parser(EngineArgs), + "async_engine_args": create_parser(AsyncEngineArgs, + async_args_only=True), + "serve": create_serve_parser(), + } + + # Generate documentation for each parser + for stem, parser in parsers.items(): + doc_path = ARGPARSE_DOC_DIR / f"{stem}.md" + with open(doc_path, "w") as f: + f.write(parser.format_help()) + logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/generate_examples.py b/docs/mkdocs/hooks/generate_examples.py index 7cfc89605150..0ee52bb34603 100644 --- a/docs/mkdocs/hooks/generate_examples.py +++ b/docs/mkdocs/hooks/generate_examples.py @@ -1,19 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import logging from dataclasses import dataclass, field from pathlib import Path from typing import Literal import regex as re +logger = logging.getLogger("mkdocs") + ROOT_DIR = Path(__file__).parent.parent.parent.parent ROOT_DIR_RELATIVE = '../../../../..' EXAMPLE_DIR = ROOT_DIR / "examples" EXAMPLE_DOC_DIR = ROOT_DIR / "docs/examples" -print(ROOT_DIR.resolve()) -print(EXAMPLE_DIR.resolve()) -print(EXAMPLE_DOC_DIR.resolve()) def fix_case(text: str) -> str: @@ -135,6 +135,11 @@ def generate(self) -> str: def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): + logger.info("Generating example documentation") + logger.debug("Root directory: %s", ROOT_DIR.resolve()) + logger.debug("Example directory: %s", EXAMPLE_DIR.resolve()) + logger.debug("Example document directory: %s", EXAMPLE_DOC_DIR.resolve()) + # Create the EXAMPLE_DOC_DIR if it doesn't exist if not EXAMPLE_DOC_DIR.exists(): EXAMPLE_DOC_DIR.mkdir(parents=True) @@ -156,8 +161,8 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool): for example in sorted(examples, key=lambda e: e.path.stem): example_name = f"{example.path.stem}.md" doc_path = EXAMPLE_DOC_DIR / example.category / example_name - print(doc_path) if not doc_path.parent.exists(): doc_path.parent.mkdir(parents=True) with open(doc_path, "w+") as f: f.write(example.generate()) + logger.debug("Example generated: %s", doc_path.relative_to(ROOT_DIR)) diff --git a/docs/mkdocs/hooks/url_schemes.py b/docs/mkdocs/hooks/url_schemes.py index 6484581ed947..6fce6bd8130e 100644 --- a/docs/mkdocs/hooks/url_schemes.py +++ b/docs/mkdocs/hooks/url_schemes.py @@ -1,5 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +This is basically a port of MyST parser’s external URL resolution mechanism +(https://myst-parser.readthedocs.io/en/latest/syntax/cross-referencing.html#customising-external-url-resolution) +to work with MkDocs. + +It allows Markdown authors to use GitHub shorthand links like: + + - [Text](gh-issue:123) + - + - [File](gh-file:path/to/file.py#L10) + +These are automatically rewritten into fully qualified GitHub URLs pointing to +issues, pull requests, files, directories, or projects in the +`vllm-project/vllm` repository. + +The goal is to simplify cross-referencing common GitHub resources +in project docs. +""" + import regex as re from mkdocs.config.defaults import MkDocsConfig from mkdocs.structure.files import Files @@ -7,11 +26,42 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, - files: Files): + files: Files) -> str: + """ + Custom MkDocs plugin hook to rewrite special GitHub reference links + in Markdown. + + This function scans the given Markdown content for specially formatted + GitHub shorthand links, such as: + - `[Link text](gh-issue:123)` + - `` + + And rewrites them into fully-qualified GitHub URLs with GitHub icons: + - `[:octicons-mark-github-16: Link text](https://github.com/vllm-project/vllm/issues/123)` + - `[:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456)` + + Supported shorthand types: + - `gh-issue` + - `gh-pr` + - `gh-project` + - `gh-dir` + - `gh-file` + + Args: + markdown (str): The raw Markdown content of the page. + page (Page): The MkDocs page object being processed. + config (MkDocsConfig): The MkDocs site configuration. + files (Files): The collection of files in the MkDocs build. + + Returns: + str: The updated Markdown content with GitHub shorthand links replaced. + """ gh_icon = ":octicons-mark-github-16:" gh_url = "https://github.com" repo_url = f"{gh_url}/vllm-project/vllm" org_url = f"{gh_url}/orgs/vllm-project" + + # Mapping of shorthand types to their corresponding GitHub base URLs urls = { "issue": f"{repo_url}/issues", "pr": f"{repo_url}/pull", @@ -19,6 +69,8 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, "dir": f"{repo_url}/tree/main", "file": f"{repo_url}/blob/main", } + + # Default title prefixes for auto links titles = { "issue": "Issue #", "pr": "Pull Request #", @@ -27,11 +79,19 @@ def on_page_markdown(markdown: str, *, page: Page, config: MkDocsConfig, "file": "", } + # Regular expression to match GitHub shorthand links scheme = r"gh-(?P.+?):(?P.+?)(#(?P.+?))?" inline_link = re.compile(r"\[(?P[^\[]+?)\]\(" + scheme + r"\)") auto_link = re.compile(f"<{scheme}>") def replace_inline_link(match: re.Match) -> str: + """ + Replaces a matched inline-style GitHub shorthand link + with a full Markdown link. + + Example: + [My issue](gh-issue:123) → [:octicons-mark-github-16: My issue](https://github.com/vllm-project/vllm/issues/123) + """ url = f'{urls[match.group("type")]}/{match.group("path")}' if fragment := match.group("fragment"): url += f"#{fragment}" @@ -39,6 +99,13 @@ def replace_inline_link(match: re.Match) -> str: return f'[{gh_icon} {match.group("title")}]({url})' def replace_auto_link(match: re.Match) -> str: + """ + Replaces a matched autolink-style GitHub shorthand + with a full Markdown link. + + Example: + <gh-pr:456> → [:octicons-mark-github-16: Pull Request #456](https://github.com/vllm-project/vllm/pull/456) + """ type = match.group("type") path = match.group("path") title = f"{titles[type]}{path}" @@ -48,6 +115,7 @@ def replace_auto_link(match: re.Match) -> str: return f"[{gh_icon} {title}]({url})" + # Replace both inline and autolinks markdown = inline_link.sub(replace_inline_link, markdown) markdown = auto_link.sub(replace_auto_link, markdown) diff --git a/docs/mkdocs/overrides/partials/toc-item.html b/docs/mkdocs/overrides/partials/toc-item.html new file mode 100644 index 000000000000..284af59cbe2c --- /dev/null +++ b/docs/mkdocs/overrides/partials/toc-item.html @@ -0,0 +1,21 @@ +<!-- Enables the use of toc_depth in document frontmatter https://github.com/squidfunk/mkdocs-material/issues/4827#issuecomment-1869812019 --> +<li class="md-nav__item"> + <a href="{{ toc_item.url }}" class="md-nav__link"> + <span class="md-ellipsis"> + {{ toc_item.title }} + </span> + </a> + + <!-- Table of contents list --> + {% if toc_item.children %} + <nav class="md-nav" aria-label="{{ toc_item.title | striptags }}"> + <ul class="md-nav__list"> + {% for toc_item in toc_item.children %} + {% if not page.meta.toc_depth or toc_item.level <= page.meta.toc_depth %} + {% include "partials/toc-item.html" %} + {% endif %} + {% endfor %} + </ul> + </nav> + {% endif %} + </li> \ No newline at end of file diff --git a/docs/mkdocs/stylesheets/extra.css b/docs/mkdocs/stylesheets/extra.css index 892013c1cddf..fb44d9cdcf3d 100644 --- a/docs/mkdocs/stylesheets/extra.css +++ b/docs/mkdocs/stylesheets/extra.css @@ -39,6 +39,8 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . :root { --md-admonition-icon--announcement: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" width="16" height="16"><path d="M3.25 9a.75.75 0 0 1 .75.75c0 2.142.456 3.828.733 4.653a.122.122 0 0 0 .05.064.212.212 0 0 0 .117.033h1.31c.085 0 .18-.042.258-.152a.45.45 0 0 0 .075-.366A16.743 16.743 0 0 1 6 9.75a.75.75 0 0 1 1.5 0c0 1.588.25 2.926.494 3.85.293 1.113-.504 2.4-1.783 2.4H4.9c-.686 0-1.35-.41-1.589-1.12A16.4 16.4 0 0 1 2.5 9.75.75.75 0 0 1 3.25 9Z"></path><path d="M0 6a4 4 0 0 1 4-4h2.75a.75.75 0 0 1 .75.75v6.5a.75.75 0 0 1-.75.75H4a4 4 0 0 1-4-4Zm4-2.5a2.5 2.5 0 1 0 0 5h2v-5Z"></path><path d="M15.59.082A.75.75 0 0 1 16 .75v10.5a.75.75 0 0 1-1.189.608l-.002-.001h.001l-.014-.01a5.775 5.775 0 0 0-.422-.25 10.63 10.63 0 0 0-1.469-.64C11.576 10.484 9.536 10 6.75 10a.75.75 0 0 1 0-1.5c2.964 0 5.174.516 6.658 1.043.423.151.787.302 1.092.443V2.014c-.305.14-.669.292-1.092.443C11.924 2.984 9.713 3.5 6.75 3.5a.75.75 0 0 1 0-1.5c2.786 0 4.826-.484 6.155-.957.665-.236 1.154-.47 1.47-.64.144-.077.284-.161.421-.25l.014-.01a.75.75 0 0 1 .78-.061Z"></path></svg>'); --md-admonition-icon--important: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16" width="16" height="16"><path d="M4.47.22A.749.749 0 0 1 5 0h6c.199 0 .389.079.53.22l4.25 4.25c.141.14.22.331.22.53v6a.749.749 0 0 1-.22.53l-4.25 4.25A.749.749 0 0 1 11 16H5a.749.749 0 0 1-.53-.22L.22 11.53A.749.749 0 0 1 0 11V5c0-.199.079-.389.22-.53Zm.84 1.28L1.5 5.31v5.38l3.81 3.81h5.38l3.81-3.81V5.31L10.69 1.5ZM8 4a.75.75 0 0 1 .75.75v3.5a.75.75 0 0 1-1.5 0v-3.5A.75.75 0 0 1 8 4Zm0 8a1 1 0 1 1 0-2 1 1 0 0 1 0 2Z"></path></svg>'); + --md-admonition-icon--code: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="m11.28 3.22 4.25 4.25a.75.75 0 0 1 0 1.06l-4.25 4.25a.749.749 0 0 1-1.275-.326.75.75 0 0 1 .215-.734L13.94 8l-3.72-3.72a.749.749 0 0 1 .326-1.275.75.75 0 0 1 .734.215m-6.56 0a.75.75 0 0 1 1.042.018.75.75 0 0 1 .018 1.042L2.06 8l3.72 3.72a.749.749 0 0 1-.326 1.275.75.75 0 0 1-.734-.215L.47 8.53a.75.75 0 0 1 0-1.06Z"/></svg>'); + --md-admonition-icon--console: url('data:image/svg+xml;charset=utf-8,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M0 2.75C0 1.784.784 1 1.75 1h12.5c.966 0 1.75.784 1.75 1.75v10.5A1.75 1.75 0 0 1 14.25 15H1.75A1.75 1.75 0 0 1 0 13.25Zm1.75-.25a.25.25 0 0 0-.25.25v10.5c0 .138.112.25.25.25h12.5a.25.25 0 0 0 .25-.25V2.75a.25.25 0 0 0-.25-.25ZM7.25 8a.75.75 0 0 1-.22.53l-2.25 2.25a.749.749 0 0 1-1.275-.326.75.75 0 0 1 .215-.734L5.44 8 3.72 6.28a.749.749 0 0 1 .326-1.275.75.75 0 0 1 .734.215l2.25 2.25c.141.14.22.331.22.53m1.5 1.5h3a.75.75 0 0 1 0 1.5h-3a.75.75 0 0 1 0-1.5"/></svg>'); } .md-typeset .admonition.announcement, @@ -49,6 +51,14 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . .md-typeset details.important { border-color: rgb(239, 85, 82); } +.md-typeset .admonition.code, +.md-typeset details.code { + border-color: #64dd17 +} +.md-typeset .admonition.console, +.md-typeset details.console { + border-color: #64dd17 +} .md-typeset .announcement > .admonition-title, .md-typeset .announcement > summary { @@ -58,6 +68,14 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . .md-typeset .important > summary { background-color: rgb(239, 85, 82, 0.1); } +.md-typeset .code > .admonition-title, +.md-typeset .code > summary { + background-color: #64dd171a; +} +.md-typeset .console > .admonition-title, +.md-typeset .console > summary { + background-color: #64dd171a; +} .md-typeset .announcement > .admonition-title::before, .md-typeset .announcement > summary::before { @@ -71,6 +89,18 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . -webkit-mask-image: var(--md-admonition-icon--important); mask-image: var(--md-admonition-icon--important); } +.md-typeset .code > .admonition-title::before, +.md-typeset .code > summary::before { + background-color: #64dd17; + -webkit-mask-image: var(--md-admonition-icon--code); + mask-image: var(--md-admonition-icon--code); +} +.md-typeset .console > .admonition-title::before, +.md-typeset .console > summary::before { + background-color: #64dd17; + -webkit-mask-image: var(--md-admonition-icon--console); + mask-image: var(--md-admonition-icon--console); +} /* Make label fully visible on hover */ .md-content__button[href*="edit"]:hover::after { @@ -143,3 +173,13 @@ body[data-md-color-scheme="slate"] .md-nav__item--section > label.md-nav__link . [data-md-color-scheme="slate"] .logo-light { display: none; } + +/* Outline for content tabs */ +.md-typeset .tabbed-set { + border: 0.075rem solid var(--md-default-fg-color); + border-radius: 0.2rem; +} + +.md-typeset .tabbed-content { + padding: 0 0.6em; +} \ No newline at end of file diff --git a/docs/models/extensions/runai_model_streamer.md b/docs/models/extensions/runai_model_streamer.md index 60b43d21d9f6..992dddf385d0 100644 --- a/docs/models/extensions/runai_model_streamer.md +++ b/docs/models/extensions/runai_model_streamer.md @@ -1,7 +1,4 @@ ---- -title: Loading models with Run:ai Model Streamer ---- -[](){ #runai-model-streamer } +# Loading models with Run:ai Model Streamer Run:ai Model Streamer is a library to read tensors in concurrency, while streaming it to GPU memory. Further reading can be found in [Run:ai Model Streamer Documentation](https://github.com/run-ai/runai-model-streamer/blob/master/docs/README.md). diff --git a/docs/models/extensions/tensorizer.md b/docs/models/extensions/tensorizer.md index e0b4479c0beb..6ea61b080cda 100644 --- a/docs/models/extensions/tensorizer.md +++ b/docs/models/extensions/tensorizer.md @@ -1,7 +1,4 @@ ---- -title: Loading models with CoreWeave's Tensorizer ---- -[](){ #tensorizer } +# Loading models with CoreWeave's Tensorizer vLLM supports loading models with [CoreWeave's Tensorizer](https://docs.coreweave.com/coreweave-machine-learning-and-ai/inference/tensorizer). vLLM model tensors that have been serialized to disk, an HTTP/HTTPS endpoint, or S3 endpoint can be deserialized @@ -10,7 +7,7 @@ shorter Pod startup times and CPU memory usage. Tensor encryption is also suppor For more information on CoreWeave's Tensorizer, please refer to [CoreWeave's Tensorizer documentation](https://github.com/coreweave/tensorizer). For more information on serializing a vLLM model, as well a general usage guide to using Tensorizer with vLLM, see -the [vLLM example script](https://docs.vllm.ai/en/latest/examples/others/tensorize_vllm_model.html). +the [vLLM example script](../../examples/others/tensorize_vllm_model.md). !!! note Note that to use this feature you will need to install `tensorizer` by running `pip install vllm[tensorizer]`. diff --git a/docs/models/generative_models.md b/docs/models/generative_models.md index fd5c659921de..21ad115e411a 100644 --- a/docs/models/generative_models.md +++ b/docs/models/generative_models.md @@ -1,7 +1,4 @@ ---- -title: Generative Models ---- -[](){ #generative-models } +# Generative Models vLLM provides first-class support for generative models, which covers most of LLMs. @@ -85,7 +82,7 @@ and automatically applies the model's [chat template](https://huggingface.co/doc In general, only instruction-tuned models have a chat template. Base models may perform poorly as they are not trained to respond to the chat conversation. -??? Code +??? code ```python from vllm import LLM @@ -134,7 +131,7 @@ outputs = llm.chat(conversation, chat_template=custom_template) ## Online Serving -Our [OpenAI-Compatible Server][serving-openai-compatible-server] provides endpoints that correspond to the offline APIs: +Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: - [Completions API][completions-api] is similar to `LLM.generate` but only accepts text. -- [Chat API][chat-api] is similar to `LLM.chat`, accepting both text and [multi-modal inputs][multimodal-inputs] for models with a chat template. +- [Chat API][chat-api] is similar to `LLM.chat`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for models with a chat template. diff --git a/docs/models/hardware_supported_models/tpu.md b/docs/models/hardware_supported_models/tpu.md index dca5e20cb343..da03a3b3160a 100644 --- a/docs/models/hardware_supported_models/tpu.md +++ b/docs/models/hardware_supported_models/tpu.md @@ -1,7 +1,4 @@ ---- -title: TPU ---- -[](){ #tpu-supported-models } +# TPU # TPU Supported Models ## Text-only Language Models diff --git a/docs/models/pooling_models.md b/docs/models/pooling_models.md index 693212e64bd2..741ae2d79c1e 100644 --- a/docs/models/pooling_models.md +++ b/docs/models/pooling_models.md @@ -1,7 +1,4 @@ ---- -title: Pooling Models ---- -[](){ #pooling-models } +# Pooling Models vLLM also supports pooling models, including embedding, reranking and reward models. @@ -11,29 +8,54 @@ before returning them. !!! note We currently support pooling models primarily as a matter of convenience. - As shown in the [Compatibility Matrix][compatibility-matrix], most vLLM features are not applicable to + As shown in the [Compatibility Matrix](../features/compatibility_matrix.md), most vLLM features are not applicable to pooling models as they only work on the generation or decode stage, so performance may not improve as much. -For pooling models, we support the following `--task` options. -The selected option sets the default pooler used to extract the final hidden states: +If the model doesn't implement this interface, you can set `--task` which tells vLLM +to convert the model into a pooling model. -| Task | Pooling Type | Normalization | Softmax | -|---------------------------------|----------------|-----------------|-----------| -| Embedding (`embed`) | `LAST` | ✅︎ | ❌ | -| Classification (`classify`) | `LAST` | ❌ | ✅︎ | -| Sentence Pair Scoring (`score`) | \* | \* | \* | +| `--task` | Model type | Supported pooling tasks | +|------------|----------------------|-------------------------------| +| `embed` | Embedding model | `encode`, `embed` | +| `classify` | Classification model | `encode`, `classify`, `score` | +| `reward` | Reward model | `encode` | -\*The default pooler is always defined by the model. +## Pooling Tasks -!!! note - If the model's implementation in vLLM defines its own pooler, the default pooler is set to that instead of the one specified in this table. +In vLLM, we define the following pooling tasks and corresponding APIs: + +| Task | APIs | +|------------|--------------------| +| `encode` | `encode` | +| `embed` | `embed`, `score`\* | +| `classify` | `classify` | +| `score` | `score` | + +\*The `score` API falls back to `embed` task if the model does not support `score` task. + +Each pooling model in vLLM supports one or more of these tasks according to [Pooler.get_supported_tasks][vllm.model_executor.layers.Pooler.get_supported_tasks]. + +By default, the pooler assigned to each task has the following attributes: + +| Task | Pooling Type | Normalization | Softmax | +|------------|----------------|---------------|---------| +| `encode` | `ALL` | ❌ | ❌ | +| `embed` | `LAST` | ✅︎ | ❌ | +| `classify` | `LAST` | ❌ | ✅︎ | + +These defaults may be overridden by the model's implementation in vLLM. When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models, -we attempt to override the default pooler based on its Sentence Transformers configuration file (`modules.json`). +we attempt to override the defaults based on its Sentence Transformers configuration file (`modules.json`), +which takes priority over the model's defaults. + +You can further customize this via the `--override-pooler-config` option, +which takes priority over both the model's and Sentence Transformers's defaults. + +!!! note -!!! tip - You can customize the model's pooling method via the `--override-pooler-config` option, - which takes priority over both the model's and Sentence Transformers's defaults. + The above configuration may be disregarded if the model's implementation in vLLM defines its own pooler + that is not based on [PoolerConfig][vllm.config.PoolerConfig]. ## Offline Inference @@ -113,10 +135,10 @@ A code example can be found here: <gh-file:examples/offline_inference/basic/scor ## Online Serving -Our [OpenAI-Compatible Server][serving-openai-compatible-server] provides endpoints that correspond to the offline APIs: +Our [OpenAI-Compatible Server](../serving/openai_compatible_server.md) provides endpoints that correspond to the offline APIs: - [Pooling API][pooling-api] is similar to `LLM.encode`, being applicable to all types of pooling models. -- [Embeddings API][embeddings-api] is similar to `LLM.embed`, accepting both text and [multi-modal inputs][multimodal-inputs] for embedding models. +- [Embeddings API][embeddings-api] is similar to `LLM.embed`, accepting both text and [multi-modal inputs](../features/multimodal_inputs.md) for embedding models. - [Classification API][classification-api] is similar to `LLM.classify` and is applicable to sequence classification models. - [Score API][score-api] is similar to `LLM.score` for cross-encoder models. @@ -152,11 +174,11 @@ You can change the output dimensions of embedding models that support Matryoshka ```python from vllm import LLM, PoolingParams -model = LLM(model="jinaai/jina-embeddings-v3", - task="embed", - trust_remote_code=True) -outputs = model.embed(["Follow the white rabbit."], - pooling_params=PoolingParams(dimensions=32)) +llm = LLM(model="jinaai/jina-embeddings-v3", + task="embed", + trust_remote_code=True) +outputs = llm.embed(["Follow the white rabbit."], + pooling_params=PoolingParams(dimensions=32)) print(outputs[0].outputs) ``` diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7ec91df98b28..c8b6c6c86120 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -1,7 +1,4 @@ ---- -title: Supported Models ---- -[](){ #supported-models } +# Supported Models vLLM supports [generative](./generative_models.md) and [pooling](./pooling_models.md) models across various tasks. If a model supports more than one task, you can set the task via the `--task` argument. @@ -21,7 +18,7 @@ These models are what we list in [supported-text-models][supported-text-models] ### Transformers -vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models are supported, and vision language model support is planned! +vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases. To check if the modeling backend is Transformers, you can simply do this: @@ -31,14 +28,17 @@ llm = LLM(model=..., task="generate") # Name or path of your model llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersForCausalLM` then it means it's based on Transformers! +If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers! !!! tip - You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference][offline-inference] or `--model-impl transformers` for the [openai-compatible-server][serving-openai-compatible-server]. + You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md). !!! note vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. +!!! note + In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance. + #### Custom models If a model is neither supported natively by vLLM or Transformers, it can still be used in vLLM! @@ -53,8 +53,8 @@ For a model to be compatible with the Transformers backend for vLLM it must: If the compatible model is: -- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for [offline-inference][offline-inference] or `--trust-remote-code` for the [openai-compatible-server][serving-openai-compatible-server]. -- in a local directory, simply pass directory path to `model=<MODEL_DIR>` for [offline-inference][offline-inference] or `vllm serve <MODEL_DIR>` for the [openai-compatible-server][serving-openai-compatible-server]. +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for [offline-inference](../serving/offline_inference.md) or `--trust-remote-code` for the [openai-compatible-server](../serving/openai_compatible_server.md). +- in a local directory, simply pass directory path to `model=<MODEL_DIR>` for [offline-inference](../serving/offline_inference.md) or `vllm serve <MODEL_DIR>` for the [openai-compatible-server](../serving/openai_compatible_server.md). This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! @@ -102,7 +102,7 @@ Here is what happens in the background when this model is loaded: 1. The config is loaded. 2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`. -3. `MyModel` is loaded into `TransformersForCausalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. +3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used. That's it! @@ -171,7 +171,7 @@ The [Transformers backend][transformers-backend] enables you to run models direc If vLLM successfully returns text (for generative models) or hidden states (for pooling models), it indicates that your model is supported. -Otherwise, please refer to [Adding a New Model][new-model] for instructions on how to implement your model in vLLM. +Otherwise, please refer to [Adding a New Model](../contributing/model/README.md) for instructions on how to implement your model in vLLM. Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. #### Download a model @@ -308,91 +308,104 @@ print(output) ### Generative Models -See [this page][generative-models] for more information on how to use generative models. +See [this page](generative_models.md) for more information on how to use generative models. #### Text Generation Specified using `--task generate`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|---------------------------------------------------|-----------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | -| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | | -| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | -| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | -| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | -| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat` etc. | | ✅︎ | ✅︎ | -| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat` etc. | | ✅︎ | ✅︎ | -| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. | | ✅︎ | ✅︎ | -| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst` etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`,etc. | | ✅︎ | ✅︎ | -| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. | | ✅︎ | ✅︎ | -| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | -| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | -| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | | -| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | -| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | -| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | -| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | -| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | | -| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | -| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`etc. | | | ✅︎ | -| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | -| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | -| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | -| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | | -| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | -| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | | -| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | -| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | -| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | -| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | -| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ | -| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | -| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | -| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | | ✅︎ | ✅︎ | -| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | | ✅︎ | ✅︎ | -| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | -| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | -| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`etc. | | | | -| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | -| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | +<style> +th { + white-space: nowrap; + min-width: 0 !important; +} +</style> + +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `ArceeForCausalLM` | Arcee (AFM) | `arcee-ai/AFM-4.5B-Base`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ | +| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ | +| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | | +| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | | +| `ChatGLMModel`, `ChatGLMForConditionalGeneration` | ChatGLM | `THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, `ShieldLM-6B-chatglm3`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `CohereForCausalLM`, `Cohere2ForCausalLM` | Command-R | `CohereForAI/c4ai-command-r-v01`, `CohereForAI/c4ai-command-r7b-12-2024`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DbrxForCausalLM` | DBRX | `databricks/dbrx-base`, `databricks/dbrx-instruct`, etc. | | ✅︎ | ✅︎ | +| `DeciLMForCausalLM` | DeciLM | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `DeepseekForCausalLM` | DeepSeek | `deepseek-ai/deepseek-llm-67b-base`, `deepseek-ai/deepseek-llm-7b-chat`, etc. | | ✅︎ | ✅︎ | +| `DeepseekV2ForCausalLM` | DeepSeek-V2 | `deepseek-ai/DeepSeek-V2`, `deepseek-ai/DeepSeek-V2-Chat`, etc. | | ✅︎ | ✅︎ | +| `DeepseekV3ForCausalLM` | DeepSeek-V3 | `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3`, etc. | | ✅︎ | ✅︎ | +| `Dots1ForCausalLM` | dots.llm1 | `rednote-hilab/dots.llm1.base`, `rednote-hilab/dots.llm1.inst`, etc. | | ✅︎ | ✅︎ | +| `Ernie4_5_ForCausalLM` | Ernie4.5 | `baidu/ERNIE-4.5-0.3B-PT`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Ernie4_5_MoeForCausalLM` | Ernie4.5MoE | `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc. |✅︎| ✅︎ | ✅︎ | +| `ExaoneForCausalLM` | EXAONE-3 | `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Exaone4ForCausalLM` | EXAONE-4 | `LGAI-EXAONE/EXAONE-4.0-32B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Fairseq2LlamaForCausalLM` | Llama (fairseq2 format) | `mgleize/fairseq2-dummy-Llama-3.2-1B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `FalconForCausalLM` | Falcon | `tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc. | | ✅︎ | ✅︎ | +| `FalconMambaForCausalLM` | FalconMamba | `tiiuae/falcon-mamba-7b`, `tiiuae/falcon-mamba-7b-instruct`, etc. | | ✅︎ | ✅︎ | +| `FalconH1ForCausalLM` | Falcon-H1 | `tiiuae/Falcon-H1-34B-Base`, `tiiuae/Falcon-H1-34B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | +| `GlmForCausalLM` | GLM-4 | `THUDM/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4ForCausalLM` | GLM-4-0414 | `THUDM/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GPT2LMHeadModel` | GPT-2 | `gpt2`, `gpt2-xl`, etc. | | ✅︎ | ✅︎ | +| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ | ✅︎ | +| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ | ✅︎ | +| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | ✅︎ | +| `HunYuanDenseV1ForCausalLM` | Hunyuan-7B-Instruct-0124 | `tencent/Hunyuan-7B-Instruct-0124` | ✅︎ | | ✅︎ | +| `HunYuanMoEV1ForCausalLM` | Hunyuan-80B-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | | ✅︎ | +| `InternLMForCausalLM` | InternLM | `internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ | +| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | | +| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | | +| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ | +| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniCPM3ForCausalLM` | MiniCPM3 | `openbmb/MiniCPM3-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MistralForCausalLM` | Mistral, Mistral-Instruct | `mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MixtralForCausalLM` | Mixtral-8x7B, Mixtral-8x7B-Instruct | `mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MPTForCausalLM` | MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter | `mosaicml/mpt-7b`, `mosaicml/mpt-7b-storywriter`, `mosaicml/mpt-30b`, etc. | | ✅︎ | ✅︎ | +| `NemotronForCausalLM` | Nemotron-3, Nemotron-4, Minitron | `nvidia/Minitron-8B-Base`, `mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `NemotronHForCausalLM` | Nemotron-H | `nvidia/Nemotron-H-8B-Base-8K`, `nvidia/Nemotron-H-47B-Base-8K`, `nvidia/Nemotron-H-56B-Base-8K`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `OLMoForCausalLM` | OLMo | `allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc. | | ✅︎ | ✅︎ | +| `OLMo2ForCausalLM` | OLMo2 | `allenai/OLMo-2-0425-1B`, etc. | | ✅︎ | ✅︎ | +| `OLMoEForCausalLM` | OLMoE | `allenai/OLMoE-1B-7B-0924`, `allenai/OLMoE-1B-7B-0924-Instruct`, etc. | | ✅︎ | ✅︎ | +| `OPTForCausalLM` | OPT, OPT-IML | `facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc. | | ✅︎ | ✅︎ | +| `OrionForCausalLM` | Orion | `OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc. | | ✅︎ | ✅︎ | +| `PhiForCausalLM` | Phi | `microsoft/phi-1_5`, `microsoft/phi-2`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | | +| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ | +| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | | +| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForCausalLM` | QwQ, Qwen2 | `Qwen/QwQ-32B-Preview`, `Qwen/Qwen2-7B-Instruct`, `Qwen/Qwen2-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ | +| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ | +| `SolarForCausalLM` | Solar Pro | `upstage/solar-pro-preview-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | | +| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | | +| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ | !!! note Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. @@ -412,19 +425,19 @@ See [this page](./pooling_models.md) for more information on how to use pooling Specified using `--task embed`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|--------------------------------------------------------|---------------------|---------------------------------------------------------------------------------------------------------------------|----------------------|---------------------------|-----------------------| -| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | -| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | -| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | -| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | ︎ | | | -| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | ︎ | ︎ | | -| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | ︎ | ︎ | | -| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | ︎ | ︎ | | -| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `BertModel` | BERT-based | `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. | | | | +| `Gemma2Model` | Gemma 2-based | `BAAI/bge-multilingual-gemma2`, etc. | ✅︎ | | ✅︎ | +| `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | +| `GteModel` | Arctic-Embed-2.0-M | `Snowflake/snowflake-arctic-embed-m-v2.0`. | | | | +| `GteNewModel` | mGTE-TRM (see note) | `Alibaba-NLP/gte-multilingual-base`, etc. | | | | +| `ModernBertModel` | ModernBERT-based | `Alibaba-NLP/gte-modernbert-base`, etc. | | | | +| `NomicBertModel` | Nomic BERT | `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. | | | | +| `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. | Llama-based | `intfloat/e5-mistral-7b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2Model`, `Qwen2ForCausalLM` | Qwen2-based | `ssmits/Qwen2-7B-Instruct-embed-base` (see note), `Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen3Model`, `Qwen3ForCausalLM` | Qwen3-based | `Qwen/Qwen3-Embedding-0.6B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `RobertaModel`, `RobertaForMaskedLM` | RoBERTa-based | `sentence-transformers/all-roberta-large-v1`, etc. | | | | !!! note `ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config. @@ -448,12 +461,12 @@ of the whole prompt are extracted from the normalized hidden state corresponding Specified using `--task reward`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|---------------------------|-----------------|------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `InternLM2ForRewardModel` | InternLM2-based | `internlm/internlm2-1_8b-reward`, `internlm/internlm2-7b-reward`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `LlamaForCausalLM` | Llama-based | `peiyi9979/math-shepherd-mistral-7b-prm`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-RM-72B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2ForProcessRewardModel` | Qwen2-based | `Qwen/Qwen2.5-Math-PRM-7B`, etc. | ✅︎ | ✅︎ | ✅︎ | If your model is not in the above list, we will try to automatically convert the model using [as_reward_model][vllm.model_executor.models.adapters.as_reward_model]. By default, we return the hidden states of each token directly. @@ -466,10 +479,10 @@ If your model is not in the above list, we will try to automatically convert the Specified using `--task classify`. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|----------------------------------|----------|----------------------------------------|------------------------|-----------------------------|-----------------------| -| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | -| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `JambaForSequenceClassification` | Jamba | `ai21labs/Jamba-tiny-reward-dev`, etc. | ✅︎ | ✅︎ | | +| `GPT2ForSequenceClassification` | GPT2 | `nie3e/sentiment-polish-gpt2-small` | | | ✅︎ | If your model is not in the above list, we will try to automatically convert the model using [as_seq_cls_model][vllm.model_executor.models.adapters.as_seq_cls_model]. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. @@ -478,13 +491,21 @@ If your model is not in the above list, we will try to automatically convert the Specified using `--task score`. -| Architecture | Models | Example HF Models | [V1](gh-issue:8779) | -|---------------------------------------|-------------------|--------------------------------------------------------------------------------------|---------------------| -| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | -| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | -| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | -| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | -| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | +| Architecture | Models | Example HF Models | [V1](gh-issue:8779) | +|--------------|--------|-------------------|---------------------| +| `BertForSequenceClassification` | BERT-based | `cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. | | +| `GemmaForSequenceClassification` | Gemma-based | `BAAI/bge-reranker-v2-gemma` (see note), etc. | | +| `Qwen2ForSequenceClassification` | Qwen2-based | `mixedbread-ai/mxbai-rerank-base-v2` (see note), etc. | ✅︎ | +| `Qwen3ForSequenceClassification` | Qwen3-based | `tomaarsen/Qwen3-Reranker-0.6B-seq-cls`, `Qwen/Qwen3-Reranker-0.6B` (see note), etc. | ✅︎ | +| `RobertaForSequenceClassification` | RoBERTa-based | `cross-encoder/quora-roberta-base`, etc. | | +| `XLMRobertaForSequenceClassification` | XLM-RoBERTa-based | `BAAI/bge-reranker-v2-m3`, etc. | | + +!!! note + Load the official original `BAAI/bge-reranker-v2-gemma` by using the following command. + + ```bash + vllm serve BAAI/bge-reranker-v2-gemma --hf_overrides '{"architectures": ["GemmaForSequenceClassification"],"classifier_from_token": ["Yes"],"method": "no_post_processing"}' + ``` !!! note Load the official original `mxbai-rerank-v2` by using the following command. @@ -519,7 +540,7 @@ On the other hand, modalities separated by `/` are mutually exclusive. - e.g.: `T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. -See [this page][multimodal-inputs] on how to pass multi-modal inputs to the model. +See [this page](../features/multimodal_inputs.md) on how to pass multi-modal inputs to the model. !!! important **To enable multiple multi-modal items per text prompt in vLLM V0**, you have to set `limit_mm_per_prompt` (offline inference) @@ -549,56 +570,58 @@ See [this page][multimodal-inputs] on how to pass multi-modal inputs to the mode ### Generative Models -See [this page][generative-models] for more information on how to use generative models. +See [this page](generative_models.md) for more information on how to use generative models. #### Text Generation Specified using `--task generate`. -| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|----------------------------------------------|--------------------------------------------------------------------------|-----------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ | -| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | -| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | -| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b` etc. | | ✅︎ | ✅︎ | -| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2` etc. | | ✅︎ | ✅︎ | -| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large` etc. | | | | -| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b` etc. | | ✅︎ | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | -| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220` etc. | ✅︎ | ✅︎ | ✅︎ | -| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinkg`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | -| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎\* | -| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3` etc. | ✅︎ | | ✅︎ | -| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | -| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | -| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | -| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ | -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | -| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | -| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | -| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | -| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | -| `Mistral3ForConditionalGeneration` | Mistral3 | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | -| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | -| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | -| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | -| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | -| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | -| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | -| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎\* | -| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | -| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | -| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`,`omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | -| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`,`omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `AriaForConditionalGeneration` | Aria | T + I<sup>+</sup> | `rhymes-ai/Aria` | | | ✅︎ | +| `AyaVisionForConditionalGeneration` | Aya Vision | T + I<sup>+</sup> | `CohereForAI/aya-vision-8b`, `CohereForAI/aya-vision-32b`, etc. | | ✅︎ | ✅︎ | +| `Blip2ForConditionalGeneration` | BLIP-2 | T + I<sup>E</sup> | `Salesforce/blip2-opt-2.7b`, `Salesforce/blip2-opt-6.7b`, etc. | | ✅︎ | ✅︎ | +| `ChameleonForConditionalGeneration` | Chameleon | T + I | `facebook/chameleon-7b`, etc. | | ✅︎ | ✅︎ | +| `DeepseekVLV2ForCausalLM`<sup>^</sup> | DeepSeek-VL2 | T + I<sup>+</sup> | `deepseek-ai/deepseek-vl2-tiny`, `deepseek-ai/deepseek-vl2-small`, `deepseek-ai/deepseek-vl2`, etc. | | ✅︎ | ✅︎ | +| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | | +| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | +| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ | +| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ | +| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ | +| `InternVLChatModel` | InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | +| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ | +| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | +| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | +| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ | +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ | +| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ | +| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ | +| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ | +| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | | +| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ | +| `Ovis` | Ovis2, Ovis1.6 | T + I<sup>+</sup> | `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. | | ✅︎ | ✅︎ | +| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ | +| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ | +| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ | +| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ | +| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2_5_VLForConditionalGeneration` | Qwen2.5-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ | +| `Qwen2_5OmniThinkerForConditionalGeneration` | Qwen2.5-Omni | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup> | `Qwen/Qwen2.5-Omni-7B` | | ✅︎ | ✅︎ | +| `SkyworkR1VChatModel` | Skywork-R1V-38B | T + I | `Skywork/Skywork-R1V-38B` | | ✅︎ | ✅︎ | +| `SmolVLMForConditionalGeneration` | SmolVLM2 | T + I | `SmolVLM2-2.2B-Instruct` | ✅︎ | | ✅︎ | +| `TarsierForConditionalGeneration` | Tarsier | T + I<sup>E+</sup> | `omni-search/Tarsier-7b`, `omni-search/Tarsier-34b` | | ✅︎ | ✅︎ | +| `Tarsier2ForConditionalGeneration`<sup>^</sup> | Tarsier2 | T + I<sup>E+</sup> + V<sup>E+</sup> | `omni-research/Tarsier2-Recap-7b`, `omni-research/Tarsier2-7b-0115` | | ✅︎ | ✅︎ | <sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.     • For example, to use DeepSeek-VL2 series models: @@ -634,7 +657,7 @@ Specified using `--task generate`. For the best results, we recommend using the following dependency versions (tested on A10 and L40): - ??? Dependency versions + ??? code "Dependency versions" ```text # Core vLLM-compatible dependencies with Molmo accuracy setup (tested on L40) @@ -677,9 +700,9 @@ Specified using `--task transcription`. Speech2Text models trained specifically for Automatic Speech Recognition. -| Architecture | Models | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|----------------------------------------------|------------------|------------------------------------------------------------------|------------------------|-----------------------------|-----------------------| -| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | +| Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|-------------------|----------------------|---------------------------|---------------------| +| `WhisperForConditionalGeneration` | Whisper | `openai/whisper-small`, `openai/whisper-large-v3-turbo`, etc. | | | | ### Pooling Models @@ -700,13 +723,21 @@ Any text generation model can be converted into an embedding model by passing `- The following table lists those that are tested in vLLM. -| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | -|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------| -| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | -| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | | +| Architecture | Models | Inputs | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/distributed_serving.md) | [V1](gh-issue:8779) | +|--------------|--------|--------|-------------------|----------------------|---------------------------|---------------------| +| `LlavaNextForConditionalGeneration` | LLaVA-NeXT-based | T / I | `royokong/e5-v` | | | | +| `Phi3VForCausalLM` | Phi-3-Vision-based | T + I | `TIGER-Lab/VLM2Vec-Full` | 🚧 | ✅︎ | | --- +#### Scoring + +Specified using `--task score`. + +| Architecture | Models | Inputs | Example HF Models | [LoRA][lora-adapter] | [PP][distributed-serving] | [V1](gh-issue:8779) | +|-------------------------------------|--------------------|----------|--------------------------|------------------------|-----------------------------|-----------------------| +| `JinaVLForSequenceClassification` | JinaVL-based | T + I<sup>E+</sup> | `jinaai/jina-reranker-m0`, etc. | | | ✅︎ | + ## Model Support Policy At vLLM, we are committed to facilitating the integration and support of third-party models within our ecosystem. Our approach is designed to balance the need for robustness and the practical limitations of supporting a wide range of models. Here’s how we manage third-party model support: diff --git a/docs/serving/data_parallel_deployment.md b/docs/serving/data_parallel_deployment.md new file mode 100644 index 000000000000..9ff9f59c54e5 --- /dev/null +++ b/docs/serving/data_parallel_deployment.md @@ -0,0 +1,120 @@ +# Data Parallel Deployment + +vLLM supports Data Parallel deployment, where model weights are replicated across separate instances/GPUs to process independent batches of requests. + +This will work with both dense and MoE models. + +For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Latent Attention), it can be advantageous to use data parallel for the attention layers and expert or tensor parallel (EP or TP) for the expert layers. + +In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks. + +The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case). + +In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size. + +For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP). + +In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently. + +This document focuses on online deployments (with the API server). DP + EP is also supported for offline usage (via the LLM class), for an example see <gh-file:examples/offline_inference/data_parallel.py>. + +There are two distinct modes supported for online deployments - self-contained with internal load balancing, or externally per-rank process deployment and load balancing. + +## Internal Load Balancing + +vLLM supports "self-contained" data parallel deployments that expose a single API endpoint. + +It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs. + +Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks. + +This will run DP=4, TP=2 on a single 8-GPU node: + +```bash +vllm serve $MODEL --data-parallel-size 4 --tensor-parallel-size 2 +``` + +This will run DP=4 with DP ranks 0 and 1 on the head node and ranks 2 and 3 on the second node: + +```bash +# Node 0 (with ip address 10.99.48.128) +vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 2 \ + --data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345 +# Node 1 +vllm serve $MODEL --headless --data-parallel-size 4 --data-parallel-size-local 2 \ + --data-parallel-start-rank 2 \ + --data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345 +``` + +This will run DP=4 with only the API server on the first node and all engines on the second node: + +```bash +# Node 0 (with ip address 10.99.48.128) +vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 0 \ + --data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345 +# Node 1 +vllm serve $MODEL --headless --data-parallel-size 4 --data-parallel-size-local 4 \ + --data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345 +``` + +This DP mode can also be used with Ray by specifying `--data-parallel-backend=ray`: + +```bash +vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 2 \ + --data-parallel-backend=ray +``` + +There are several notable differences when using Ray: + +- A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node +- There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address` +- There is no need to specify `--data-parallel-rpc-port` +- Remote DP ranks will be allocated based on node resources of the Ray cluster + +Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic. + +When deploying large DP sizes using this method, the API server process can become a bottleneck. In this case, the orthogonal `--api-server-count` command line option can be used to scale this out (for example `--api-server-count=4`). This is transparent to users - a single HTTP endpoint / port is still exposed. Note that this API server scale-out is "internal" and still confined to the "head" node. + +<figure markdown="1"> +![DP Internal LB Diagram](../assets/deployment/dp_internal_lb.png) +</figure> + +## External Load Balancing + +For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally. + +In this case, it's more convenient to treat each DP rank like a separate vLLM deployment, with its own endpoint, and have an external router balance HTTP requests between them, making use of appropriate real-time telemetry from each server for routing decisions. + +This can already be done trivially for non-MoE models, since each deployed server is fully independent. No data parallel CLI options need to be used for this. + +We support an equivalent topology for MoE DP+EP which can be configured via the following CLI arguments. + +If DP ranks are co-located (same node / ip address), a default RPC port is used, but a different HTTP server port must be specified for each rank: + +```bash +# Rank 0 +CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 0 \ + --port 8000 +# Rank 1 +CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 1 \ + --port 8001 +``` + +For multi-node cases, the address/port of rank 0 must also be specified: + +```bash +# Rank 0 (with ip address 10.99.48.128) +vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 0 \ + --data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345 +# Rank 1 +vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 1 \ + --data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345 +``` + +The coordinator process also runs in this scenario, co-located with the DP rank 0 engine. + +<figure markdown="1"> +![DP External LB Diagram](../assets/deployment/dp_external_lb.png) +</figure> + +In the above diagram, each of the dotted boxes corresponds to a separate launch of `vllm serve` - these could be separate Kubernetes pods, for example. diff --git a/docs/serving/distributed_serving.md b/docs/serving/distributed_serving.md index 6665955411ad..a1f522cc5f14 100644 --- a/docs/serving/distributed_serving.md +++ b/docs/serving/distributed_serving.md @@ -1,7 +1,4 @@ ---- -title: Distributed Inference and Serving ---- -[](){ #distributed-serving } +# Distributed Inference and Serving ## How to decide the distributed inference strategy? @@ -18,6 +15,10 @@ After adding enough GPUs and nodes to hold the model, you can run vLLM first, wh !!! note There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs. +### Distributed serving of MoE (Mixture of Experts) models + +It is often advantageous to exploit the inherent parallelism of experts by using a separate parallelism strategy for the expert layers. vLLM supports large-scale deployment combining Data Parallel attention with Expert or Tensor Parallel MoE layers. See the page on [Data Parallel Deployment](data_parallel_deployment.md) for more information. + ## Running vLLM on a single node vLLM supports distributed tensor-parallel and pipeline-parallel inference and serving. Currently, we support [Megatron-LM's tensor parallel algorithm](https://arxiv.org/pdf/1909.08053.pdf). We manage the distributed runtime with either [Ray](https://github.com/ray-project/ray) or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inference currently requires Ray. diff --git a/docs/serving/integrations/langchain.md b/docs/serving/integrations/langchain.md index 1a24ab29c19c..47074f411ac9 100644 --- a/docs/serving/integrations/langchain.md +++ b/docs/serving/integrations/langchain.md @@ -1,7 +1,4 @@ ---- -title: LangChain ---- -[](){ #serving-langchain } +# LangChain vLLM is also available via [LangChain](https://github.com/langchain-ai/langchain) . @@ -13,7 +10,7 @@ pip install langchain langchain_community -q To run inference on a single or multiple GPUs, use `VLLM` class from `langchain`. -??? Code +??? code ```python from langchain_community.llms import VLLM diff --git a/docs/serving/integrations/llamaindex.md b/docs/serving/integrations/llamaindex.md index 4feed63bd46b..4b838cbcaa9d 100644 --- a/docs/serving/integrations/llamaindex.md +++ b/docs/serving/integrations/llamaindex.md @@ -1,7 +1,4 @@ ---- -title: LlamaIndex ---- -[](){ #serving-llamaindex } +# LlamaIndex vLLM is also available via [LlamaIndex](https://github.com/run-llama/llama_index) . diff --git a/docs/serving/offline_inference.md b/docs/serving/offline_inference.md index b238199e4144..ddda47690002 100644 --- a/docs/serving/offline_inference.md +++ b/docs/serving/offline_inference.md @@ -1,12 +1,6 @@ ---- -title: Offline Inference ---- -[](){ #offline-inference } +# Offline Inference -You can run vLLM in your own code on a list of prompts. - -The offline API is based on the [LLM][vllm.LLM] class. -To initialize the vLLM engine, create a new instance of `LLM` and specify the model to run. +Offline inference is possible in your own code using vLLM's [`LLM`][vllm.LLM] class. For example, the following code downloads the [`facebook/opt-125m`](https://huggingface.co/facebook/opt-125m) model from HuggingFace and runs it in vLLM using the default configuration. @@ -14,16 +8,53 @@ and runs it in vLLM using the default configuration. ```python from vllm import LLM +# Initialize the vLLM engine. llm = LLM(model="facebook/opt-125m") ``` -After initializing the `LLM` instance, you can perform model inference using various APIs. -The available APIs depend on the type of model that is being run: - -- [Generative models][generative-models] output logprobs which are sampled from to obtain the final output text. -- [Pooling models][pooling-models] output their hidden states directly. +After initializing the `LLM` instance, use the available APIs to perform model inference. +The available APIs depend on the model type: -Please refer to the above pages for more details about each API. +- [Generative models](../models/generative_models.md) output logprobs which are sampled from to obtain the final output text. +- [Pooling models](../models/pooling_models.md) output their hidden states directly. !!! info [API Reference][offline-inference-api] + +## Ray Data LLM API + +Ray Data LLM is an alternative offline inference API that uses vLLM as the underlying engine. +This API adds several batteries-included capabilities that simplify large-scale, GPU-efficient inference: + +- Streaming execution processes datasets that exceed aggregate cluster memory. +- Automatic sharding, load balancing, and autoscaling distribute work across a Ray cluster with built-in fault tolerance. +- Continuous batching keeps vLLM replicas saturated and maximizes GPU utilization. +- Transparent support for tensor and pipeline parallelism enables efficient multi-GPU inference. +- Reading and writing to most popular file formats and cloud object storage. +- Scaling up the workload without code changes. + +??? code + + ```python + import ray # Requires ray>=2.44.1 + from ray.data.llm import vLLMEngineProcessorConfig, build_llm_processor + + config = vLLMEngineProcessorConfig(model_source="unsloth/Llama-3.2-1B-Instruct") + processor = build_llm_processor( + config, + preprocess=lambda row: { + "messages": [ + {"role": "system", "content": "You are a bot that completes unfinished haikus."}, + {"role": "user", "content": row["item"]}, + ], + "sampling_params": {"temperature": 0.3, "max_tokens": 250}, + }, + postprocess=lambda row: {"answer": row["generated_text"]}, + ) + + ds = ray.data.from_items(["An old silent pond..."]) + ds = processor(ds) + ds.write_parquet("local:///tmp/data/") + ``` + +For more information about the Ray Data LLM API, see the [Ray Data LLM documentation](https://docs.ray.io/en/latest/data/working-with-llms.html). diff --git a/docs/serving/openai_compatible_server.md b/docs/serving/openai_compatible_server.md index 5371e45d8052..edec40f41760 100644 --- a/docs/serving/openai_compatible_server.md +++ b/docs/serving/openai_compatible_server.md @@ -1,11 +1,8 @@ ---- -title: OpenAI-Compatible Server ---- -[](){ #serving-openai-compatible-server } +# OpenAI-Compatible Server vLLM provides an HTTP server that implements OpenAI's [Completions API](https://platform.openai.com/docs/api-reference/completions), [Chat API](https://platform.openai.com/docs/api-reference/chat), and more! This functionality lets you serve models and interact with them using an HTTP client. -In your terminal, you can [install](../getting_started/installation/README.md) vLLM, then start the server with the [`vllm serve`][serve-args] command. (You can also use our [Docker][deployment-docker] image.) +In your terminal, you can [install](../getting_started/installation/README.md) vLLM, then start the server with the [`vllm serve`](../configuration/serve_args.md) command. (You can also use our [Docker](../deployment/docker.md) image.) ```bash vllm serve NousResearch/Meta-Llama-3-8B-Instruct \ @@ -15,7 +12,7 @@ vllm serve NousResearch/Meta-Llama-3-8B-Instruct \ To call the server, in your preferred text editor, create a script that uses an HTTP client. Include any messages that you want to send to the model. Then run that script. Below is an example script using the [official OpenAI Python client](https://github.com/openai/openai-python). -??? Code +??? code ```python from openai import OpenAI @@ -146,7 +143,7 @@ completion = client.chat.completions.create( Only `X-Request-Id` HTTP request header is supported for now. It can be enabled with `--enable-request-id-headers`. -??? Code +??? code ```python completion = client.chat.completions.create( @@ -185,7 +182,7 @@ Code example: <gh-file:examples/online_serving/openai_completion_client.py> The following [sampling parameters][sampling-params] are supported. -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:completion-sampling-params" @@ -193,7 +190,7 @@ The following [sampling parameters][sampling-params] are supported. The following extra parameters are supported: -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:completion-extra-params" @@ -208,7 +205,7 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai We support both [Vision](https://platform.openai.com/docs/guides/vision)- and [Audio](https://platform.openai.com/docs/guides/audio?audio-generation-quickstart-example=audio-in)-related parameters; -see our [Multimodal Inputs][multimodal-inputs] guide for more information. +see our [Multimodal Inputs](../features/multimodal_inputs.md) guide for more information. - *Note: `image_url.detail` parameter is not supported.* Code example: <gh-file:examples/online_serving/openai_chat_completion_client.py> @@ -217,7 +214,7 @@ Code example: <gh-file:examples/online_serving/openai_chat_completion_client.py> The following [sampling parameters][sampling-params] are supported. -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-sampling-params" @@ -225,7 +222,7 @@ The following [sampling parameters][sampling-params] are supported. The following extra parameters are supported: -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:chat-completion-extra-params" @@ -268,7 +265,7 @@ and passing a list of `messages` in the request. Refer to the examples below for Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: - ??? Code + ??? code ```python import requests @@ -327,7 +324,7 @@ The following [pooling parameters][pooling-params] are supported. The following extra parameters are supported by default: -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:embedding-extra-params" @@ -335,7 +332,7 @@ The following extra parameters are supported by default: For chat-like input (i.e. if `messages` is passed), these extra parameters are supported instead: -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:chat-embedding-extra-params" @@ -354,11 +351,16 @@ you can use the [official OpenAI Python client](https://github.com/openai/openai Code example: <gh-file:examples/online_serving/openai_transcription_client.py> <!-- TODO: api enforced limits + uploading audios --> +#### API Enforced Limits + +Set the maximum audio file size (in MB) that VLLM will accept, via the +`VLLM_MAX_AUDIO_CLIP_FILESIZE_MB` environment variable. Default is 25 MB. + #### Extra Parameters The following [sampling parameters][sampling-params] are supported. -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:transcription-sampling-params" @@ -366,7 +368,7 @@ The following [sampling parameters][sampling-params] are supported. The following extra parameters are supported: -??? Code +??? code ```python --8<-- "vllm/entrypoints/openai/protocol.py:transcription-extra-params" @@ -446,9 +448,9 @@ curl -v "http://127.0.0.1:8000/classify" \ }' ``` -??? Response +??? console "Response" - ```bash + ```json { "id": "classify-7c87cac407b749a6935d8c7ce2a8fba2", "object": "list", @@ -494,9 +496,9 @@ curl -v "http://127.0.0.1:8000/classify" \ }' ``` -??? Response +??? console "Response" - ```bash + ```json { "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682", "object": "list", @@ -540,7 +542,7 @@ The following extra parameters are supported: ### Score API -Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair. +Our Score API can apply a cross-encoder model or an embedding model to predict scores for sentence or multimodal pairs. When using an embedding model the score corresponds to the cosine similarity between each embedding pair. Usually, the score for a sentence pair refers to the similarity between two sentences, on a scale of 0 to 1. You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). @@ -564,9 +566,9 @@ curl -X 'POST' \ }' ``` -??? Response +??? console "Response" - ```bash + ```json { "id": "score-request-id", "object": "list", @@ -589,7 +591,7 @@ You can pass a string to `text_1` and a list to `text_2`, forming multiple sente where each pair is built from `text_1` and a string in `text_2`. The total number of pairs is `len(text_2)`. -??? Request +??? console "Request" ```bash curl -X 'POST' \ @@ -606,9 +608,9 @@ The total number of pairs is `len(text_2)`. }' ``` -??? Response +??? console "Response" - ```bash + ```json { "id": "score-request-id", "object": "list", @@ -634,7 +636,7 @@ You can pass a list to both `text_1` and `text_2`, forming multiple sentence pai where each pair is built from a string in `text_1` and the corresponding string in `text_2` (similar to `zip()`). The total number of pairs is `len(text_2)`. -??? Request +??? console "Request" ```bash curl -X 'POST' \ @@ -655,9 +657,9 @@ The total number of pairs is `len(text_2)`. }' ``` -??? Response +??? console "Response" - ```bash + ```json { "id": "score-request-id", "object": "list", @@ -679,6 +681,55 @@ The total number of pairs is `len(text_2)`. } ``` +#### Multi-modal inputs + +You can pass multi-modal inputs to scoring models by passing `content` including a list of multi-modal input (image, etc.) in the request. Refer to the examples below for illustration. + +=== "JinaVL-Reranker" + + To serve the model: + + ```bash + vllm serve jinaai/jina-reranker-m0 + ``` + + Since the request schema is not defined by OpenAI client, we post a request to the server using the lower-level `requests` library: + + ??? Code + + ```python + import requests + + response = requests.post( + "http://localhost:8000/v1/score", + json={ + "model": "jinaai/jina-reranker-m0", + "text_1": "slm markdown", + "text_2": { + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + }, + }, + ] + } + }, + ) + response.raise_for_status() + response_json = response.json() + print("Scoring output:", response_json["data"][0]["score"]) + print("Scoring output:", response_json["data"][1]["score"]) + ``` +Full example: <gh-file:examples/online_serving/openai_cross_encoder_score_for_multimodal.py> + #### Extra parameters The following [pooling parameters][pooling-params] are supported. @@ -698,8 +749,7 @@ The following extra parameters are supported: ### Re-rank API Our Re-rank API can apply an embedding model or a cross-encoder model to predict relevant scores between a single query, and -each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences, on -a scale of 0 to 1. +each of a list of documents. Usually, the score for a sentence pair refers to the similarity between two sentences or multi-modal inputs (image, etc.), on a scale of 0 to 1. You can find the documentation for cross encoder models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). @@ -716,7 +766,7 @@ Code example: <gh-file:examples/online_serving/jinaai_rerank_client.py> Note that the `top_n` request parameter is optional and will default to the length of the `documents` field. Result documents will be sorted by relevance, and the `index` property can be used to determine original order. -??? Request +??? console "Request" ```bash curl -X 'POST' \ @@ -734,9 +784,9 @@ Result documents will be sorted by relevance, and the `index` property can be us }' ``` -??? Response +??? console "Response" - ```bash + ```json { "id": "rerank-fae51b2b664d4ed38f5969b612edff77", "model": "BAAI/bge-reranker-base", @@ -775,3 +825,17 @@ The following extra parameters are supported: ```python --8<-- "vllm/entrypoints/openai/protocol.py:rerank-extra-params" ``` + +## Ray Serve LLM + +Ray Serve LLM enables scalable, production-grade serving of the vLLM engine. It integrates tightly with vLLM and extends it with features such as auto-scaling, load balancing, and back-pressure. + +Key capabilities: + +- Exposes an OpenAI-compatible HTTP API as well as a Pythonic API. +- Scales from a single GPU to a multi-node cluster without code changes. +- Provides observability and autoscaling policies through Ray dashboards and metrics. + +The following example shows how to deploy a large model like DeepSeek R1 with Ray Serve LLM: <gh-file:examples/online_serving/ray_serve_deepseek.py>. + +Learn more about Ray Serve LLM with the official [Ray Serve LLM documentation](https://docs.ray.io/en/latest/serve/llm/serving-llms.html). diff --git a/docs/usage/faq.md b/docs/usage/faq.md index 51977d4434f5..2c8680cb6f7b 100644 --- a/docs/usage/faq.md +++ b/docs/usage/faq.md @@ -1,7 +1,4 @@ ---- -title: Frequently Asked Questions ---- -[](){ #faq } +# Frequently Asked Questions > Q: How can I serve multiple models on a single port using the OpenAI API? @@ -12,7 +9,7 @@ A: Assuming that you're referring to using OpenAI compatible server to serve mul > Q: Which model to use for offline inference embedding? A: You can try [e5-mistral-7b-instruct](https://huggingface.co/intfloat/e5-mistral-7b-instruct) and [BAAI/bge-base-en-v1.5](https://huggingface.co/BAAI/bge-base-en-v1.5); -more are listed [here][supported-models]. +more are listed [here](../models/supported_models.md). By extracting hidden states, vLLM can automatically convert text generation models like [Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B), [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) into embedding models, diff --git a/docs/usage/metrics.md b/docs/usage/metrics.md index 4350ab5025f5..d756e32476f0 100644 --- a/docs/usage/metrics.md +++ b/docs/usage/metrics.md @@ -4,7 +4,7 @@ vLLM exposes a number of metrics that can be used to monitor the health of the system. These metrics are exposed via the `/metrics` endpoint on the vLLM OpenAI compatible API server. -You can start the server using Python, or using [Docker][deployment-docker]: +You can start the server using Python, or using [Docker](../deployment/docker.md): ```bash vllm serve unsloth/Llama-3.2-1B-Instruct @@ -12,7 +12,7 @@ vllm serve unsloth/Llama-3.2-1B-Instruct Then query the endpoint to get the latest metrics from the server: -??? Output +??? console "Output" ```console $ curl http://0.0.0.0:8000/metrics @@ -33,7 +33,7 @@ Then query the endpoint to get the latest metrics from the server: The following metrics are exposed: -??? Code +??? code ```python --8<-- "vllm/engine/metrics.py:metrics-definitions" diff --git a/docs/usage/troubleshooting.md b/docs/usage/troubleshooting.md index 7f1f76ce3d2e..f9ba32c58c4e 100644 --- a/docs/usage/troubleshooting.md +++ b/docs/usage/troubleshooting.md @@ -1,7 +1,4 @@ ---- -title: Troubleshooting ---- -[](){ #troubleshooting } +# Troubleshooting This document outlines some troubleshooting strategies you can consider. If you think you've discovered a bug, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. @@ -60,7 +57,7 @@ To identify the particular CUDA operation that causes the error, you can add `-- If GPU/CPU communication cannot be established, you can use the following Python script and follow the instructions below to confirm whether the GPU/CPU communication is working correctly. -??? Code +??? code ```python # Test PyTorch NCCL @@ -170,7 +167,7 @@ WARNING 12-11 14:50:37 multiproc_worker_utils.py:281] CUDA was previously or an error from Python that looks like this: -??? Logs +??? console "Logs" ```console RuntimeError: @@ -212,9 +209,9 @@ if __name__ == '__main__': ## `torch.compile` Error -vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](https://github.com/vllm-project/vllm/pull/10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script: +vLLM heavily depends on `torch.compile` to optimize the model for better performance, which introduces the dependency on the `torch.compile` functionality and the `triton` library. By default, we use `torch.compile` to [optimize some functions](gh-pr:10406) in the model. Before running vLLM, you can check if `torch.compile` is working as expected by running the following script: -??? Code +??? code ```python import torch @@ -231,7 +228,7 @@ vLLM heavily depends on `torch.compile` to optimize the model for better perform print(f(x)) ``` -If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See [this issue](https://github.com/vllm-project/vllm/issues/12219) for example. +If it raises errors from `torch/_inductor` directory, usually it means you have a custom `triton` library that is not compatible with the version of PyTorch you are using. See <gh-issue:12219> for example. ## Model failed to be inspected @@ -267,7 +264,7 @@ or: ValueError: Model architectures ['<arch>'] are not supported for now. Supported architectures: [...] ``` -But you are sure that the model is in the [list of supported models][supported-models], there may be some issue with vLLM's model resolution. In that case, please follow [these steps](../configuration/model_resolution.md) to explicitly specify the vLLM implementation for the model. +But you are sure that the model is in the [list of supported models](../models/supported_models.md), there may be some issue with vLLM's model resolution. In that case, please follow [these steps](../configuration/model_resolution.md) to explicitly specify the vLLM implementation for the model. ## Failed to infer device type diff --git a/docs/usage/usage_stats.md b/docs/usage/usage_stats.md index 78d2a6784bc5..e78c67522f61 100644 --- a/docs/usage/usage_stats.md +++ b/docs/usage/usage_stats.md @@ -10,7 +10,7 @@ The list of data collected by the latest version of vLLM can be found here: <gh- Here is an example as of v0.4.0: -??? Output +??? console "Output" ```json { diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 82a2710d895c..498ff3da0ca3 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -2,7 +2,7 @@ !!! announcement - We have started the process of deprecating V0. Please read [RFC #18571](https://github.com/vllm-project/vllm/issues/18571) for more details. + We have started the process of deprecating V0. Please read [RFC #18571](gh-issue:18571) for more details. V1 is now enabled by default for all supported use cases, and we will gradually enable it for every use case we plan to support. Please share any feedback on [GitHub](https://github.com/vllm-project/vllm) or in the [vLLM Slack](https://inviter.co/vllm-slack). @@ -83,14 +83,14 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the | **Decoder-only Models** | <nobr>🚀 Optimized</nobr> | | **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> | | **Embedding Models** | <nobr>🟢 Functional</nobr> | -| **Mamba Models** | <nobr>🚧 WIP ([PR #19327](https://github.com/vllm-project/vllm/pull/19327))</nobr> | +| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> | | **Multimodal Models** | <nobr>🟢 Functional</nobr> | vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol. !!! tip - This corresponds to the V1 column in our [list of supported models][supported-models]. + This corresponds to the V1 column in our [list of supported models](../models/supported_models.md). See below for the status of models that are not yet supported or have more features planned in V1. @@ -98,14 +98,20 @@ See below for the status of models that are not yet supported or have more featu The initial basic support is now functional. -Later, we will consider using [hidden states processor](https://github.com/vllm-project/vllm/issues/12249), -which is based on [global logits processor](https://github.com/vllm-project/vllm/pull/13360) +Later, we will consider using [hidden states processor](gh-issue:12249), +which is based on [global logits processor](gh-pr:13360) to enable simultaneous generation and embedding using the same engine instance in V1. #### Mamba Models -Models using selective state-space mechanisms instead of standard transformer attention (e.g., `MambaForCausalLM`, `JambaForCausalLM`) -will be supported via [PR #19327](https://github.com/vllm-project/vllm/pull/19327). +Models using selective state-space mechanisms instead of standard transformer attention are partially supported. +Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers +(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require +disabling prefix caching in V1. + +Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, +`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that +these models currently require disabling prefix caching and using the FlashInfer attention backend in V1. #### Encoder-Decoder Models @@ -120,13 +126,13 @@ are not yet supported. | **Chunked Prefill** | <nobr>🚀 Optimized</nobr> | | **LoRA** | <nobr>🚀 Optimized</nobr> | | **Logprobs Calculation** | <nobr>🟢 Functional</nobr> | -| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices ([PR #15191](https://github.com/vllm-project/vllm/pull/15191))</nobr>| +| **FP8 KV Cache** | <nobr>🟢 Functional on Hopper devices (<gh-pr:15191>)</nobr>| | **Spec Decode** | <nobr>🚀 Optimized</nobr> | -| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](https://github.com/vllm-project/vllm/issues/13414))</nobr>| +| **Prompt Logprobs with Prefix Caching** | <nobr>🟡 Planned ([RFC #13414](gh-issue:13414))</nobr>| | **Structured Output Alternative Backends** | <nobr>🟢 Functional</nobr> | | **Request-level Structured Output Backend** | <nobr>🔴 Deprecated</nobr> | -| **best_of** | <nobr>🔴 Deprecated ([RFC #13361](https://github.com/vllm-project/vllm/issues/13361))</nobr>| -| **Per-Request Logits Processors** | <nobr>🔴 Deprecated ([RFC #13360](https://github.com/vllm-project/vllm/pull/13360))</nobr> | +| **best_of** | <nobr>🔴 Deprecated ([RFC #13361](gh-issue:13361))</nobr>| +| **Per-Request Logits Processors** | <nobr>🔴 Deprecated ([RFC #13360](gh-pr:13360))</nobr> | | **GPU <> CPU KV Cache Swapping** | <nobr>🔴 Deprecated</nobr> | !!! note @@ -153,7 +159,7 @@ Support for logprobs with post-sampling adjustments is in progress and will be a **Prompt Logprobs with Prefix Caching** -Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](https://github.com/vllm-project/vllm/issues/13414). +Currently prompt logprobs are only supported when prefix caching is turned off via `--no-enable-prefix-caching`. In a future release, prompt logprobs will be compatible with prefix caching, but a recomputation will be triggered to recover the full prompt logprobs even upon a prefix cache hit. See details in [RFC #13414](gh-issue:13414). #### Deprecated Features @@ -161,11 +167,11 @@ As part of the major architectural rework in vLLM V1, several legacy features ha **Sampling features** -- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](https://github.com/vllm-project/vllm/issues/13361). +- **best_of**: This feature has been deprecated due to limited usage. See details at [RFC #13361](gh-issue:13361). - **Per-Request Logits Processors**: In V0, users could pass custom processing functions to adjust logits on a per-request basis. In vLLM V1, this feature has been deprecated. Instead, the design is moving toward supporting **global logits - processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](https://github.com/vllm-project/vllm/pull/13360). + processors**, a feature the team is actively working on for future releases. See details at [RFC #13360](gh-pr:13360). **KV Cache features** diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index 8e5cac78a4b2..8014cb53f16a 100644 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -10,7 +10,7 @@ import os from dataclasses import asdict -from typing import NamedTuple, Optional +from typing import Any, NamedTuple, Optional from huggingface_hub import snapshot_download from transformers import AutoTokenizer @@ -30,7 +30,9 @@ class ModelRequestData(NamedTuple): engine_args: EngineArgs - prompt: str + prompt: Optional[str] = None + prompt_token_ids: Optional[dict[str, list[int]]] = None + multi_modal_data: Optional[dict[str, Any]] = None stop_token_ids: Optional[list[int]] = None lora_requests: Optional[list[LoRARequest]] = None @@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple): # Unless specified, these settings have been tested to work on a single L4. +# Voxtral +def run_voxtral(question: str, audio_count: int) -> ModelRequestData: + from mistral_common.audio import Audio + from mistral_common.protocol.instruct.messages import ( + AudioChunk, + RawAudio, + TextChunk, + UserMessage, + ) + from mistral_common.protocol.instruct.request import ChatCompletionRequest + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + + model_name = "mistralai/Voxtral-Mini-3B-2507" + tokenizer = MistralTokenizer.from_hf_hub(model_name) + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + limit_mm_per_prompt={"audio": audio_count}, + config_format="mistral", + load_format="mistral", + tokenizer_mode="mistral", + enforce_eager=True, + enable_chunked_prefill=False, + ) + + text_chunk = TextChunk(text=question) + audios = [ + Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) + for i in range(audio_count) + ] + audio_chunks = [ + AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios + ] + + messages = [UserMessage(content=[*audio_chunks, text_chunk])] + + req = ChatCompletionRequest(messages=messages, model=model_name) + + tokens = tokenizer.encode_chat_completion(req) + prompt_ids, audios = tokens.tokens, tokens.audios + + audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios] + + multi_modal_data = {"audio": audios_and_sr} + + return ModelRequestData( + engine_args=engine_args, + prompt_token_ids=prompt_ids, + multi_modal_data=multi_modal_data, + ) + + # Granite Speech def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: # NOTE - the setting in this example are somehat different than what is @@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { + "voxtral": run_voxtral, "granite_speech": run_granite_speech, "minicpmo": run_minicpmo, "phi4_mm": run_phi4mm, @@ -311,16 +368,24 @@ def main(args): temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids ) - mm_data = {} - if audio_count > 0: - mm_data = { - "audio": [ - asset.audio_and_sample_rate for asset in audio_assets[:audio_count] - ] - } + mm_data = req_data.multi_modal_data + if not mm_data: + mm_data = {} + if audio_count > 0: + mm_data = { + "audio": [ + asset.audio_and_sample_rate for asset in audio_assets[:audio_count] + ] + } assert args.num_prompts > 0 - inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data} + inputs = {"multi_modal_data": mm_data} + + if req_data.prompt: + inputs["prompt"] = req_data.prompt + else: + inputs["prompt_token_ids"] = req_data.prompt_token_ids + if args.num_prompts > 1: # Batch inference inputs = [inputs] * args.num_prompts diff --git a/examples/offline_inference/basic/classify.py b/examples/offline_inference/basic/classify.py index 219064e97429..aaf0e83c9dee 100644 --- a/examples/offline_inference/basic/classify.py +++ b/examples/offline_inference/basic/classify.py @@ -28,10 +28,10 @@ def main(args: Namespace): # Create an LLM. # You should pass task="classify" for classification models - model = LLM(**vars(args)) + llm = LLM(**vars(args)) # Generate logits. The output is a list of ClassificationRequestOutputs. - outputs = model.classify(prompts) + outputs = llm.classify(prompts) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) diff --git a/examples/offline_inference/basic/embed.py b/examples/offline_inference/basic/embed.py index 1114033d5cea..7ff9c7f5e0eb 100644 --- a/examples/offline_inference/basic/embed.py +++ b/examples/offline_inference/basic/embed.py @@ -31,10 +31,10 @@ def main(args: Namespace): # Create an LLM. # You should pass task="embed" for embedding models - model = LLM(**vars(args)) + llm = LLM(**vars(args)) # Generate embedding. The output is a list of EmbeddingRequestOutputs. - outputs = model.embed(prompts) + outputs = llm.embed(prompts) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) diff --git a/examples/offline_inference/basic/score.py b/examples/offline_inference/basic/score.py index 6a08de2d2c38..d37527b0a131 100644 --- a/examples/offline_inference/basic/score.py +++ b/examples/offline_inference/basic/score.py @@ -27,10 +27,10 @@ def main(args: Namespace): # Create an LLM. # You should pass task="score" for cross-encoder models - model = LLM(**vars(args)) + llm = LLM(**vars(args)) # Generate scores. The output is a list of ScoringRequestOutputs. - outputs = model.score(text_1, texts_2) + outputs = llm.score(text_1, texts_2) # Print the outputs. print("\nGenerated Outputs:\n" + "-" * 60) diff --git a/examples/offline_inference/batch_llm_inference.py b/examples/offline_inference/batch_llm_inference.py index b1c1ef620da8..22408dc95033 100644 --- a/examples/offline_inference/batch_llm_inference.py +++ b/examples/offline_inference/batch_llm_inference.py @@ -3,17 +3,19 @@ """ This example shows how to use Ray Data for data parallel batch inference. -Ray Data is a data processing framework that can handle large datasets -and integrates tightly with vLLM for data-parallel inference. - -As of Ray 2.44, Ray Data has a native integration with -vLLM (under ray.data.llm). +Ray Data is a data processing framework that can process very large datasets +with first-class support for vLLM. Ray Data provides functionality for: -* Reading and writing to cloud storage (S3, GCS, etc.) -* Automatic sharding and load-balancing across a cluster -* Optimized configuration of vLLM using continuous batching -* Compatible with tensor/pipeline parallel inference as well. +* Reading and writing to most popular file formats and cloud object storage. +* Streaming execution, so you can run inference on datasets that far exceed + the aggregate RAM of the cluster. +* Scale up the workload without code changes. +* Automatic sharding, load-balancing, and autoscaling across a Ray cluster, + with built-in fault-tolerance and retry semantics. +* Continuous batching that keeps vLLM replicas saturated and maximizes GPU + utilization. +* Compatible with tensor/pipeline parallel inference. Learn more about Ray Data's LLM integration: https://docs.ray.io/en/latest/data/working-with-llms.html diff --git a/examples/offline_inference/convert_model_to_seq_cls.py b/examples/offline_inference/convert_model_to_seq_cls.py new file mode 100644 index 000000000000..72356020330f --- /dev/null +++ b/examples/offline_inference/convert_model_to_seq_cls.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import argparse +import json + +import torch +import transformers + +# Usage: +# for BAAI/bge-reranker-v2-gemma +# Caution: "Yes" and "yes" are two different tokens +# python convert_model_to_seq_cls.py --model_name BAAI/bge-reranker-v2-gemma --classifier_from_tokens '["Yes"]' --method no_post_processing --path ./bge-reranker-v2-gemma-seq-cls +# for mxbai-rerank-v2 +# python convert_model_to_seq_cls.py --model_name mixedbread-ai/mxbai-rerank-base-v2 --classifier_from_tokens '["0", "1"]' --method from_2_way_softmax --path ./mxbai-rerank-base-v2-seq-cls +# for Qwen3-Reranker +# python convert_model_to_seq_cls.py --model_name Qwen/Qwen3-Reranker-0.6B --classifier_from_tokens '["no", "yes"]' --method from_2_way_softmax --path ./Qwen3-Reranker-0.6B-seq-cls + + +def from_2_way_softmax(causal_lm, seq_cls_model, tokenizer, tokens, device): + # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 + assert len(tokens) == 2 + + lm_head_weights = causal_lm.lm_head.weight + + false_id = tokenizer.convert_tokens_to_ids(tokens[0]) + true_id = tokenizer.convert_tokens_to_ids(tokens[1]) + + score_weight = lm_head_weights[true_id].to(device).to( + torch.float32 + ) - lm_head_weights[false_id].to(device).to(torch.float32) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight.unsqueeze(0)) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +def no_post_processing(causal_lm, seq_cls_model, tokenizer, tokens, device): + lm_head_weights = causal_lm.lm_head.weight + + token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] + + score_weight = lm_head_weights[token_ids].to(device) + + with torch.no_grad(): + seq_cls_model.score.weight.copy_(score_weight) + if seq_cls_model.score.bias is not None: + seq_cls_model.score.bias.zero_() + + +method_map = { + function.__name__: function for function in [from_2_way_softmax, no_post_processing] +} + + +def converting( + model_name, classifier_from_tokens, path, method, use_pad_token=False, device="cpu" +): + assert method in method_map + + if method == "from_2_way_softmax": + assert len(classifier_from_tokens) == 2 + num_labels = 1 + else: + num_labels = len(classifier_from_tokens) + + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + causal_lm = transformers.AutoModelForCausalLM.from_pretrained( + model_name, device_map=device + ) + + seq_cls_model = transformers.AutoModelForSequenceClassification.from_pretrained( + model_name, + num_labels=num_labels, + ignore_mismatched_sizes=True, + device_map=device, + ) + + method_map[method]( + causal_lm, seq_cls_model, tokenizer, classifier_from_tokens, device + ) + + # `llm as reranker` defaults to not using pad_token + seq_cls_model.config.use_pad_token = use_pad_token + seq_cls_model.config.pad_token_id = tokenizer.pad_token_id + + seq_cls_model.save_pretrained(path) + tokenizer.save_pretrained(path) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Converting *ForCausalLM models to " + "*ForSequenceClassification models." + ) + parser.add_argument( + "--model_name", + type=str, + default="BAAI/bge-reranker-v2-gemma", + help="Model name", + ) + parser.add_argument( + "--classifier_from_tokens", + type=str, + default='["Yes"]', + help="classifier from tokens", + ) + parser.add_argument( + "--method", type=str, default="no_post_processing", help="Converting converting" + ) + parser.add_argument( + "--use-pad-token", action="store_true", help="Whether to use pad_token" + ) + parser.add_argument( + "--path", + type=str, + default="./bge-reranker-v2-gemma-seq-cls", + help="Path to save converted model", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + + converting( + model_name=args.model_name, + classifier_from_tokens=json.loads(args.classifier_from_tokens), + method=args.method, + use_pad_token=args.use_pad_token, + path=args.path, + ) diff --git a/examples/offline_inference/embed_jina_embeddings_v3.py b/examples/offline_inference/embed_jina_embeddings_v3.py index e68128399ba2..7d78b8c63c63 100644 --- a/examples/offline_inference/embed_jina_embeddings_v3.py +++ b/examples/offline_inference/embed_jina_embeddings_v3.py @@ -30,11 +30,11 @@ def main(args: Namespace): # Create an LLM. # You should pass task="embed" for embedding models - model = LLM(**vars(args)) + llm = LLM(**vars(args)) # Generate embedding. The output is a list of EmbeddingRequestOutputs. # Only text matching task is supported for now. See #16120 - outputs = model.embed(prompts) + outputs = llm.embed(prompts) # Print the outputs. print("\nGenerated Outputs:") diff --git a/examples/offline_inference/embed_matryoshka_fy.py b/examples/offline_inference/embed_matryoshka_fy.py index 7f5d74d9a3ae..50a645ba8270 100644 --- a/examples/offline_inference/embed_matryoshka_fy.py +++ b/examples/offline_inference/embed_matryoshka_fy.py @@ -30,10 +30,10 @@ def main(args: Namespace): # Create an LLM. # You should pass task="embed" for embedding models - model = LLM(**vars(args)) + llm = LLM(**vars(args)) # Generate embedding. The output is a list of EmbeddingRequestOutputs. - outputs = model.embed(prompts, pooling_params=PoolingParams(dimensions=32)) + outputs = llm.embed(prompts, pooling_params=PoolingParams(dimensions=32)) # Print the outputs. print("\nGenerated Outputs:") diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py index 0b2070c8e253..8b1d235ff974 100644 --- a/examples/offline_inference/neuron_eagle.py +++ b/examples/offline_inference/neuron_eagle.py @@ -54,7 +54,7 @@ def main(): for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, \n\n\n Generated text: {generated_text!r}") if __name__ == "__main__": diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py index 2ef69f29863d..7fc22caee742 100644 --- a/examples/offline_inference/neuron_speculation.py +++ b/examples/offline_inference/neuron_speculation.py @@ -25,7 +25,7 @@ def config_buckets(): os.environ["NEURON_TOKEN_GEN_BUCKETS"] = "128,512,1024,2048" -def initialize_model(): +def initialize_llm(): """Create an LLM with speculative decoding.""" return LLM( model="openlm-research/open_llama_7b", @@ -37,15 +37,14 @@ def initialize_model(): max_num_seqs=4, max_model_len=2048, block_size=2048, - use_v2_block_manager=True, device="neuron", tensor_parallel_size=32, ) -def process_requests(model: LLM, sampling_params: SamplingParams): +def process_requests(llm: LLM, sampling_params: SamplingParams): """Generate texts from prompts and print them.""" - outputs = model.generate(prompts, sampling_params) + outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text @@ -53,12 +52,12 @@ def process_requests(model: LLM, sampling_params: SamplingParams): def main(): - """Main function that sets up the model and processes prompts.""" + """Main function that sets up the llm and processes prompts.""" config_buckets() - model = initialize_model() + llm = initialize_llm() # Create a sampling params object. sampling_params = SamplingParams(max_tokens=100, top_k=1) - process_requests(model, sampling_params) + process_requests(llm, sampling_params) if __name__ == "__main__": diff --git a/examples/offline_inference/prithvi_geospatial_mae.py b/examples/offline_inference/prithvi_geospatial_mae.py index 567c448a8c97..4fdc7a3cf709 100644 --- a/examples/offline_inference/prithvi_geospatial_mae.py +++ b/examples/offline_inference/prithvi_geospatial_mae.py @@ -1,122 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -This is a demo script showing how to use the -PrithviGeospatialMAE model with vLLM -This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa - -Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa - -The requirements for running this script are: -- Installing [terratorch, albumentations, rasterio] in your python environment -- downloading the model weights in a 'model' folder local to the script - (temporary measure until the proper config.json file is uploaded to HF) -- download an input example image (India_900498_S2Hand.tif) and place it in - the same folder with the script (or specify with the --data_file argument) - -Run the example: -python prithvi_geospatial_mae.py - -""" # noqa: E501 - import argparse import datetime import os +import re from typing import Union import albumentations import numpy as np import rasterio -import regex as re import torch from einops import rearrange from terratorch.datamodules import Sen1Floods11NonGeoDataModule from vllm import LLM +torch.set_default_dtype(torch.float16) + NO_DATA = -9999 NO_DATA_FLOAT = 0.0001 OFFSET = 0 PERCENTILE = 99 -model_config = """{ - "architectures": ["PrithviGeoSpatialMAE"], - "num_classes": 0, - "pretrained_cfg": { - "task_args": { - "task": "SemanticSegmentationTask", - "model_factory": "EncoderDecoderFactory", - "loss": "ce", - "ignore_index": -1, - "lr": 0.001, - "freeze_backbone": false, - "freeze_decoder": false, - "plot_on_val": 10, - "optimizer": "AdamW", - "scheduler": "CosineAnnealingLR" - }, - "model_args": { - "backbone_pretrained": false, - "backbone": "prithvi_eo_v2_300_tl", - "decoder": "UperNetDecoder", - "decoder_channels": 256, - "decoder_scale_modules": true, - "num_classes": 2, - "rescale": true, - "backbone_bands": [ - "BLUE", - "GREEN", - "RED", - "NIR_NARROW", - "SWIR_1", - "SWIR_2" - ], - "head_dropout": 0.1, - "necks": [ - { - "name": "SelectIndices", - "indices": [ - 5, - 11, - 17, - 23 - ] - }, - { - "name": "ReshapeTokensToImage" - } - ] - }, - "optimizer_params" : { - "lr": 5.0e-05, - "betas": [0.9, 0.999], - "eps": [1.0e-08], - "weight_decay": 0.05, - "amsgrad": false, - "maximize": false, - "capturable": false, - "differentiable": false - }, - "scheduler_params" : { - "T_max": 50, - "eta_min": 0, - "last_epoch": -1, - "verbose": "deprecated" - } - }, - - - "torch_dtype": "float32" -} -""" - -# Temporarily creating the "config.json" for the model. -# This is going to disappear once the correct config.json is available on HF -with open( - os.path.join(os.path.dirname(__file__), "./model/config.json"), "w" -) as config_file: - config_file.write(model_config) - datamodule_config = { "bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"], "batch_size": 16, @@ -138,28 +43,24 @@ class PrithviMAE: - def __init__(self): - print("Initializing PrithviMAE model") + def __init__(self, model): self.model = LLM( - model=os.path.join(os.path.dirname(__file__), "./model"), - skip_tokenizer_init=True, - dtype="float32", + model=model, skip_tokenizer_init=True, dtype="float16", enforce_eager=True ) def run(self, input_data, location_coords): - print("################ Running inference on vLLM ##############") # merge the inputs into one data structure + if input_data is not None and input_data.dtype == torch.float32: + input_data = input_data.to(torch.float16) + input_data = input_data[0] + mm_data = { - "pixel_values": torch.empty(0) if input_data is None else input_data, - "location_coords": torch.empty(0) - if location_coords is None - else location_coords, + "pixel_values": input_data, + "location_coords": location_coords, } prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data} - outputs = self.model.encode(prompt, use_tqdm=False) - print("################ Inference done (it took seconds) ##############") return outputs[0].outputs.data @@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels): """ Args: orig_img: torch.Tensor representing original image (reference) - with shape = (bands, H, W). + with shape = (bands, H, W). channels: list of indices representing RGB channels. Returns: - torch.Tensor with shape (num_channels, height, width) for original image + torch.Tensor with shape (num_channels, height, width) + for original image """ orig_img = orig_img[channels, ...] @@ -260,10 +162,10 @@ def load_example( Args: file_paths: list of file paths . - mean: list containing mean values for each band in the images - in *file_paths*. - std: list containing std values for each band in the images - in *file_paths*. + mean: list containing mean values for each band in the + images in *file_paths*. + std: list containing std values for each band in the + images in *file_paths*. Returns: np.array containing created example @@ -308,7 +210,7 @@ def load_example( print(f"Could not extract timestamp for {file} ({e})") imgs = np.stack(imgs, axis=0) # num_frames, H, W, C - imgs = np.moveaxis(imgs, -1, 0).astype("float32") + imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W imgs = np.expand_dims(imgs, axis=0) # add batch di return imgs, temporal_coords, location_coords, metas @@ -332,8 +234,10 @@ def run_model( ) # Build sliding window + batch_size = 1 - batch = torch.tensor(input_data, device="cpu") + # batch = torch.tensor(input_data, device="cpu") + batch = torch.tensor(input_data) windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size) h1, w1 = windows.shape[3:5] windows = rearrange( @@ -344,18 +248,16 @@ def run_model( num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1 windows = torch.tensor_split(windows, num_batches, dim=0) - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - if temporal_coords: - temporal_coords = torch.tensor(temporal_coords, device=device).unsqueeze(0) + temporal_coords = torch.tensor(temporal_coords).unsqueeze(0) else: temporal_coords = None if location_coords: - location_coords = torch.tensor(location_coords[0], device=device).unsqueeze(0) + location_coords = torch.tensor(location_coords[0]).unsqueeze(0) else: location_coords = None - # Run model + # Run Prithvi-EO-V2-300M-TL-Sen1Floods11 pred_imgs = [] for x in windows: # Apply standardization @@ -363,15 +265,7 @@ def run_model( x = datamodule.aug(x)["image"] with torch.no_grad(): - x = x.to(device) pred = model.run(x, location_coords=location_coords) - if lightning_model: - pred_lightning = lightning_model( - x, temporal_coords=temporal_coords, location_coords=location_coords - ) - pred_lightning = pred_lightning.output.detach().cpu() - if not torch.equal(pred, pred_lightning): - print("Inference output is not equal") y_hat = pred.argmax(dim=1) y_hat = torch.nn.functional.interpolate( @@ -403,52 +297,18 @@ def run_model( return pred_imgs -def parse_args(): - parser = argparse.ArgumentParser("MAE run inference", add_help=False) - - parser.add_argument( - "--data_file", - type=str, - default="./India_900498_S2Hand.tif", - help="Path to the file.", - ) - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Path to the directory where to save outputs.", - ) - parser.add_argument( - "--input_indices", - default=[1, 2, 3, 8, 11, 12], - type=int, - nargs="+", - help="0-based indices of the six Prithvi channels to be selected from the " - "input. By default selects [1,2,3,8,11,12] for S2L1C data.", - ) - parser.add_argument( - "--rgb_outputs", - action="store_true", - help="If present, output files will only contain RGB channels. " - "Otherwise, all bands will be saved.", - ) - - def main( data_file: str, + model: str, output_dir: str, rgb_outputs: bool, input_indices: list[int] = None, ): os.makedirs(output_dir, exist_ok=True) - # Load model --------------------------------------------------------------- - - model_obj = PrithviMAE() + model_obj = PrithviMAE(model=model) datamodule = generate_datamodule() - img_size = 256 # Size of Sen1Floods11 - - # Loading data ------------------------------------------------------------- + img_size = 512 # Size of Sen1Floods11 input_data, temporal_coords, location_coords, meta_data = load_example( file_paths=[data_file], @@ -460,8 +320,6 @@ def main( if input_data.mean() > 1: input_data = input_data / 10000 # Convert to range 0-1 - # Running model ------------------------------------------------------------ - channels = [ datamodule_config["bands"].index(b) for b in ["RED", "GREEN", "BLUE"] ] # BGR -> RGB @@ -469,7 +327,6 @@ def main( pred = run_model( input_data, temporal_coords, location_coords, model_obj, datamodule, img_size ) - # Save pred meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0) pred_file = os.path.join( @@ -487,6 +344,7 @@ def main( orig_img=torch.Tensor(input_data[0, :, 0, ...]), channels=channels, ) + rgb_orig = rgb_orig.to(torch.float32) pred[pred == 0.0] = np.nan img_pred = rgb_orig * 0.7 + pred * 0.3 @@ -503,9 +361,10 @@ def main( # Save image rgb if rgb_outputs: + name_suffix = os.path.splitext(os.path.basename(data_file))[0] rgb_file = os.path.join( output_dir, - f"original_rgb_{os.path.splitext(os.path.basename(data_file))[0]}.tiff", + f"original_rgb_{name_suffix}.tiff", ) save_geotiff( image=_convert_np_uint8(rgb_orig), @@ -515,6 +374,42 @@ def main( if __name__ == "__main__": - args = parse_args() + parser = argparse.ArgumentParser("MAE run inference", add_help=False) + + parser.add_argument( + "--data_file", + type=str, + default="./India_900498_S2Hand.tif", + help="Path to the file.", + ) + parser.add_argument( + "--model", + type=str, + default="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + help="Path to a checkpoint file to load from.", + ) + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Path to the directory where to save outputs.", + ) + parser.add_argument( + "--input_indices", + default=[1, 2, 3, 8, 11, 12], + type=int, + nargs="+", + help=""" + 0-based indices of the six Prithvi channels to be selected from the input. + By default selects [1,2,3,8,11,12] for S2L1C data. + """, + ) + parser.add_argument( + "--rgb_outputs", + action="store_true", + help="If present, output files will only contain RGB channels. " + "Otherwise, all bands will be saved.", + ) + args = parser.parse_args() main(**vars(args)) diff --git a/examples/offline_inference/qwen3_reranker.py b/examples/offline_inference/qwen3_reranker.py index fe3cebc348f1..b0fd57237d47 100644 --- a/examples/offline_inference/qwen3_reranker.py +++ b/examples/offline_inference/qwen3_reranker.py @@ -17,13 +17,13 @@ # Models converted offline using this method can not only be more efficient # and support the vllm score API, but also make the init parameters more # concise, for example. -# model = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score") +# llm = LLM(model="tomaarsen/Qwen3-Reranker-0.6B-seq-cls", task="score") # If you want to load the official original version, the init parameters are # as follows. -def get_model() -> LLM: +def get_llm() -> LLM: """Initializes and returns the LLM model for Qwen3-Reranker.""" return LLM( model=model_name, @@ -77,8 +77,8 @@ def main() -> None: ] documents = [document_template.format(doc=doc, suffix=suffix) for doc in documents] - model = get_model() - outputs = model.score(queries, documents) + llm = get_llm() + outputs = llm.score(queries, documents) print("-" * 30) print([output.outputs.score for output in outputs]) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index c6e63531a99d..752117a4e362 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -1,17 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -a simple demonstration of RLHF with vLLM, inspired by -the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF . -It follows the design that, training processes and inference processes -are different, and they live on different GPUs. -Training processes send prompts to inference processes to generate data, -and also synchronize the weights of the model by broadcasting the weights -from the training process to the inference process. -Note that this is a simple demonstration of one training instance and one -inference instance. In practice, there could be multiple training instances -and multiple inference instances. For the full implementation, please refer -to the OpenRLHF framework. +Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray. + +The script separates training and inference workloads onto distinct GPUs +so that Ray can manage process placement and inter-process communication. +A Hugging Face Transformer model occupies GPU 0 for training, whereas a +tensor-parallel vLLM inference engine occupies GPU 1–2. + +The example performs the following steps: + +* Load the training model on GPU 0. +* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism + and Ray placement groups. +* Generate text from a list of prompts using the inference engine. +* Update the weights of the training model and broadcast the updated weights + to the inference engine by using a Ray collective RPC group. Note that + for demonstration purposes we simply zero out the weights. + +For a production-ready implementation that supports multiple training and +inference replicas, see the OpenRLHF framework: +https://github.com/OpenRLHF/OpenRLHF + +This example assumes a single-node cluster with three GPUs, but Ray +supports multi-node clusters. vLLM expects the GPUs are only used for vLLM +workloads. Residual GPU activity interferes with vLLM memory profiling and +causes unexpected behavior. """ import os @@ -28,29 +42,27 @@ class MyLLM(LLM): + """Configure the vLLM worker for Ray placement group execution.""" + def __init__(self, *args, **kwargs): - # a hack to make the script work. - # stop ray from manipulating CUDA_VISIBLE_DEVICES - # at the top-level + # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray + # so that vLLM can manage its own device placement within the worker. os.environ.pop("CUDA_VISIBLE_DEVICES", None) super().__init__(*args, **kwargs) -""" -Start the training process, here we use huggingface transformers -as an example to hold a model on GPU 0. -""" - +# Load the OPT-125M model onto GPU 0 for the training workload. train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") train_model.to("cuda:0") -""" -Start the inference process, here we use vLLM to hold a model on GPU 1 and -GPU 2. For the details on how to use ray, please refer to the ray -documentation https://docs.ray.io/en/latest/ . -""" + +# Initialize Ray and set the visible devices. The vLLM engine will +# be placed on GPUs 1 and 2. os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" ray.init() +# Create a placement group that reserves GPU 1–2 for the vLLM inference engine. +# Learn more about Ray placement groups: +# https://docs.ray.io/en/latest/placement-groups.html pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2) ray.get(pg_inference.ready()) scheduling_inference = PlacementGroupSchedulingStrategy( @@ -58,10 +70,9 @@ def __init__(self, *args, **kwargs): placement_group_capture_child_tasks=True, placement_group_bundle_index=0, ) -""" -launch the vLLM inference engine. -here we use `enforce_eager` to reduce the start time. -""" + +# Launch the vLLM inference engine. The `enforce_eager` flag reduces +# start-up latency. llm = ray.remote( num_cpus=0, num_gpus=0, @@ -74,7 +85,7 @@ def __init__(self, *args, **kwargs): distributed_executor_backend="ray", ) -# Generate texts from the prompts. +# Generate text from the prompts. prompts = [ "Hello, my name is", "The president of the United States is", @@ -93,8 +104,8 @@ def __init__(self, *args, **kwargs): print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") print("-" * 50) -# set up the communication between the training process -# and the inference engine. +# Set up the communication channel between the training process and the +# inference engine. master_address = get_ip() master_port = get_open_port() @@ -107,21 +118,23 @@ def __init__(self, *args, **kwargs): ) ray.get(handle) -# simulate training, modify the weights of the model. +# Simulate a training step by zeroing out all model weights. +# In a real RLHF training loop the weights would be updated using the gradient +# from an RL objective such as PPO on a reward model. for name, p in train_model.named_parameters(): p.data.zero_() -# sync weight from the training process to the inference engine. +# Synchronize the updated weights to the inference engine. for name, p in train_model.named_parameters(): handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape)) model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream()) ray.get(handle) -# check if the weights are updated. +# Verify that the inference weights have been updated. assert all(ray.get(llm.collective_rpc.remote("check_weights_changed"))) -# use the updated model to generate texts, they will be nonsense -# because the weights are all zeros. +# Generate text with the updated model. The output is expected to be nonsense +# because the weights are zero. outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params)) print("-" * 50) for output in outputs_updated: diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 096363e68301..65621023ab6c 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -1,14 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -a simple demonstration to show how to co-locate -vLLM worker with training actors on the same GPUs, -for RLHF-like applications. -The key points: -- Control the placement of the vLLM workers with Ray, by setting - VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly. -- Use cuda-ipc to pass tensors, since NCCL does not work when we have - multiple processes on the same GPU. +Demonstrates how to co-locate a vLLM inference worker and training +actors on the same set of GPUs for reinforcement learning from human feedback +(RLHF) workloads. + +Ray serves as the distributed execution framework in this example. Ray +placement groups allocate both training actors and vLLM workers to the +same GPU bundles, enabling fast, in-GPU communication between the two +components. + +The script shows how to do the following: + +* Configure environment variables (`VLLM_RAY_PER_WORKER_GPUS` and + `VLLM_RAY_BUNDLE_INDICES`) so that vLLM workers land on the desired + devices. +* Exchange tensors between processes by means of CUDA inter-process + communication (IPC). CUDA IPC sidesteps NCCL limitations that occur + when multiple processes share a single GPU. + +Note that this example assumes a single-node cluster with four GPUs, but Ray +supports multi-node clusters. vLLM expects exclusive use of the GPUs during +its initialization for memory profiling. Residual GPU activity interferes +with vLLM memory profiling and causes unexpected behavior. + +Learn more about Ray placement groups: +https://docs.ray.io/en/latest/placement-groups.html """ import os @@ -22,13 +39,24 @@ class MyLLM(LLM): - def __init__(self, *args, bundle_indices: list, **kwargs): - # a hack to make the script work. - # stop ray from manipulating CUDA_VISIBLE_DEVICES - # at the top-level + """Configure the vLLM worker for Ray placement group execution. + + The constructor sets environment variables that allow multiple vLLM + workers to share a single physical GPU and that encode the bundle + indices assigned by the placement group. + + Args: + *args: Positional arguments forwarded to `vllm.LLM`. + bundle_indices (list[int]): Placement-group bundle indices + assigned to this worker. + **kwargs: Keyword arguments forwarded to `vllm.LLM`. + """ + + def __init__(self, *args, bundle_indices: list[int], **kwargs): + # Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable + # so that vLLM can its own device placement inside the worker. os.environ.pop("CUDA_VISIBLE_DEVICES", None) - # every worker will use 0.4 GPU, so that we can schedule - # 2 instances on the same GPUs. + # Each worker uses 0.4 GPU so that two instances fit on the same GPUs. os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) print(f"creating LLM with bundle_indices={bundle_indices}") @@ -36,17 +64,25 @@ def __init__(self, *args, bundle_indices: list, **kwargs): class RayTrainingActor: + """Training actor that hosts a Facebook OPT-125M model from Hugging Face. + + The model is loaded onto the first GPU assigned to this actor, and expose + the CUDA IPC handles so that colocated vLLM workers can map tensors + directly. + """ + def __init__(self): - # ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs + # Ray sets CUDA_VISIBLE_DEVICES to the GPUs assigned to this actor. from transformers import AutoModelForCausalLM self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") self.model.to("cuda:0") + # Zero out all the parameters. for name, p in self.model.named_parameters(): p.data.zero_() torch.cuda.synchronize() - # the argument for get_device_uuid is the index - # of the GPU in the visible devices. + # The argument for `get_device_uuid` is the index of the GPU in the + # list of visible devices. from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) @@ -59,23 +95,23 @@ def get_weight_ipc_handles(self): data = {} for name, p in self.model.named_parameters(): - # the training actor might only have a subset of the weights - # and need to all-gather the weights from all the actors. - # for demonstration, here we assume all training actors have - # the full weights. + # A training actor might hold only a subset of the weights and may + # need to gather weights from other actors. For demonstration + # purposes, each training actor owns the full weight set. data[name] = reduce_tensor(p.detach()) return {self.device_uuid: data} -# ray manages 4 GPUs +# Ray manages four GPUs. + os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" ray.init() -# we want to co-locate vLLM instance and the training actor -# on the same set of GPUs. -# the placement plan is as follows: -# GPU 0 and 1: training actor 0, 1, and vLLM instance 0 (with TP=2) -# GPU 2 and 3: training actor 2, 3, and vLLM instance 1 (with TP=2) +# Co-locate vLLM instances and training actors on the same set of GPUs: +# * GPU 0 and 1: training actor 0, training actor 1, and vLLM instance 0 +# (tensor parallelism = 2). +# * GPU 2 and 3: training actor 2, training actor 3, and vLLM instance 1 +# (tensor parallelism = 2). pg = placement_group([{"GPU": 1, "CPU": 0}] * 4) ray.get(pg.ready()) @@ -104,10 +140,8 @@ def get_weight_ipc_handles(self): training_actor_device_ids.append(device_id) for i, bundle_indices in enumerate([[0, 1], [2, 3]]): - # IMPORTANT: when creating vLLM instances, we need to - # make sure there are no GPU activities on the target GPUs, - # otherwise, they will interfere with the vLLM memory profiling, - # and cause unexpected behaviors. + # Use the following syntax instead of the @ray.remote decorator so that + # the placement group is customized for each bundle. llm = ray.remote( num_cpus=0, num_gpus=0, @@ -125,8 +159,8 @@ def get_weight_ipc_handles(self): bundle_indices=bundle_indices, ) inference_engines.append(llm) - # don't call any method on the inference engine here, - # otherwise it will block until the vLLM instance is created. + # Do not call any method on the inference engine at this point; the call + # blocks until the vLLM instance finishes initialization. for i, llm in enumerate(inference_engines): inference_engine_device_ids.append( @@ -134,26 +168,25 @@ def get_weight_ipc_handles(self): ) print(f"inference engine {i} is on {inference_engine_device_ids[-1]}") -# check the placement -# the first two training actors should be -# on the same GPUs as the first inference engine +# Verify placement: the first two training actors share the same GPUs as +# the first inference engine. assert training_actor_device_ids[:2] == inference_engine_device_ids[0] -# the last two training actors should be -# on the same GPUs as the second inference engine +# Verify placement: the last two training actors share the same GPUs as +# the second inference engine. assert training_actor_device_ids[2:] == inference_engine_device_ids[1] -print("gather all the IPC handles from the training actors") +print("Gather all the IPC handles from the training actors.") ipc_handles = {} for actor in training_actors: ipc_handles.update(ray.get(actor.get_weight_ipc_handles.remote())) -print("update the weights of the inference engines") +print("Update the weights of the inference engines.") for llm in inference_engines: ray.get( llm.collective_rpc.remote( "update_weights_from_ipc_handles", args=(ipc_handles,) ) ) -print("check if the weights are updated") +print("Check if the weights are updated.") for llm in inference_engines: assert ray.get(llm.collective_rpc.remote("check_weights_changed", args=tuple())) diff --git a/examples/offline_inference/skip_loading_weights_in_engine_init.py b/examples/offline_inference/skip_loading_weights_in_engine_init.py new file mode 100644 index 000000000000..1a616817dd23 --- /dev/null +++ b/examples/offline_inference/skip_loading_weights_in_engine_init.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm import LLM, RequestOutput, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + +def print_prompts_and_outputs(outputs: list[RequestOutput]) -> None: + print("-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + +def main(): + # Create an LLM without loading real weights + llm = LLM( + model="Qwen/Qwen3-0.6B", + load_format="dummy", + enforce_eager=True, + tensor_parallel_size=4, + ) + outputs = llm.generate(prompts, sampling_params) + print("\nOutputs do not make sense:") + print_prompts_and_outputs(outputs) + + # Update load format from `dummy` to `auto` + llm.collective_rpc( + "update_config", args=({"load_config": {"load_format": "auto"}},) + ) + # Now reload real weights inplace + llm.collective_rpc("reload_weights") + + # Check outputs make sense + outputs = llm.generate(prompts, sampling_params) + print("\nOutputs make sense after loading real weights:") + print_prompts_and_outputs(outputs) + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/spec_decode.py b/examples/offline_inference/spec_decode.py index 26e492fed25f..ce735f3b27df 100644 --- a/examples/offline_inference/spec_decode.py +++ b/examples/offline_inference/spec_decode.py @@ -84,6 +84,7 @@ def main(): gpu_memory_utilization=0.8, speculative_config=speculative_config, disable_log_stats=False, + max_model_len=16384, ) sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 5bd75a78f2c4..e4811c023377 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -429,6 +429,44 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: ) +# Nemontron_VL +def run_nemotron_vl(questions: list[str], modality: str) -> ModelRequestData: + model_name = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + limit_mm_per_prompt={modality: 1}, + ) + + assert modality == "image" + placeholder = "<image>" + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + messages = [ + [{"role": "user", "content": f"{placeholder}\n{question}"}] + for question in questions + ] + prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + # Stop tokens for InternVL + # models variants may have different stop tokens + # please refer to the model card for the correct "stop words": + # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + stop_token_ids = [token_id for token_id in stop_token_ids if token_id is not None] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + stop_token_ids=stop_token_ids, + ) + + # Keye-VL def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData: model_name = "Kwai-Keye/Keye-VL-8B-Preview" @@ -1186,6 +1224,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: "h2ovl_chat": run_h2ovl, "idefics3": run_idefics3, "internvl_chat": run_internvl, + "nemotron_vl": run_nemotron_vl, "keye_vl": run_keye_vl, "kimi_vl": run_kimi_vl, "llava": run_llava, diff --git a/examples/offline_inference/vision_language_embedding.py b/examples/offline_inference/vision_language_pooling.py similarity index 66% rename from examples/offline_inference/vision_language_embedding.py rename to examples/offline_inference/vision_language_pooling.py index 9451825f0b73..57963ebd2b10 100644 --- a/examples/offline_inference/vision_language_embedding.py +++ b/examples/offline_inference/vision_language_pooling.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This example shows how to use vLLM for running offline inference with -the correct prompt format on vision language models for multimodal embedding. +the correct prompt format on vision language models for multimodal pooling. For most models, the prompt format should follow corresponding examples on HuggingFace model repository. @@ -15,6 +15,7 @@ from PIL.Image import Image from vllm import LLM, EngineArgs +from vllm.entrypoints.score_utils import ScoreMultiModalParam from vllm.multimodal.utils import fetch_image from vllm.utils import FlexibleArgumentParser @@ -35,14 +36,22 @@ class TextImageQuery(TypedDict): image: Image -QueryModality = Literal["text", "image", "text+image"] -Query = Union[TextQuery, ImageQuery, TextImageQuery] +class TextImagesQuery(TypedDict): + modality: Literal["text+images"] + text: str + image: ScoreMultiModalParam + + +QueryModality = Literal["text", "image", "text+image", "text+images"] +Query = Union[TextQuery, ImageQuery, TextImageQuery, TextImagesQuery] class ModelRequestData(NamedTuple): engine_args: EngineArgs - prompt: str - image: Optional[Image] + prompt: Optional[str] = None + image: Optional[Image] = None + query: Optional[str] = None + documents: Optional[ScoreMultiModalParam] = None def run_e5_v(query: Query) -> ModelRequestData: @@ -107,6 +116,29 @@ def run_vlm2vec(query: Query) -> ModelRequestData: ) +def run_jinavl_reranker(query: Query) -> ModelRequestData: + if query["modality"] != "text+images": + raise ValueError(f"Unsupported query modality: '{query['modality']}'") + + engine_args = EngineArgs( + model="jinaai/jina-reranker-m0", + task="score", + max_model_len=32768, + trust_remote_code=True, + mm_processor_kwargs={ + "min_pixels": 3136, + "max_pixels": 602112, + }, + limit_mm_per_prompt={"image": 1}, + ) + + return ModelRequestData( + engine_args=engine_args, + query=query["text"], + documents=query["image"], + ) + + def get_query(modality: QueryModality): if modality == "text": return TextQuery(modality="text", text="A dog sitting in the grass") @@ -128,6 +160,28 @@ def get_query(modality: QueryModality): ), ) + if modality == "text+images": + return TextImagesQuery( + modality="text+images", + text="slm markdown", + image={ + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + }, + }, + ] + }, + ) + msg = f"Modality {modality} is not supported." raise ValueError(msg) @@ -162,16 +216,31 @@ def run_encode(model: str, modality: QueryModality, seed: Optional[int]): print("-" * 50) +def run_score(model: str, modality: QueryModality, seed: Optional[int]): + query = get_query(modality) + req_data = model_example_map[model](query) + + engine_args = asdict(req_data.engine_args) | {"seed": seed} + llm = LLM(**engine_args) + + outputs = llm.score(req_data.query, req_data.documents) + + print("-" * 30) + print([output.outputs.score for output in outputs]) + print("-" * 30) + + model_example_map = { "e5_v": run_e5_v, "vlm2vec": run_vlm2vec, + "jinavl_reranker": run_jinavl_reranker, } def parse_args(): parser = FlexibleArgumentParser( description="Demo on using vLLM for offline inference with " - "vision language models for multimodal embedding" + "vision language models for multimodal pooling tasks." ) parser.add_argument( "--model-name", @@ -181,6 +250,14 @@ def parse_args(): choices=model_example_map.keys(), help="The name of the embedding model.", ) + parser.add_argument( + "--task", + "-t", + type=str, + default="embedding", + choices=["embedding", "scoring"], + help="The task type.", + ) parser.add_argument( "--modality", type=str, @@ -198,7 +275,12 @@ def parse_args(): def main(args: Namespace): - run_encode(args.model_name, args.modality, args.seed) + if args.task == "embedding": + run_encode(args.model_name, args.modality, args.seed) + elif args.task == "scoring": + run_score(args.model_name, args.modality, args.seed) + else: + raise ValueError(f"Unsupported task: {args.task}") if __name__ == "__main__": diff --git a/examples/online_serving/chart-helm/values.yaml b/examples/online_serving/chart-helm/values.yaml index 28dba9a6f688..815f02a4bfd5 100644 --- a/examples/online_serving/chart-helm/values.yaml +++ b/examples/online_serving/chart-helm/values.yaml @@ -8,7 +8,7 @@ image: # -- Image tag tag: "latest" # -- Container launch command - command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] + command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] # -- Container port containerPort: 8000 diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh index 2966f386c93a..76f5c0c99d0b 100644 --- a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_example_p2p_nccl_xpyd.sh @@ -93,6 +93,7 @@ ensure_python_library_installed() { cleanup() { echo "Stopping everything…" trap - INT TERM # prevent re-entrancy + pkill -9 -f "disagg_proxy_p2p_nccl_xpyd.py" kill -- -$$ # negative PID == "this whole process-group" wait # reap children so we don't leave zombies exit 0 diff --git a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py index 4e82424d6cd7..ec58a183061e 100644 --- a/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py +++ b/examples/online_serving/disaggregated_serving_p2p_nccl_xpyd/disagg_proxy_p2p_nccl_xpyd.py @@ -4,7 +4,9 @@ import os import socket import threading +import time import uuid +from typing import Any import aiohttp import msgpack @@ -12,12 +14,25 @@ from quart import Quart, make_response, request count = 0 -prefill_instances: dict[str, str] = {} # http_address: zmq_address -decode_instances: dict[str, str] = {} # http_address: zmq_address +prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) +decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) prefill_cv = threading.Condition() decode_cv = threading.Condition() +DEFAULT_PING_SECONDS = 5 + + +def _remove_oldest_instances(instances: dict[str, Any]) -> None: + oldest_key = next(iter(instances), None) + while oldest_key is not None: + value = instances[oldest_key] + if value[1] > time.time(): + break + print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]") + instances.pop(oldest_key, None) + oldest_key = next(iter(instances), None) + def _listen_for_register(poller, router_socket): while True: @@ -31,12 +46,23 @@ def _listen_for_register(poller, router_socket): global prefill_instances global prefill_cv with prefill_cv: - prefill_instances[data["http_address"]] = data["zmq_address"] + node = prefill_instances.pop(data["http_address"], None) + prefill_instances[data["http_address"]] = ( + data["zmq_address"], + time.time() + DEFAULT_PING_SECONDS, + ) + _remove_oldest_instances(prefill_instances) + elif data["type"] == "D": global decode_instances global decode_cv with decode_cv: - decode_instances[data["http_address"]] = data["zmq_address"] + node = decode_instances.pop(data["http_address"], None) + decode_instances[data["http_address"]] = ( + data["zmq_address"], + time.time() + DEFAULT_PING_SECONDS, + ) + _remove_oldest_instances(decode_instances) else: print( "Unexpected, Received message from %s, data: %s", @@ -44,6 +70,9 @@ def _listen_for_register(poller, router_socket): data, ) + if node is None: + print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]") + def start_service_discovery(hostname, port): if not hostname: @@ -105,12 +134,14 @@ async def handle_request(): with prefill_cv: prefill_list = list(prefill_instances.items()) prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] + prefill_zmq_addr = prefill_zmq_addr[0] global decode_instances global decode_cv with decode_cv: decode_list = list(decode_instances.items()) decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] + decode_zmq_addr = decode_zmq_addr[0] print( f"handle_request count: {count}, [HTTP:{prefill_addr}, " diff --git a/examples/online_serving/elastic_ep/bench.sh b/examples/online_serving/elastic_ep/bench.sh new file mode 100644 index 000000000000..e47631465618 --- /dev/null +++ b/examples/online_serving/elastic_ep/bench.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite" +LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0" +HOST="localhost" +PORT=8006 +NUM_PROMPTS=20 +REQUEST_RATE=5 + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL_NAME="$2" + shift 2 + ;; + --local-model) + MODEL_NAME=$LOCAL_MODEL_PATH + shift + ;; + --host) + HOST="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --num-prompts) + NUM_PROMPTS="$2" + shift 2 + ;; + --request-rate) + REQUEST_RATE="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --model MODEL_NAME Set model name or path (default: deepseek-ai/DeepSeek-V2-Lite)" + echo " --local-model Use local model path (convenience option)" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use -h or --help for usage information" + exit 1 + ;; + esac +done + +vllm bench serve \ + --model $MODEL_NAME \ + --host $HOST \ + --port $PORT \ + --num-prompts $NUM_PROMPTS \ + --request-rate $REQUEST_RATE diff --git a/examples/online_serving/elastic_ep/scale.py b/examples/online_serving/elastic_ep/scale.py new file mode 100644 index 000000000000..a93c299e3234 --- /dev/null +++ b/examples/online_serving/elastic_ep/scale.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import json +import sys + +import requests + + +def scale(host, port, new_dp_size): + url = f"http://{host}:{port}/scale_elastic_ep" + payload = {"new_data_parallel_size": new_dp_size} + headers = {"Content-Type": "application/json"} + + print(f"Sending scale request to {url}") + print(f"Payload: {json.dumps(payload, indent=2)}") + + try: + response = requests.post(url, json=payload, headers=headers, timeout=300) + + print(f"Status Code: {response.status_code}") + print(f"Response: {response.text}") + + if response.status_code == 200: + print("Scale up/down request successful!") + return True + else: + print("Scale up/down request failed!") + return False + + except requests.exceptions.RequestException as e: + print(f"Request failed: {e}") + return False + + +def main(): + parser = argparse.ArgumentParser(description="Test scale up/down functionality") + parser.add_argument("--host", default="localhost", help="API server host") + parser.add_argument("--port", type=int, default=8006, help="API server port") + parser.add_argument( + "--new-dp-size", type=int, default=2, help="New data parallel size" + ) + + args = parser.parse_args() + + success = scale(args.host, args.port, args.new_dp_size) + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/elastic_ep/serve_deepseek_v2.sh b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh new file mode 100644 index 000000000000..1234ebba4d81 --- /dev/null +++ b/examples/online_serving/elastic_ep/serve_deepseek_v2.sh @@ -0,0 +1,72 @@ +#!/bin/bash + +HOST="0.0.0.0" +PORT=8006 +DATA_PARALLEL_SIZE=4 +REDUNDANT_EXPERTS=0 +LOCAL_MODEL_PATH="/models/models--deepseek-ai--DeepSeek-V2-Lite/snapshots/604d5664dddd88a0433dbae533b7fe9472482de0" +MODEL_NAME="deepseek-ai/DeepSeek-V2-Lite" + +while [[ $# -gt 0 ]]; do + case $1 in + --dp) + DATA_PARALLEL_SIZE="$2" + shift 2 + ;; + --re) + REDUNDANT_EXPERTS="$2" + shift 2 + ;; + --host) + HOST="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --model) + MODEL_NAME="$2" + shift 2 + ;; + --local-model) + MODEL_NAME=$LOCAL_MODEL_PATH + shift + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --dp SIZE Set data parallel size (default: 4)" + echo " --re SIZE Set redundant experts (default: 0)" + echo " --host HOST Set host address (default: 0.0.0.0)" + echo " --port PORT Set port number (default: 8006)" + echo " --model MODEL_NAME Set model name or path" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use -h or --help for usage information" + exit 1 + ;; + esac +done + +echo "Starting vLLM server for $MODEL_NAME with data parallel size: $DATA_PARALLEL_SIZE and redundant experts: $REDUNDANT_EXPERTS" + +export RAY_DEDUP_LOGS=0 +export VLLM_USE_V1=1 +export VLLM_ALL2ALL_BACKEND="pplx" +export VLLM_USE_DEEP_GEMM=1 + +vllm serve $MODEL_NAME \ + --data-parallel-size $DATA_PARALLEL_SIZE \ + --data-parallel-size-local $DATA_PARALLEL_SIZE \ + --data-parallel-backend ray \ + --enforce-eager \ + --enable-expert-parallel \ + --enable-eplb \ + --num-redundant-experts $REDUNDANT_EXPERTS \ + --trust-remote-code \ + --host $HOST \ + --port $PORT diff --git a/examples/online_serving/multi-node-serving.sh b/examples/online_serving/multi-node-serving.sh index 067f20c69b88..e8ad8d3de5f4 100644 --- a/examples/online_serving/multi-node-serving.sh +++ b/examples/online_serving/multi-node-serving.sh @@ -1,12 +1,35 @@ #!/bin/bash +# +# Helper script to manually start or join a Ray cluster for online serving of vLLM models. +# This script is first executed on the head node, and then on each worker node with the IP address +# of the head node. +# +# Subcommands: +# leader: Launches a Ray head node and blocks until the cluster reaches the expected size (head + workers). +# worker: Starts a worker node that connects to an existing Ray head node. +# +# Example usage: +# On the head node machine, start the Ray head node process and run a vLLM server. +# ./multi-node-serving.sh leader --ray_port=6379 --ray_cluster_size=<SIZE> [<extra ray args>] && \ +# python3 -m vllm.entrypoints.openai.api_server --port 8080 --model meta-llama/Meta-Llama-3.1-405B-Instruct --tensor-parallel-size 8 --pipeline_parallel_size 2 +# +# On each worker node, start the Ray worker node process. +# ./multi-node-serving.sh worker --ray_address=<HEAD_NODE_IP> --ray_port=6379 [<extra ray args>] +# +# About Ray: +# Ray is an open-source distributed execution framework that simplifies +# distributed computing. Learn more: +# https://ray.io/ -subcommand=$1 -shift -ray_port=6379 -ray_init_timeout=300 -declare -a start_params +subcommand=$1 # Either "leader" or "worker". +shift # Remove the subcommand from the argument list. +ray_port=6379 # Port used by the Ray head node. +ray_init_timeout=300 # Seconds to wait before timing out. +declare -a start_params # Parameters forwarded to the underlying 'ray start' command. + +# Handle the worker subcommand. case "$subcommand" in worker) ray_address="" @@ -32,6 +55,7 @@ case "$subcommand" in exit 1 fi + # Retry until the worker node connects to the head node or the timeout expires. for (( i=0; i < $ray_init_timeout; i+=5 )); do ray start --address=$ray_address:$ray_port --block "${start_params[@]}" if [ $? -eq 0 ]; then @@ -45,6 +69,7 @@ case "$subcommand" in exit 1 ;; + # Handle the leader subcommand. leader) ray_cluster_size="" while [ $# -gt 0 ]; do @@ -69,10 +94,10 @@ case "$subcommand" in exit 1 fi - # start the ray daemon + # Start the Ray head node. ray start --head --port=$ray_port "${start_params[@]}" - # wait until all workers are active + # Poll Ray until every worker node is active. for (( i=0; i < $ray_init_timeout; i+=5 )); do active_nodes=`python3 -c 'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))'` if [ $active_nodes -eq $ray_cluster_size ]; then diff --git a/examples/online_serving/openai_cross_encoder_score_for_multimodal.py b/examples/online_serving/openai_cross_encoder_score_for_multimodal.py new file mode 100644 index 000000000000..e49905a864c1 --- /dev/null +++ b/examples/online_serving/openai_cross_encoder_score_for_multimodal.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Example online usage of Score API. + +Run `vllm serve <model> --task score` to start up the server in vLLM. +""" + +import argparse +import pprint + +import requests + + +def post_http_request(prompt: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="jinaai/jina-reranker-m0") + return parser.parse_args() + + +def main(args): + api_url = f"http://{args.host}:{args.port}/score" + model_name = args.model + + text_1 = "slm markdown" + text_2 = { + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png" + }, + }, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + }, + }, + ] + } + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("\nPrompt when text_1 is string and text_2 is a image list:") + pprint.pprint(prompt) + print("\nScore Response:") + pprint.pprint(score_response.json()) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py index 9471563ddb76..d24b553df27c 100644 --- a/examples/online_serving/ray_serve_deepseek.py +++ b/examples/online_serving/ray_serve_deepseek.py @@ -1,13 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ -Example to deploy DeepSeek R1 or V3 with Ray Serve LLM. -See more details at: -https://docs.ray.io/en/latest/serve/tutorials/serve-deepseek.html -And see Ray Serve LLM documentation at: -https://docs.ray.io/en/latest/serve/llm/serving-llms.html +Deploy DeepSeek R1 or V3 with Ray Serve LLM. + +Ray Serve LLM is a scalable and production-grade model serving library built +on the Ray distributed computing framework and first-class support for the vLLM engine. + +Key features: +- Automatic scaling, back-pressure, and load balancing across a Ray cluster. +- Unified multi-node multi-model deployment. +- Exposes an OpenAI-compatible HTTP API. +- Multi-LoRA support with shared base models. -Run `python3 ray_serve_deepseek.py` to deploy the model. +Run `python3 ray_serve_deepseek.py` to launch an endpoint. + +Learn more in the official Ray Serve LLM documentation: +https://docs.ray.io/en/latest/serve/llm/serving-llms.html """ from ray import serve @@ -16,9 +24,8 @@ llm_config = LLMConfig( model_loading_config={ "model_id": "deepseek", - # Since DeepSeek model is huge, it is recommended to pre-download - # the model to local disk, say /path/to/the/model and specify: - # model_source="/path/to/the/model" + # Pre-downloading the model to local storage is recommended since + # the model is large. Set model_source="/path/to/the/model". "model_source": "deepseek-ai/DeepSeek-R1", }, deployment_config={ @@ -27,10 +34,10 @@ "max_replicas": 1, } }, - # Change to the accelerator type of the node + # Set to the node's accelerator type. accelerator_type="H100", runtime_env={"env_vars": {"VLLM_USE_V1": "1"}}, - # Customize engine arguments as needed (e.g. vLLM engine kwargs) + # Customize engine arguments as required (for example, vLLM engine kwargs). engine_kwargs={ "tensor_parallel_size": 8, "pipeline_parallel_size": 2, @@ -44,6 +51,6 @@ }, ) -# Deploy the application +# Deploy the application. llm_app = build_openai_app({"llm_configs": [llm_config]}) serve.run(llm_app) diff --git a/examples/online_serving/run_cluster.sh b/examples/online_serving/run_cluster.sh index 7b4b40b4b7e2..522b9566212b 100644 --- a/examples/online_serving/run_cluster.sh +++ b/examples/online_serving/run_cluster.sh @@ -1,35 +1,81 @@ #!/bin/bash +# +# Launch a Ray cluster inside Docker for vLLM inference. +# +# This script can start either a head node or a worker node, depending on the +# --head or --worker flag provided as the third positional argument. +# +# Usage: +# 1. Designate one machine as the head node and execute: +# bash run_cluster.sh \ +# vllm/vllm-openai \ +# <head_node_ip> \ +# --head \ +# /abs/path/to/huggingface/cache \ +# -e VLLM_HOST_IP=<head_node_ip> +# +# 2. On every worker machine, execute: +# bash run_cluster.sh \ +# vllm/vllm-openai \ +# <head_node_ip> \ +# --worker \ +# /abs/path/to/huggingface/cache \ +# -e VLLM_HOST_IP=<worker_node_ip> +# +# Each worker requires a unique VLLM_HOST_IP value. +# Keep each terminal session open. Closing a session stops the associated Ray +# node and thereby shuts down the entire cluster. +# Every machine must be reachable at the supplied IP address. +# +# The container is named "node-<random_suffix>". To open a shell inside +# a container after launch, use: +# docker exec -it node-<random_suffix> /bin/bash +# +# Then, you can execute vLLM commands on the Ray cluster as if it were a +# single machine, e.g. vllm serve ... +# +# To stop the container, use: +# docker stop node-<random_suffix> -# Check for minimum number of required arguments +# Check for minimum number of required arguments. if [ $# -lt 4 ]; then - echo "Usage: $0 docker_image head_node_address --head|--worker path_to_hf_home [additional_args...]" + echo "Usage: $0 docker_image head_node_ip --head|--worker path_to_hf_home [additional_args...]" exit 1 fi -# Assign the first three arguments and shift them away +# Extract the mandatory positional arguments and remove them from $@. DOCKER_IMAGE="$1" HEAD_NODE_ADDRESS="$2" -NODE_TYPE="$3" # Should be --head or --worker +NODE_TYPE="$3" # Should be --head or --worker. PATH_TO_HF_HOME="$4" shift 4 -# Additional arguments are passed directly to the Docker command +# Preserve any extra arguments so they can be forwarded to Docker. ADDITIONAL_ARGS=("$@") -# Validate node type +# Validate the NODE_TYPE argument. if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then echo "Error: Node type must be --head or --worker" exit 1 fi -# Define a function to cleanup on EXIT signal +# Generate a unique container name with random suffix. +# Docker container names must be unique on each host. +# The random suffix allows multiple Ray containers to run simultaneously on the same machine, +# for example, on a multi-GPU machine. +CONTAINER_NAME="node-${RANDOM}" + +# Define a cleanup routine that removes the container when the script exits. +# This prevents orphaned containers from accumulating if the script is interrupted. cleanup() { - docker stop node - docker rm node + docker stop "${CONTAINER_NAME}" + docker rm "${CONTAINER_NAME}" } trap cleanup EXIT -# Command setup for head or worker node +# Build the Ray start command based on the node role. +# The head node manages the cluster and accepts connections on port 6379, +# while workers connect to the head's address. RAY_START_CMD="ray start --block" if [ "${NODE_TYPE}" == "--head" ]; then RAY_START_CMD+=" --head --port=6379" @@ -37,11 +83,15 @@ else RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379" fi -# Run the docker command with the user specified parameters and additional arguments +# Launch the container with the assembled parameters. +# --network host: Allows Ray nodes to communicate directly via host networking +# --shm-size 10.24g: Increases shared memory +# --gpus all: Gives container access to all GPUs on the host +# -v HF_HOME: Mounts HuggingFace cache to avoid re-downloading models docker run \ --entrypoint /bin/bash \ --network host \ - --name node \ + --name "${CONTAINER_NAME}" \ --shm-size 10.24g \ --gpus all \ -v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \ diff --git a/examples/others/tensorize_vllm_model.py b/examples/others/tensorize_vllm_model.py index 11233229561b..64a6c42ae235 100644 --- a/examples/others/tensorize_vllm_model.py +++ b/examples/others/tensorize_vllm_model.py @@ -4,6 +4,7 @@ import argparse import dataclasses import json +import logging import os import uuid @@ -15,9 +16,13 @@ TensorizerConfig, tensorize_lora_adapter, tensorize_vllm_model, + tensorizer_kwargs_arg, ) from vllm.utils import FlexibleArgumentParser +logger = logging.getLogger() + + # yapf conflicts with isort for this docstring # yapf: disable """ @@ -119,7 +124,7 @@ """ -def parse_args(): +def get_parser(): parser = FlexibleArgumentParser( description="An example script that can be used to serialize and " "deserialize vLLM models. These models " @@ -135,13 +140,13 @@ def parse_args(): required=False, help="Path to a LoRA adapter to " "serialize along with model tensors. This can then be deserialized " - "along with the model by passing a tensorizer_config kwarg to " - "LoRARequest with type TensorizerConfig. See the docstring for this " - "for a usage example." - + "along with the model by instantiating a TensorizerConfig object, " + "creating a dict from it with TensorizerConfig.to_serializable(), " + "and passing it to LoRARequest's initializer with the kwarg " + "tensorizer_config_dict." ) - subparsers = parser.add_subparsers(dest='command') + subparsers = parser.add_subparsers(dest='command', required=True) serialize_parser = subparsers.add_parser( 'serialize', help="Serialize a model to `--serialized-directory`") @@ -171,6 +176,14 @@ def parse_args(): "where `suffix` is given by `--suffix` or a random UUID if not " "provided.") + serialize_parser.add_argument( + "--serialization-kwargs", + type=tensorizer_kwargs_arg, + required=False, + help=("A JSON string containing additional keyword arguments to " + "pass to Tensorizer's TensorSerializer during " + "serialization.")) + serialize_parser.add_argument( "--keyfile", type=str, @@ -186,9 +199,17 @@ def parse_args(): deserialize_parser.add_argument( "--path-to-tensors", type=str, - required=True, + required=False, help="The local path or S3 URI to the model tensors to deserialize. ") + deserialize_parser.add_argument( + "--serialized-directory", + type=str, + required=False, + help="Directory with model artifacts for loading. Assumes a " + "model.tensors file exists therein. Can supersede " + "--path-to-tensors.") + deserialize_parser.add_argument( "--keyfile", type=str, @@ -196,11 +217,27 @@ def parse_args(): help=("Path to a binary key to use to decrypt the model weights," " if the model was serialized with encryption")) - TensorizerArgs.add_cli_args(deserialize_parser) + deserialize_parser.add_argument( + "--deserialization-kwargs", + type=tensorizer_kwargs_arg, + required=False, + help=("A JSON string containing additional keyword arguments to " + "pass to Tensorizer's `TensorDeserializer` during " + "deserialization.")) - return parser.parse_args() + TensorizerArgs.add_cli_args(deserialize_parser) + return parser +def merge_extra_config_with_tensorizer_config(extra_cfg: dict, + cfg: TensorizerConfig): + for k, v in extra_cfg.items(): + if hasattr(cfg, k): + setattr(cfg, k, v) + logger.info( + "Updating TensorizerConfig with %s from " + "--model-loader-extra-config provided", k + ) def deserialize(args, tensorizer_config): if args.lora_path: @@ -230,7 +267,8 @@ def deserialize(args, tensorizer_config): lora_request=LoRARequest("sql-lora", 1, args.lora_path, - tensorizer_config = tensorizer_config) + tensorizer_config_dict = tensorizer_config + .to_serializable()) ) ) else: @@ -243,7 +281,8 @@ def deserialize(args, tensorizer_config): def main(): - args = parse_args() + parser = get_parser() + args = parser.parse_args() s3_access_key_id = (getattr(args, 's3_access_key_id', None) or os.environ.get("S3_ACCESS_KEY_ID", None)) @@ -265,13 +304,24 @@ def main(): else: keyfile = None + extra_config = {} if args.model_loader_extra_config: - config = json.loads(args.model_loader_extra_config) - tensorizer_args = \ - TensorizerConfig(**config)._construct_tensorizer_args() - tensorizer_args.tensorizer_uri = args.path_to_tensors - else: - tensorizer_args = None + extra_config = json.loads(args.model_loader_extra_config) + + + tensorizer_dir = (args.serialized_directory or + extra_config.get("tensorizer_dir")) + tensorizer_uri = (getattr(args, "path_to_tensors", None) + or extra_config.get("tensorizer_uri")) + + if tensorizer_dir and tensorizer_uri: + parser.error("--serialized-directory and --path-to-tensors " + "cannot both be provided") + + if not tensorizer_dir and not tensorizer_uri: + parser.error("Either --serialized-directory or --path-to-tensors " + "must be provided") + if args.command == "serialize": eng_args_dict = {f.name: getattr(args, f.name) for f in @@ -281,7 +331,7 @@ def main(): argparse.Namespace(**eng_args_dict) ) - input_dir = args.serialized_directory.rstrip('/') + input_dir = tensorizer_dir.rstrip('/') suffix = args.suffix if args.suffix else uuid.uuid4().hex base_path = f"{input_dir}/vllm/{model_ref}/{suffix}" if engine_args.tensor_parallel_size > 1: @@ -292,21 +342,29 @@ def main(): tensorizer_config = TensorizerConfig( tensorizer_uri=model_path, encryption_keyfile=keyfile, - **credentials) + serialization_kwargs=args.serialization_kwargs or {}, + **credentials + ) if args.lora_path: tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir tensorize_lora_adapter(args.lora_path, tensorizer_config) + merge_extra_config_with_tensorizer_config(extra_config, + tensorizer_config) tensorize_vllm_model(engine_args, tensorizer_config) elif args.command == "deserialize": - if not tensorizer_args: - tensorizer_config = TensorizerConfig( - tensorizer_uri=args.path_to_tensors, - encryption_keyfile = keyfile, - **credentials - ) + tensorizer_config = TensorizerConfig( + tensorizer_uri=args.path_to_tensors, + tensorizer_dir=args.serialized_directory, + encryption_keyfile=keyfile, + deserialization_kwargs=args.deserialization_kwargs or {}, + **credentials + ) + + merge_extra_config_with_tensorizer_config(extra_config, + tensorizer_config) deserialize(args, tensorizer_config) else: raise ValueError("Either serialize or deserialize must be specified.") diff --git a/examples/tool_chat_template_deepseekr1.jinja b/examples/tool_chat_template_deepseekr1.jinja index 9ae19341fc48..908574be9df5 100644 --- a/examples/tool_chat_template_deepseekr1.jinja +++ b/examples/tool_chat_template_deepseekr1.jinja @@ -11,7 +11,7 @@ {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} {%- endif %} {%- endif %} -{%- endfor %} +{%- endfor -%} {#- Adapted from https://github.com/sgl-project/sglang/blob/main/examples/chat_template/tool_chat_template_deepseekr1.jinja #} {% if tools is defined and tools is not none %} @@ -27,8 +27,8 @@ {% set ns.system_prompt = ns.system_prompt + '\n\n' + tool_ns.text %} {% endif %} -{{ bos_token }} -{{ ns.system_prompt }} +{{- bos_token }} +{{- ns.system_prompt }} {%- for message in messages %} {% set content = message['content'] %} {%- if message['role'] == 'user' %} @@ -45,7 +45,7 @@ {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} {%- set ns.is_last_user = false -%} {%- if ns.is_tool %} - {{'<|tool▁outputs▁end|>'}} + {{- '<|tool▁outputs▁end|>'}} {%- endif %} {%- set ns.is_first = false %} {%- set ns.is_tool = false -%} @@ -53,40 +53,40 @@ {%- for tool in message['tool_calls'] %} {%- if not ns.is_first %} {%- if content is none %} - {{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {{- '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} {%- else %} - {{content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {{- content + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} {%- endif %} {%- set ns.is_first = true -%} {%- else %} - {{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {{- '\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} {%- endif %} {%- endfor %} - {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {{- '<|tool▁calls▁end|><|end▁of▁sentence|>'}} {%- endif %} {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%} {%- set ns.is_last_user = false -%} {%- if ns.is_tool %} - {{'<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}} + {{- '<|tool▁outputs▁end|>' + content + '<|end▁of▁sentence|>'}} {%- set ns.is_tool = false -%} {%- else %} - {{content + '<|end▁of▁sentence|>'}} + {{- content + '<|end▁of▁sentence|>'}} {%- endif %} {%- endif %} {%- if message['role'] == 'tool' %} {%- set ns.is_last_user = false -%} {%- set ns.is_tool = true -%} {%- if ns.is_output_first %} - {{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}} + {{- '<|tool▁outputs▁begin|><|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}} {%- set ns.is_output_first = false %} {%- else %} - {{'\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}} + {{- '\n<|tool▁output▁begin|>' + content + '<|tool▁output▁end|>'}} {%- endif %} {%- endif %} {%- endfor -%} {% if ns.is_tool %} - {{'<|tool▁outputs▁end|>'}} -{% endif %} + {{- '<|tool▁outputs▁end|>'}} +{%- endif %} {% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %} - {{'<|Assistant|>'}} -{% endif %} + {{- '<|Assistant|>'}} +{%- endif %} \ No newline at end of file diff --git a/examples/tool_chat_template_hunyuan_a13b.jinja b/examples/tool_chat_template_hunyuan_a13b.jinja new file mode 100644 index 000000000000..a0808e44858a --- /dev/null +++ b/examples/tool_chat_template_hunyuan_a13b.jinja @@ -0,0 +1,113 @@ +{% set loop_messages = messages %} +{% if tools %} + {% set weekday_map = {'Monday': '星期一', 'Tuesday': '星期二', 'Wednesday': '星期三', 'Thursday': '星期四', 'Friday': '星期五', 'Saturday': '星期六', 'Sunday': '星期日'} %} + {% set weekday_cn = weekday_map[strftime_now('%A')] %} + {% set datetime_str = strftime_now('%Y-%m-%d %H:%M:%S') %} + {% set datetime_str = datetime_str + ' ' + weekday_cn %} + {% for message in loop_messages %} + {% if 'content' in message %} + {% set content = message['content'] %} + {% else %} + {% set content = '' %} + {% endif %} + {% if loop.index0 == 0 %} + {% set content_tmp = '你是一位函数组合专家。你会得到一个问题和一组可能的函数。根据问题,你需要进行一个或多个函数/工具调用以实现目的。 +如果没有一个函数可以使用,请直接使用自然语言回复用户,以助手:开头。 +如果给定的问题缺少函数所需的参数,请使用自然语言进行提问,向用户询问必要信息,以助手:开头。 +如果调用结果已经足够回答用户问题,请对历史结果进行总结,使用自然语言回复用户,以助手:开头。 +你应该只在工具调用部分返回函数调用。如果你决定调用任何函数,你必须将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>。你不应该在回复中包含任何其他文本。以下是你可以调用的函数列表,格式为JSON。 +' %} + {% set content_tmp = content_tmp + ' +' + tools | tojson + ' +' %} + {% if message['role'] == 'system' %} + {% set content_tmp = content_tmp + ' +额外要求: +' + content + ' + +如果你决定返回函数调用,请将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>,不得包含其他文本。如果额外要求里有格式要求,请忽略,以此处为准。 +否则,请参考开头说的三种情况,以助手:开头进行回复。 + +如果额外要求里有时间信息,就以额外要求里的时间为准,否则,参考当前时间:' + datetime_str %} + {% set content = '<|startoftext|>' + content_tmp + '<|extra_4|>' %} + {% elif message['role'] == 'user' %} + {% set content_tmp = content_tmp + ' +如果你决定返回函数调用,请将其格式化为<tool_calls>[{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},...]</tool_calls>,不得包含其他文本。 +否则,请参考开头说的三种情况,以助手:开头进行回复。 + +当前时间:' + datetime_str %} + {% set content_tmp = '<|startoftext|>' + content_tmp + '<|extra_4|>'%} + {% set content = content_tmp + '用户:' + content + '<|extra_0|>' %} + {% endif %} + {% else %} + {% if message['role'] == 'user' %} + {% set content = '用户:' + content + '<|extra_0|>' %} + {% elif message['role'] == 'assistant' %} + {% if 'tool_calls' in message %} + {% set tool_calls = message['tool_calls'] %} + {% set ns = namespace(tool_calls="[") %} + {% for tool_call in tool_calls %} + {% set function = tool_call['function'] %} + {% set name = function['name'] %} + {% set ns.tool_calls = ns.tool_calls + '{"name": "' + name + '", '%} + {% set arguments = function['arguments'] %} + {% if arguments is not string %} + {% set arguments = arguments | tojson %} + {% endif %} + {% set ns.tool_calls = ns.tool_calls + '"arguments": ' + arguments + '}' %} + {% if not loop.last %} + {% set ns.tool_calls = ns.tool_calls + ', '%} + {% endif %} + {% endfor %} + {% set ns.tool_calls = ns.tool_calls + ']' %} + {% set content = content + '<tool_calls>' + ns.tool_calls + '</tool_calls>' %} + {% else %} + {% set content = '助手:' + content %} + {% endif %} + {% set content = content + '<|eos|>' %} + {% elif message['role'] == 'tool' %} + {% if content is not string %} + {set content = content | tojson } + {% endif %} + {% set content = '<tool_response>' + content + '</tool_response>' %} + {% set content = content + '<|extra_0|>' %} + {% endif %} + {% endif %} + {{- content -}} + {% endfor %} +{% else %} + {% set context = {'has_head': true} %} + {% for message in loop_messages %} + {% if 'content' in message %} + {% set content = message['content'] %} + {% else %} + {% set content = '' %} + {% endif %} + {% if loop.index0 == 0 %} + {% if content == '' %} + {% set _ = context.update({'has_head': false}) %} + {% elif message['role'] == 'system' %} + {% set content = '<|startoftext|>' + content + '<|extra_4|>' %} + {% endif %} + {% endif %} + {% if message['role'] == 'user' %} + {% if loop.index0 == 1 and not context.has_head %} + {% set content = '<|startoftext|>' + content %} + {% endif %} + {% if loop.index0 == 1 and context.has_head %} + {% set content = content + '<|extra_0|>' %} + {% else %} + {% set content = '<|startoftext|>' + content + '<|extra_0|>' %} + {% endif %} + {% elif message['role'] == 'assistant' %} + {% set content = content + '<|eos|>' %} + {% elif message['role'] == 'tool' %} + {% set content = content + '<|extra_0|>' %} + {% endif %} + {{- content -}} + {% endfor %} +{% endif %} +{%- if enable_thinking is defined and enable_thinking is false %} + {{- '<think>\n\n</think>\n' }} +{%- endif %} + diff --git a/mkdocs.yaml b/mkdocs.yaml index 45b6ffadbeb7..b392fb160c2a 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -3,6 +3,7 @@ site_url: https://docs.vllm.ai repo_url: https://github.com/vllm-project/vllm edit_uri: edit/main/docs/ exclude_docs: | + argparse *.inc.md *.template.md theme: @@ -47,6 +48,7 @@ theme: hooks: - docs/mkdocs/hooks/remove_announcement.py - docs/mkdocs/hooks/generate_examples.py + - docs/mkdocs/hooks/generate_argparse.py - docs/mkdocs/hooks/url_schemes.py # Required to stop api-autonav from raising an error @@ -59,6 +61,7 @@ plugins: - search - autorefs - awesome-nav + - glightbox # For API reference generation - api-autonav: modules: ["vllm"] diff --git a/pyproject.toml b/pyproject.toml index 340abb385657..a65267942d47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ "packaging>=24.2", "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.7.0", + "torch == 2.7.1", "wheel", "jinja2", ] @@ -72,8 +72,6 @@ line-length = 80 "vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] -"vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] # Python 3.8 typing - skip utils for ROCm "vllm/utils/__init__.py" = ["UP006", "UP035"] @@ -174,3 +172,186 @@ respect-ignore-files = true [tool.ty.environment] python = "./.venv" + +[tool.typos.files] +# these files may be written in non english words +extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*", + "benchmarks/sonnet.txt", "tests/lora/data/*", "build/*", + "vllm/third_party/*"] +ignore-hidden = true +ignore-files = true +ignore-dot = true +ignore-vcs = true +ignore-global = true +ignore-parent = true + +[tool.typos.default] +binary = false +check-filename = false +check-file = true +unicode = true +ignore-hex = true +identifier-leading-digits = false +locale = "en" +extend-ignore-identifiers-re = ["NVML_*", ".*Unc.*", ".*_thw", + ".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", + ".*[Tt]h[rR].*"] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.default.extend-identifiers] +bbc5b7ede = "bbc5b7ede" +womens_doubles = "womens_doubles" +v_2nd = "v_2nd" +# splitted_input = "splitted_input" +NOOPs = "NOOPs" +typ = "typ" +nin_shortcut = "nin_shortcut" +UperNetDecoder = "UperNetDecoder" +subtile = "subtile" +cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin" +SFOuput = "SFOuput" +# huggingface transformers repo uses these words +depthwise_seperable_out_channel = "depthwise_seperable_out_channel" +DepthWiseSeperableConv1d = "DepthWiseSeperableConv1d" +depthwise_seperable_CNN = "depthwise_seperable_CNN" + +[tool.typos.default.extend-words] +iy = "iy" +tendencias = "tendencias" +# intel cpu features +tme = "tme" +dout = "dout" +Pn = "Pn" +arange = "arange" + +[tool.typos.type.py] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.py.extend-identifiers] +arange = "arange" +NDArray = "NDArray" +EOFError = "EOFError" +fo = "fo" +ba = "ba" + +[tool.typos.type.py.extend-words] + +[tool.typos.type.cpp] +extend-glob = ["*.cu"] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.cpp.extend-identifiers] +countr_one = "countr_one" +k_ot = "k_ot" +ot = "ot" + +[tool.typos.type.cpp.extend-words] + +[tool.typos.type.rust] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.rust.extend-identifiers] +flate2 = "flate2" + +[tool.typos.type.rust.extend-words] +ser = "ser" + +[tool.typos.type.lock] +extend-glob = [] +check-file = false +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.lock.extend-identifiers] + +[tool.typos.type.lock.extend-words] + +[tool.typos.type.jl] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.jl.extend-identifiers] + +[tool.typos.type.jl.extend-words] +modul = "modul" +egals = "egals" +usig = "usig" +egal = "egal" + +[tool.typos.type.go] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.go.extend-identifiers] +flate = "flate" + +[tool.typos.type.go.extend-words] + +[tool.typos.type.css] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.css.extend-identifiers] +nd = "nd" + +[tool.typos.type.css.extend-words] + +[tool.typos.type.man] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.man.extend-identifiers] +Nd = "Nd" + +[tool.typos.type.man.extend-words] + +[tool.typos.type.cert] +extend-glob = [] +check-file = false +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.cert.extend-identifiers] + +[tool.typos.type.cert.extend-words] + +[tool.typos.type.sh] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.sh.extend-identifiers] +ot = "ot" + +[tool.typos.type.sh.extend-words] + +[tool.typos.type.vimscript] +extend-glob = [] +extend-ignore-identifiers-re = [] +extend-ignore-words-re = [] +extend-ignore-re = [] + +[tool.typos.type.vimscript.extend-identifiers] +windo = "windo" + +[tool.typos.type.vimscript.extend-words] diff --git a/requirements/build.txt b/requirements/build.txt index 528cd3b538ef..dd644d621efc 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -4,7 +4,7 @@ ninja packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.7.0 +torch==2.7.1 wheel jinja2>=3.1.6 regex diff --git a/requirements/common.txt b/requirements/common.txt index 8bc0be7779af..013b831be32e 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -7,13 +7,12 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.51.1 -huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads. +transformers >= 4.53.2 tokenizers >= 0.21.1 # Required for fast incremental detokenization. protobuf # Required by LlamaTokenizer. fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint. aiohttp -openai >= 1.52.0, <= 1.90.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support) +openai >= 1.87.0, <= 1.90.0 # Ensure modern openai package (ensure ResponsePrompt exists in type.responses and max_completion_tokens field support) pydantic >= 2.10 prometheus_client >= 0.18.0 pillow # Required for image processing @@ -21,9 +20,11 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" -outlines == 0.1.11 +outlines_core == 0.2.10 +# required for outlines backend disk cache +diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" +xgrammar == 0.1.21; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs @@ -31,17 +32,18 @@ pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 importlib_metadata; python_version < '3.10' -mistral_common[opencv] >= 1.6.2 +mistral_common[image,audio] >= 1.8.2 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. compressed-tensors == 0.10.2 # required for compressed-tensors -depyf==0.18.0 # required for profiling and debugging with compilation config +depyf==0.19.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/others/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu pybase64 # fast base64 implementation +cbor2 # Required for cross-language serialization of hashable objects diff --git a/requirements/cpu.txt b/requirements/cpu.txt index df3a3393563a..d80354342bc2 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -24,6 +24,4 @@ datasets # for benchmark scripts # Intel Extension for PyTorch, only for x86_64 CPUs intel-openmp==2024.2.1; platform_machine == "x86_64" intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218 -py-libnuma; platform_system != "Darwin" -psutil; platform_system != "Darwin" triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile. diff --git a/requirements/cuda.txt b/requirements/cuda.txt index f1535eae3cae..40a33f25ed40 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -7,9 +7,9 @@ numba == 0.61.2; python_version > '3.9' # Dependencies for NVIDIA GPUs pyarrow == 19.0.1 # temporary fix for missing pyarrow in ray until 2.44.2 ray release ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.7.0 -torchaudio==2.7.0 +torch==2.7.1 +torchaudio==2.7.1 # These must be updated alongside torch -torchvision==0.22.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -# https://github.com/facebookresearch/xformers/releases/tag/v0.0.30 -xformers==0.0.30; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 +torchvision==0.22.1 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# https://github.com/facebookresearch/xformers/releases/tag/v0.0.31 +xformers==0.0.31; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 diff --git a/requirements/docs.txt b/requirements/docs.txt index 64c70cb65c55..1ddc825a9cdd 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -4,6 +4,24 @@ mkdocs-material mkdocstrings-python mkdocs-gen-files mkdocs-awesome-nav +mkdocs-glightbox python-markdown-math regex ruff + +# Required for argparse hook only +-f https://download.pytorch.org/whl/cpu +cachetools +cbor2 +cloudpickle +fastapi +msgspec +openai +partial-json-parser +pillow +psutil +pybase64 +pydantic +torch +transformers +zmq diff --git a/requirements/hpu.txt b/requirements/hpu.txt deleted file mode 100644 index a88777268a34..000000000000 --- a/requirements/hpu.txt +++ /dev/null @@ -1,12 +0,0 @@ -# Common dependencies --r common.txt - -# Dependencies for HPU code -ray -triton==3.1.0 -pandas -numpy==1.26.4 -tabulate -setuptools>=77.0.3,<80.0.0 -setuptools-scm>=8 -vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624 diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 0bade084fdf6..09a32cc13c72 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -1,6 +1,6 @@ # testing pytest -tensorizer>=2.9.0 +tensorizer==2.10.1 pytest-forked pytest-asyncio pytest-rerunfailures @@ -23,7 +23,7 @@ jiwer # required for audio tests timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.6.2 # required for pixtral test +mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test @@ -31,7 +31,6 @@ lm-eval[api]==0.4.8 # required for model evaluation test mteb>=1.38.11, <2 # required for mteb test transformers==4.52.4 tokenizers==0.21.1 -huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes>=0.46.1 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index d33021fc7597..7038c9024c6b 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -11,9 +11,10 @@ datasets ray>=2.10.0,<2.45.0 peft pytest-asyncio -tensorizer>=2.9.0 +tensorizer==2.10.1 packaging>=24.2 setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 +conch-triton-kernels==1.2.1 diff --git a/requirements/test.in b/requirements/test.in index 5f8b97a0e341..899b857fb98f 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,6 +1,6 @@ # testing pytest -tensorizer>=2.9.0 +tensorizer==2.10.1 pytest-forked pytest-asyncio pytest-rerunfailures @@ -22,21 +22,21 @@ sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests timm # required for internvl test -torch==2.7.0 -torchaudio==2.7.0 -torchvision==0.22.0 +torch==2.7.1 +torchaudio==2.7.1 +torchvision==0.22.1 transformers_stream_generator # required for qwen-vl test mamba_ssm # required for plamo2 test matplotlib # required for qwen-vl test -mistral_common[opencv] >= 1.6.2 # required for pixtral test +mistral_common[image,audio] >= 1.8.2 # required for voxtral test num2words # required for smolvlm test +open_clip_torch==2.32.0 # Required for nemotron_vl test opencv-python-headless >= 4.11.0 # required for video test datamodel_code_generator # required for minicpm3 test lm-eval[api]==0.4.8 # required for model evaluation test mteb[bm25s]>=1.38.11, <2 # required for mteb test -transformers==4.52.4 +transformers==4.53.2 tokenizers==0.21.1 -huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads. schemathesis>=3.39.15 # Required for openai schema test. # quantization bitsandbytes==0.46.1 @@ -53,3 +53,4 @@ runai-model-streamer==0.11.0 runai-model-streamer-s3==0.11.0 fastsafetensors>=0.1.10 pydantic>=2.10 # 2.9 leads to error on python 3.10 +terratorch==1.1rc2 # required for PrithviMAE test \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index f6f599df758f..0c609d5f7e2b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -6,6 +6,10 @@ accelerate==1.0.1 # via # lm-eval # peft +aenum==3.1.16 + # via lightly +affine==2.4.0 + # via rasterio aiohappyeyeballs==2.4.3 # via aiohttp aiohttp==3.10.11 @@ -21,8 +25,18 @@ aiosignal==1.3.1 # via # aiohttp # ray +albucore==0.0.16 + # via terratorch +albumentations==1.4.6 + # via terratorch +alembic==1.16.4 + # via mlflow annotated-types==0.7.0 # via pydantic +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf anyio==4.6.2.post1 # via # httpx @@ -34,10 +48,12 @@ arrow==1.3.0 attrs==24.2.0 # via # aiohttp + # fiona # hypothesis # jsonlines # jsonschema # pytest-subtests + # rasterio # referencing audioread==3.0.1 # via librosa @@ -46,9 +62,13 @@ backoff==2.2.1 # -r requirements/test.in # schemathesis bitsandbytes==0.46.1 - # via -r requirements/test.in + # via + # -r requirements/test.in + # lightning black==24.10.0 # via datamodel-code-generator +blinker==1.9.0 + # via flask blobfile==3.0.0 # via -r requirements/test.in bm25s==0.2.13 @@ -64,11 +84,18 @@ bounded-pool-executor==0.0.3 buildkite-test-collector==0.1.9 # via -r requirements/test.in cachetools==5.5.2 - # via google-auth + # via + # google-auth + # mlflow-skinny certifi==2024.8.30 # via + # fiona # httpcore # httpx + # lightly + # pyogrio + # pyproj + # rasterio # requests cffi==1.17.1 # via soundfile @@ -79,11 +106,28 @@ charset-normalizer==3.4.0 click==8.1.7 # via # black + # click-plugins + # cligj + # fiona + # flask # jiwer + # mlflow-skinny # nltk + # rasterio # ray # schemathesis # typer + # uvicorn +click-plugins==1.1.1.2 + # via + # fiona + # rasterio +cligj==0.7.2 + # via + # fiona + # rasterio +cloudpickle==3.1.1 + # via mlflow-skinny colorama==0.4.6 # via # sacrebleu @@ -99,6 +143,8 @@ cupy-cuda12x==13.3.0 # via ray cycler==0.12.1 # via matplotlib +databricks-sdk==0.59.0 + # via mlflow-skinny datamodel-code-generator==0.26.3 # via -r requirements/test.in dataproperty==1.0.1 @@ -122,13 +168,21 @@ distlib==0.3.9 # via virtualenv dnspython==2.7.0 # via email-validator +docker==7.1.0 + # via mlflow docopt==0.6.2 # via num2words -einops==0.8.0 +docstring-parser==0.17.0 + # via jsonargparse +efficientnet-pytorch==0.7.1 + # via segmentation-models-pytorch +einops==0.8.1 # via # -r requirements/test.in # encodec # mamba-ssm + # terratorch + # torchgeo # vector-quantize-pytorch # vocos einx==0.3.0 @@ -141,6 +195,8 @@ eval-type-backport==0.2.2 # via mteb evaluate==0.4.3 # via lm-eval +fastapi==0.116.1 + # via mlflow-skinny fastparquet==2024.11.0 # via genai-perf fastrlock==0.8.2 @@ -156,6 +212,10 @@ filelock==3.16.1 # torch # transformers # virtualenv +fiona==1.10.1 + # via torchgeo +flask==3.1.1 + # via mlflow fonttools==4.54.1 # via matplotlib fqdn==1.5.1 @@ -173,26 +233,53 @@ fsspec==2024.9.0 # evaluate # fastparquet # huggingface-hub + # lightning + # pytorch-lightning # torch +ftfy==6.3.1 + # via open-clip-torch genai-perf==0.0.8 # via -r requirements/test.in genson==1.3.0 # via datamodel-code-generator +geopandas==1.0.1 + # via terratorch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.44 + # via mlflow-skinny google-api-core==2.24.2 # via opencensus google-auth==2.40.2 - # via google-api-core + # via + # databricks-sdk + # google-api-core googleapis-common-protos==1.70.0 # via google-api-core +graphene==3.4.3 + # via mlflow graphql-core==3.2.6 - # via hypothesis-graphql + # via + # graphene + # graphql-relay + # hypothesis-graphql +graphql-relay==3.2.0 + # via graphene +greenlet==3.2.3 + # via sqlalchemy grpcio==1.71.0 # via ray +gunicorn==23.0.0 + # via mlflow h11==0.14.0 - # via httpcore + # via + # httpcore + # uvicorn +h5py==3.13.0 + # via terratorch harfile==0.3.0 # via schemathesis -hf-xet==1.1.3 +hf-xet==1.1.7 # via huggingface-hub hiredis==3.0.0 # via tensorizer @@ -202,20 +289,26 @@ httpx==0.27.2 # via # -r requirements/test.in # schemathesis -huggingface-hub==0.33.0 +huggingface-hub==0.33.1 # via - # -r requirements/test.in # accelerate # datasets # evaluate + # open-clip-torch # peft + # segmentation-models-pytorch # sentence-transformers + # terratorch # timm # tokenizers # transformers # vocos humanize==4.11.0 # via runai-model-streamer +hydra-core==1.3.2 + # via + # lightly + # lightning hypothesis==6.131.0 # via # hypothesis-graphql @@ -233,6 +326,14 @@ idna==3.10 # jsonschema # requests # yarl +imageio==2.37.0 + # via scikit-image +importlib-metadata==8.7.0 + # via + # mlflow-skinny + # opentelemetry-api +importlib-resources==6.5.2 + # via typeshed-client inflect==5.6.2 # via datamodel-code-generator iniconfig==2.0.0 @@ -241,9 +342,13 @@ isoduration==20.11.0 # via jsonschema isort==5.13.2 # via datamodel-code-generator +itsdangerous==2.2.0 + # via flask jinja2==3.1.6 # via # datamodel-code-generator + # flask + # mlflow # torch jiwer==3.0.5 # via -r requirements/test.in @@ -256,6 +361,10 @@ joblib==1.4.2 # librosa # nltk # scikit-learn +jsonargparse==4.35.0 + # via + # lightning + # terratorch jsonlines==4.0.0 # via lm-eval jsonpointer==3.0.0 @@ -274,12 +383,33 @@ kaleido==0.2.1 # via genai-perf kiwisolver==1.4.7 # via matplotlib +kornia==0.8.1 + # via torchgeo +kornia-rs==0.1.9 + # via kornia lazy-loader==0.4 - # via librosa + # via + # librosa + # scikit-image libnacl==2.1.0 # via tensorizer librosa==0.10.2.post1 # via -r requirements/test.in +lightly==1.5.20 + # via + # terratorch + # torchgeo +lightly-utils==0.0.2 + # via lightly +lightning==2.5.1.post0 + # via + # terratorch + # torchgeo +lightning-utilities==0.14.3 + # via + # lightning + # pytorch-lightning + # torchmetrics llvmlite==0.44.0 # via numba lm-eval==0.4.8 @@ -288,16 +418,27 @@ lxml==5.3.0 # via # blobfile # sacrebleu +mako==1.3.10 + # via alembic mamba-ssm==2.2.4 # via -r requirements/test.in +markdown==3.8.2 + # via mlflow markdown-it-py==3.0.0 # via rich markupsafe==3.0.1 # via + # flask # jinja2 + # mako # werkzeug matplotlib==3.9.2 - # via -r requirements/test.in + # via + # -r requirements/test.in + # lightning + # mlflow + # pycocotools + # torchgeo mbstrdecoder==1.1.3 # via # dataproperty @@ -305,8 +446,12 @@ mbstrdecoder==1.1.3 # typepy mdurl==0.1.2 # via markdown-it-py -mistral-common==1.6.2 +mistral-common==1.8.2 # via -r requirements/test.in +mlflow==2.22.0 + # via terratorch +mlflow-skinny==2.22.0 + # via mlflow more-itertools==10.5.0 # via lm-eval mpmath==1.3.0 @@ -325,10 +470,14 @@ multiprocess==0.70.16 # via # datasets # evaluate +munch==4.0.0 + # via pretrainedmodels mypy-extensions==1.0.0 # via black networkx==3.2.1 - # via torch + # via + # scikit-image + # torch ninja==1.11.1.3 # via mamba-ssm nltk==3.9.1 @@ -345,6 +494,8 @@ numpy==1.26.4 # via # -r requirements/test.in # accelerate + # albucore + # albumentations # bitsandbytes # bm25s # contourpy @@ -355,9 +506,15 @@ numpy==1.26.4 # evaluate # fastparquet # genai-perf + # geopandas + # h5py + # imageio # librosa + # lightly + # lightly-utils # matplotlib # mistral-common + # mlflow # mteb # numba # numexpr @@ -365,18 +522,30 @@ numpy==1.26.4 # pandas # patsy # peft + # pycocotools + # pyogrio + # rasterio + # rioxarray # rouge-score # runai-model-streamer # sacrebleu + # scikit-image # scikit-learn # scipy + # segmentation-models-pytorch + # shapely # soxr # statsmodels + # tensorboardx # tensorizer + # tifffile + # torchgeo + # torchmetrics # torchvision # transformers # tritonclient # vocos + # xarray nvidia-cublas-cu12==12.8.3.14 # via # nvidia-cudnn-cu12 @@ -414,6 +583,12 @@ nvidia-nvjitlink-cu12==12.8.61 # torch nvidia-nvtx-cu12==12.8.55 # via torch +omegaconf==2.3.0 + # via + # hydra-core + # lightning +open-clip-torch==2.32.0 + # via -r requirements/test.in opencensus==0.11.4 # via ray opencensus-context==0.1.3 @@ -421,7 +596,18 @@ opencensus-context==0.1.3 opencv-python-headless==4.11.0.86 # via # -r requirements/test.in + # albucore + # albumentations # mistral-common +opentelemetry-api==1.35.0 + # via + # mlflow-skinny + # opentelemetry-sdk + # opentelemetry-semantic-conventions +opentelemetry-sdk==1.35.0 + # via mlflow-skinny +opentelemetry-semantic-conventions==0.56b0 + # via opentelemetry-sdk packaging==24.2 # via # accelerate @@ -430,26 +616,44 @@ packaging==24.2 # datasets # evaluate # fastparquet + # geopandas + # gunicorn # huggingface-hub + # hydra-core + # kornia # lazy-loader + # lightning + # lightning-utilities # mamba-ssm # matplotlib + # mlflow-skinny # peft # plotly # pooch + # pyogrio # pytest # pytest-rerunfailures + # pytorch-lightning # ray + # rioxarray + # scikit-image # statsmodels + # tensorboardx + # torchmetrics # transformers # typepy + # xarray pandas==2.2.3 # via # datasets # evaluate # fastparquet # genai-perf + # geopandas + # mlflow # statsmodels + # torchgeo + # xarray pathspec==0.12.1 # via black pathvalidate==3.2.1 @@ -463,9 +667,14 @@ peft==0.13.2 pillow==10.4.0 # via # genai-perf + # imageio + # lightly-utils # matplotlib # mistral-common + # scikit-image + # segmentation-models-pytorch # sentence-transformers + # torchgeo # torchvision platformdirs==4.3.6 # via @@ -484,6 +693,8 @@ portalocker==2.10.1 # via sacrebleu pqdm==0.2.0 # via -r requirements/test.in +pretrainedmodels==0.7.4 + # via segmentation-models-pytorch prometheus-client==0.22.0 # via ray propcache==0.2.0 @@ -494,8 +705,10 @@ protobuf==5.28.3 # via # google-api-core # googleapis-common-protos + # mlflow-skinny # proto-plus # ray + # tensorboardx # tensorizer psutil==6.1.0 # via @@ -510,6 +723,7 @@ pyarrow==18.0.0 # via # datasets # genai-perf + # mlflow pyasn1==0.6.1 # via # pyasn1-modules @@ -518,6 +732,10 @@ pyasn1-modules==0.4.2 # via google-auth pybind11==2.13.6 # via lm-eval +pycocotools==2.0.8 + # via terratorch +pycountry==24.6.1 + # via pydantic-extra-types pycparser==2.22 # via cffi pycryptodomex==3.22.0 @@ -525,23 +743,39 @@ pycryptodomex==3.22.0 pydantic==2.11.5 # via # -r requirements/test.in + # albumentations # datamodel-code-generator + # fastapi + # lightly # mistral-common + # mlflow-skinny # mteb + # pydantic-extra-types # ray pydantic-core==2.33.2 # via pydantic +pydantic-extra-types==2.10.5 + # via mistral-common pygments==2.18.0 # via rich +pyogrio==0.11.0 + # via geopandas pyparsing==3.2.0 - # via matplotlib + # via + # matplotlib + # rasterio +pyproj==3.7.1 + # via + # geopandas + # rioxarray + # torchgeo pyrate-limiter==3.7.0 # via schemathesis pystemmer==3.0.0 # via mteb pytablewriter==1.2.0 # via lm-eval -pytest==8.3.3 +pytest==8.3.5 # via # -r requirements/test.in # buildkite-test-collector @@ -554,6 +788,7 @@ pytest==8.3.3 # pytest-subtests # pytest-timeout # schemathesis + # terratorch pytest-asyncio==0.24.0 # via -r requirements/test.in pytest-forked==1.6.0 @@ -568,15 +803,23 @@ pytest-subtests==0.14.1 # via schemathesis pytest-timeout==2.3.1 # via -r requirements/test.in +python-box==7.3.2 + # via terratorch python-dateutil==2.9.0.post0 # via # arrow # botocore + # graphene + # lightly # matplotlib # pandas # typepy python-rapidjson==1.20 # via tritonclient +pytorch-lightning==2.5.2 + # via + # lightly + # lightning pytrec-eval-terrier==0.5.7 # via mteb pytz==2024.2 @@ -586,11 +829,17 @@ pytz==2024.2 pyyaml==6.0.2 # via # accelerate + # albumentations # datamodel-code-generator # datasets # genai-perf # huggingface-hub + # jsonargparse + # lightning + # mlflow-skinny + # omegaconf # peft + # pytorch-lightning # ray # responses # schemathesis @@ -599,6 +848,11 @@ pyyaml==6.0.2 # vocos rapidfuzz==3.12.1 # via jiwer +rasterio==1.4.3 + # via + # rioxarray + # terratorch + # torchgeo ray==2.43.0 # via -r requirements/test.in redis==5.2.0 @@ -610,18 +864,23 @@ referencing==0.35.1 regex==2024.9.11 # via # nltk + # open-clip-torch # sacrebleu # tiktoken # transformers requests==2.32.3 # via # buildkite-test-collector + # databricks-sdk # datasets + # docker # evaluate # google-api-core # huggingface-hub + # lightly # lm-eval # mistral-common + # mlflow-skinny # mteb # pooch # ray @@ -639,8 +898,11 @@ rfc3987==1.3.8 rich==13.9.4 # via # genai-perf + # lightning # mteb # typer +rioxarray==0.19.0 + # via terratorch rouge-score==0.1.2 # via lm-eval rpds-py==0.20.1 @@ -649,6 +911,8 @@ rpds-py==0.20.1 # referencing rsa==4.9.1 # via google-auth +rtree==1.4.0 + # via torchgeo runai-model-streamer==0.11.0 # via -r requirements/test.in runai-model-streamer-s3==0.11.0 @@ -660,26 +924,38 @@ sacrebleu==2.4.3 safetensors==0.4.5 # via # accelerate + # open-clip-torch # peft # timm # transformers schemathesis==3.39.15 # via -r requirements/test.in +scikit-image==0.25.2 + # via albumentations scikit-learn==1.5.2 # via + # albumentations # librosa # lm-eval + # mlflow # mteb # sentence-transformers scipy==1.13.1 # via + # albumentations # bm25s # librosa + # mlflow # mteb + # scikit-image # scikit-learn # sentence-transformers # statsmodels # vocos +segmentation-models-pytorch==0.4.0 + # via + # terratorch + # torchgeo sentence-transformers==3.2.1 # via # -r requirements/test.in @@ -688,21 +964,30 @@ sentencepiece==0.2.0 # via mistral-common setuptools==77.0.3 # via + # lightning-utilities # mamba-ssm # pytablewriter # torch # triton +shapely==2.1.1 + # via + # geopandas + # torchgeo shellingham==1.5.4 # via typer six==1.16.0 # via # junit-xml + # lightly # opencensus # python-dateutil # rfc3339-validator # rouge-score + # segmentation-models-pytorch smart-open==7.1.0 # via ray +smmap==5.0.2 + # via gitdb sniffio==1.3.1 # via # anyio @@ -713,12 +998,22 @@ soundfile==0.12.1 # via # -r requirements/test.in # librosa + # mistral-common soxr==0.5.0.post1 - # via librosa + # via + # librosa + # mistral-common +sqlalchemy==2.0.41 + # via + # alembic + # mlflow sqlitedict==2.1.0 # via lm-eval +sqlparse==0.5.3 + # via mlflow-skinny starlette==0.46.2 # via + # fastapi # schemathesis # starlette-testclient starlette-testclient==0.4.1 @@ -739,16 +1034,29 @@ tenacity==9.0.0 # via # lm-eval # plotly -tensorizer==2.9.0 +tensorboardx==2.6.4 + # via lightning +tensorizer==2.10.1 + # via -r requirements/test.in +terratorch==1.1rc2 # via -r requirements/test.in threadpoolctl==3.5.0 # via scikit-learn +tifffile==2025.3.30 + # via + # scikit-image + # terratorch tiktoken==0.7.0 # via # lm-eval # mistral-common -timm==1.0.11 - # via -r requirements/test.in +timm==1.0.15 + # via + # -r requirements/test.in + # open-clip-torch + # segmentation-models-pytorch + # terratorch + # torchgeo tokenizers==0.21.1 # via # -r requirements/test.in @@ -757,50 +1065,81 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.7.0+cu128 +torch==2.7.1+cu128 # via # -r requirements/test.in # accelerate # bitsandbytes + # efficientnet-pytorch # encodec # fastsafetensors + # kornia + # lightly + # lightning # lm-eval # mamba-ssm # mteb + # open-clip-torch # peft + # pretrainedmodels + # pytorch-lightning # runai-model-streamer + # segmentation-models-pytorch # sentence-transformers # tensorizer + # terratorch # timm # torchaudio + # torchgeo + # torchmetrics # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.7.0+cu128 +torchaudio==2.7.1+cu128 # via # -r requirements/test.in # encodec # vocos -torchvision==0.22.0+cu128 +torchgeo==0.7.0 + # via terratorch +torchmetrics==1.7.4 + # via + # lightning + # pytorch-lightning + # terratorch + # torchgeo +torchvision==0.22.1+cu128 # via # -r requirements/test.in + # lightly + # open-clip-torch + # pretrainedmodels + # segmentation-models-pytorch + # terratorch # timm + # torchgeo tqdm==4.66.6 # via # datasets # evaluate # huggingface-hub + # lightly + # lightning # lm-eval # mteb # nltk + # open-clip-torch # peft # pqdm + # pretrainedmodels + # pytorch-lightning + # segmentation-models-pytorch # sentence-transformers # tqdm-multiprocess # transformers tqdm-multiprocess==0.0.11 # via lm-eval -transformers==4.52.4 +transformers==4.53.2 # via # -r requirements/test.in # genai-perf @@ -811,7 +1150,7 @@ transformers==4.52.4 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.3.0 +triton==3.3.1 # via torch tritonclient==2.51.0 # via @@ -826,17 +1165,34 @@ typer==0.15.2 # via fastsafetensors types-python-dateutil==2.9.0.20241206 # via arrow +typeshed-client==2.8.2 + # via jsonargparse typing-extensions==4.12.2 # via + # albumentations + # alembic + # fastapi + # graphene # huggingface-hub # librosa + # lightning + # lightning-utilities # mistral-common + # mlflow-skinny # mteb + # opentelemetry-api + # opentelemetry-sdk + # opentelemetry-semantic-conventions # pqdm # pydantic # pydantic-core + # pydantic-extra-types + # pytorch-lightning + # sqlalchemy # torch + # torchgeo # typer + # typeshed-client # typing-inspection typing-inspection==0.4.1 # via pydantic @@ -848,23 +1204,33 @@ urllib3==2.2.3 # via # blobfile # botocore + # docker + # lightly # requests # responses # tritonclient +uvicorn==0.35.0 + # via mlflow-skinny vector-quantize-pytorch==1.21.2 # via -r requirements/test.in virtualenv==20.31.2 # via ray vocos==0.1.0 # via -r requirements/test.in +wcwidth==0.2.13 + # via ftfy webcolors==24.11.1 # via jsonschema werkzeug==3.1.3 - # via schemathesis + # via + # flask + # schemathesis word2number==1.1 # via lm-eval wrapt==1.17.2 # via smart-open +xarray==2025.7.1 + # via rioxarray xxhash==3.5.0 # via # datasets @@ -873,5 +1239,7 @@ yarl==1.17.1 # via # aiohttp # schemathesis +zipp==3.23.0 + # via importlib-metadata zstandard==0.23.0 # via lm-eval diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 2b5fd8941647..354771482ee3 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -18,9 +18,8 @@ setuptools==78.1.0 --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250618 -torchvision==0.23.0.dev20250618 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250618-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250618-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250618-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.9.0.dev20250716 +torchvision==0.24.0.dev20250716 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250716-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250716-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" diff --git a/setup.py b/setup.py index ea7cd0169c8b..d46e678e7aa4 100644 --- a/setup.py +++ b/setup.py @@ -410,29 +410,6 @@ def run(self) -> None: package_data[package_name].append(file_name) -def _is_hpu() -> bool: - # if VLLM_TARGET_DEVICE env var was set explicitly, skip HPU autodetection - if os.getenv("VLLM_TARGET_DEVICE", None) == VLLM_TARGET_DEVICE: - return VLLM_TARGET_DEVICE == "hpu" - - # if VLLM_TARGET_DEVICE was not set explicitly, check if hl-smi succeeds, - # and if it doesn't, check if habanalabs driver is loaded - is_hpu_available = False - try: - out = subprocess.run(["hl-smi"], capture_output=True, check=True) - is_hpu_available = out.returncode == 0 - except (FileNotFoundError, PermissionError, subprocess.CalledProcessError): - if sys.platform.startswith("linux"): - try: - output = subprocess.check_output( - 'lsmod | grep habanalabs | wc -l', shell=True) - is_hpu_available = int(output) > 0 - except (ValueError, FileNotFoundError, PermissionError, - subprocess.CalledProcessError): - pass - return is_hpu_available - - def _no_device() -> bool: return VLLM_TARGET_DEVICE == "empty" @@ -440,7 +417,7 @@ def _no_device() -> bool: def _is_cuda() -> bool: has_cuda = torch.version.cuda is not None return (VLLM_TARGET_DEVICE == "cuda" and has_cuda - and not (_is_neuron() or _is_tpu() or _is_hpu())) + and not (_is_neuron() or _is_tpu())) def _is_hip() -> bool: @@ -573,12 +550,6 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"{sep}neuron{neuron_version_str}" - elif _is_hpu(): - # Get the Intel Gaudi Software Suite version - gaudi_sw_version = str(get_gaudi_sw_version()) - if gaudi_sw_version != MAIN_CUDA_VERSION: - gaudi_sw_version = gaudi_sw_version.replace(".", "")[:3] - version += f"{sep}gaudi{gaudi_sw_version}" elif _is_tpu(): version += f"{sep}tpu" elif _is_cpu(): @@ -625,8 +596,6 @@ def _read_requirements(filename: str) -> list[str]: requirements = _read_requirements("rocm.txt") elif _is_neuron(): requirements = _read_requirements("neuron.txt") - elif _is_hpu(): - requirements = _read_requirements("hpu.txt") elif _is_tpu(): requirements = _read_requirements("tpu.txt") elif _is_cpu(): @@ -635,8 +604,7 @@ def _read_requirements(filename: str) -> list[str]: requirements = _read_requirements("xpu.txt") else: raise ValueError( - "Unsupported platform, please use CUDA, ROCm, Neuron, HPU, " - "or CPU.") + "Unsupported platform, please use CUDA, ROCm, Neuron, or CPU.") return requirements @@ -689,10 +657,12 @@ def _read_requirements(filename: str) -> list[str]: install_requires=get_requirements(), extras_require={ "bench": ["pandas", "datasets"], - "tensorizer": ["tensorizer>=2.9.0"], + "tensorizer": ["tensorizer==2.10.1"], "fastsafetensors": ["fastsafetensors >= 0.1.10"], - "runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], - "audio": ["librosa", "soundfile"], # Required for audio processing + "runai": + ["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"], + "audio": ["librosa", "soundfile", + "mistral_common[audio]"], # Required for audio processing "video": [] # Kept for backwards compatibility }, cmdclass=cmdclass, diff --git a/tests/async_engine/test_api_server.py b/tests/async_engine/test_api_server.py index 38ecaf2233d9..76c94bdf80ca 100644 --- a/tests/async_engine/test_api_server.py +++ b/tests/async_engine/test_api_server.py @@ -29,7 +29,7 @@ def _query_server_long(prompt: str) -> dict: @pytest.fixture -def api_server(tokenizer_pool_size: int, distributed_executor_backend: str): +def api_server(distributed_executor_backend: str): script_path = Path(__file__).parent.joinpath( "api_server_async_engine.py").absolute() commands = [ @@ -40,8 +40,6 @@ def api_server(tokenizer_pool_size: int, distributed_executor_backend: str): "facebook/opt-125m", "--host", "127.0.0.1", - "--tokenizer-pool-size", - str(tokenizer_pool_size), "--distributed-executor-backend", distributed_executor_backend, ] @@ -54,10 +52,8 @@ def api_server(tokenizer_pool_size: int, distributed_executor_backend: str): uvicorn_process.terminate() -@pytest.mark.parametrize("tokenizer_pool_size", [0, 2]) @pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"]) -def test_api_server(api_server, tokenizer_pool_size: int, - distributed_executor_backend: str): +def test_api_server(api_server, distributed_executor_backend: str): """ Run the API server and test it. diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 2e103019f7af..13ddf035a55e 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -236,13 +236,13 @@ def test_failed_model_execution(vllm_runner, monkeypatch) -> None: monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: - if isinstance(vllm_model.model.llm_engine, LLMEngineV1): + if isinstance(vllm_model.llm.llm_engine, LLMEngineV1): v1_test_failed_model_execution(vllm_model) def v1_test_failed_model_execution(vllm_model): - engine = vllm_model.model.llm_engine + engine = vllm_model.llm.llm_engine mocked_execute_model = Mock( side_effect=RuntimeError("Mocked Critical Error")) engine.engine_core.engine_core.model_executor.execute_model =\ diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 4a422e8555da..4816b76996fc 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -294,61 +294,3 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) - - -@pytest.mark.parametrize("model", ["facebook/opt-125m"]) -@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) -@pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") -def test_models_cpu( - hf_runner: HfRunner, - vllm_runner: VllmRunner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - chunked_prefill_token_size: int, - enforce_eager: bool, - attention_backend: str, - monkeypatch: pytest.MonkeyPatch, -) -> None: - test_models( - hf_runner, - vllm_runner, - example_prompts, - model, - dtype, - max_tokens, - chunked_prefill_token_size, - enforce_eager, - 1, - attention_backend, - monkeypatch, - ) - - -@pytest.mark.parametrize("max_tokens", [16]) -@pytest.mark.parametrize("enforce_eager", [False]) -@pytest.mark.parametrize("chunk_size", [30, 32]) -@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) -@pytest.mark.cpu_model -@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") -def test_with_prefix_caching_cpu( - vllm_runner: VllmRunner, - max_tokens: int, - enforce_eager: bool, - chunk_size: int, - dtype: str, -) -> None: - test_with_prefix_caching( - vllm_runner, - max_tokens, - enforce_eager, - chunk_size, - 1, - dtype, - ) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 341a39a42b85..db2fa2f6bef6 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -81,7 +81,7 @@ def test_chunked_prefill_recompute( disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) for i in range(len(example_prompts)): @@ -118,10 +118,10 @@ def test_preemption( distributed_executor_backend=distributed_executor_backend, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) total_preemption = ( - vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) + vllm_model.llm.llm_engine.scheduler[0].num_cumulative_preemption) check_outputs_equal( outputs_0_lst=hf_outputs, @@ -174,12 +174,12 @@ def test_preemption_infeasible( ) as vllm_model: sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True) - req_outputs = vllm_model.model.generate( + req_outputs = vllm_model.llm.generate( example_prompts, sampling_params=sampling_params, ) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt + assert (vllm_model.llm.llm_engine.scheduler[0].artificial_preempt_cnt < ARTIFICIAL_PREEMPTION_MAX_CNT) # Verify the request is ignored and not hang. diff --git a/tests/compile/piecewise/test_multiple_graphs.py b/tests/compile/piecewise/test_multiple_graphs.py new file mode 100644 index 000000000000..e460d7095178 --- /dev/null +++ b/tests/compile/piecewise/test_multiple_graphs.py @@ -0,0 +1,350 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Test (piecewise) compilation with a simple model where multiple submodules +are compiled and graph captured separately. +""" +import torch +from torch import nn +from torch.library import Library + +from vllm.compilation.backends import set_model_tag +from vllm.compilation.counter import compilation_counter +from vllm.compilation.decorators import (ignore_torch_compile, + support_torch_compile) +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) +from vllm.envs import VLLM_USE_V1 +from vllm.forward_context import set_forward_context +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + +BATCH_SIZE = 32 +MLP_SIZE = 128 +HIDDEN_SIZE = 1024 +RANDOM_SEED = 0 + + +def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + out.copy_(q) + out += k + out += v + + +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: + return + + +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + +@support_torch_compile +class ParentModel(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class Attention(nn.Module): + + def __init__(self, mlp_size: int, hidden_size: int) -> None: + super().__init__() + self.pre_attn = nn.Linear(mlp_size, hidden_size, bias=False) + self.post_attn = nn.Linear(hidden_size, mlp_size, bias=False) + self.rms_norm_weight = nn.Parameter(torch.ones(hidden_size)) + + # Initialize to same weights for testing + nn.init.xavier_normal_( + self.pre_attn.weight.data, + generator=torch.Generator().manual_seed(RANDOM_SEED), + gain=0.001) + nn.init.xavier_normal_( + self.post_attn.weight.data, + generator=torch.Generator().manual_seed(RANDOM_SEED), + gain=0.001) + + def rms_norm_ref(self, x: torch.Tensor) -> torch.Tensor: + x_f32 = x.float() + return (x_f32 * torch.rsqrt( + torch.mean(x_f32.square(), dim=-1, keepdim=True) + 1e-6) * + self.rms_norm_weight).to(x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pre_attn(x) + x = self.rms_norm_ref(x) + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = self.rms_norm_ref(x) + x = self.post_attn(x) + return x + + +@support_torch_compile +class CompiledAttention(nn.Module): + + def __init__(self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + self.attn = Attention(mlp_size, hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.attn(x) + + +@support_torch_compile +class CompiledAttentionTwo(CompiledAttention): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.attn(x) + x + + +@ignore_torch_compile +class SimpleModelWithTwoGraphs(ParentModel): + + def __init__(self, + *, + mlp_size: int, + hidden_size: int, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__(vllm_config=vllm_config, prefix=prefix) + # Test will fail without set_model_tag here with error: + # "ValueError: too many values to unpack (expected 3)" + # This is because CompiledAttention and CompiledAttentionTwo + # have different implmentations but the same torch.compile + # cache dir will be used as default prefix is 'model_tag' + with set_model_tag("attn_one"): + self.attn_one = CompiledAttention( + mlp_size=mlp_size, + hidden_size=hidden_size, + vllm_config=vllm_config, + prefix=f"{prefix}.attn_one", + ) + with set_model_tag("attn_two"): + self.attn_two = CompiledAttentionTwo( + mlp_size=mlp_size, + hidden_size=hidden_size, + vllm_config=vllm_config, + prefix=f"{prefix}.attn_two", + ) + + self.hidden_states = torch.zeros((BATCH_SIZE, MLP_SIZE)).cuda() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bsz = x.shape[0] + # CUDAGraph expects same tensor addresses for each run + self.hidden_states[:bsz].copy_(x) + x = self.attn_one(self.hidden_states[:bsz]) + self.hidden_states[:bsz].copy_(x) + x = self.attn_two(self.hidden_states[:bsz]) + return x + + +def test_ignore_torch_compile_decorator(): + assert VLLM_USE_V1 + + # piecewise + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + + @support_torch_compile + class A(nn.Module): + + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = '', + **kwargs) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + x + attn_output = torch.empty_like(x) + torch.ops.silly.attention(x, x, x, attn_output) + x = attn_output + x = x * 3 + return x + + @ignore_torch_compile + class B(A): + ... + + @support_torch_compile + class C(B): + ... + + with set_current_vllm_config(vllm_config): + mod_A = A(vllm_config=vllm_config, prefix='').eval().cuda() + + # A has support_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ), set_forward_context({}, vllm_config=vllm_config): + # first run is for compile + mod_A(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + # run cudagraph captured sizes + mod_A(torch.randn(2, MLP_SIZE).cuda()) + mod_A(torch.randn(1, MLP_SIZE).cuda()) + + with set_current_vllm_config(vllm_config): + mod_B = B(vllm_config=vllm_config, prefix='').eval().cuda() + + # B's ignore_torch_compile should override A's support_torch_compile + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ), set_forward_context({}, vllm_config=vllm_config): + mod_B(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + mod_B(torch.randn(2, MLP_SIZE).cuda()) + mod_B(torch.randn(1, MLP_SIZE).cuda()) + + with set_current_vllm_config(vllm_config): + mod_C = C(vllm_config=vllm_config, prefix='').eval().cuda() + + # C's support_torch_compile should override B's ignore_torch_compile + with compilation_counter.expect( + num_graphs_seen=1, + num_piecewise_graphs_seen=3, + num_piecewise_capturable_graphs_seen=2, + num_backend_compilations=2, + num_cudagraph_captured=4, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ), set_forward_context({}, vllm_config=vllm_config): + mod_C(torch.randn(BATCH_SIZE, MLP_SIZE).cuda()) + mod_C(torch.randn(2, MLP_SIZE).cuda()) + mod_C(torch.randn(1, MLP_SIZE).cuda()) + + +@torch.inference_mode +def run_model(vllm_config, model: nn.Module, inputs: torch.Tensor): + with set_forward_context({}, vllm_config=vllm_config): + # First run is for compile + model(inputs) + + # Run CUDAGraph captured sizes + model(inputs[:2]) + model(inputs[:1]) + + output = model(inputs[:2]) + + output = output.cpu() + return output.cpu() + + +def test_multi_graph_piecewise_compile_outputs_equal(): + outputs = [] + + # piecewise compile + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=True, + splitting_ops=["silly.attention"], + cudagraph_capture_sizes=[1, 2], + )) + + with set_current_vllm_config(vllm_config): + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='').eval().cuda() + + # Pre-allocate memory for CUDAGraph which expects + # static tensor addresses + inputs = torch.randn(BATCH_SIZE, MLP_SIZE).cuda() + + with compilation_counter.expect( + num_graphs_seen=2, # two graphs for the model + num_piecewise_graphs_seen=6, + # attn_one, attn_two each has 3 piecewise graphs + # (pre attn, post attn, silly_attention) each + num_piecewise_capturable_graphs_seen=4, + # attn_one, attn_two has pre attn and post attn each, total=4 + num_backend_compilations=4, # num_piecewise_capturable_graphs_seen + num_cudagraph_captured=8, + # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen + ): + outputs.append(run_model(vllm_config, model, inputs)) + + # no compile or cudagraph + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.NO_COMPILATION, )) + + with set_current_vllm_config(vllm_config): + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='').eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=0, + num_piecewise_graphs_seen=0, + num_piecewise_capturable_graphs_seen=0, + num_backend_compilations=0, + num_cudagraph_captured=0, + ): + outputs.append(run_model(vllm_config, model, inputs)) + + # piecewise compile without CUDA graph + vllm_config = VllmConfig(compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + use_cudagraph=False, + splitting_ops=["silly.attention"], + )) + + with set_current_vllm_config(vllm_config): + model = SimpleModelWithTwoGraphs(mlp_size=MLP_SIZE, + hidden_size=HIDDEN_SIZE, + vllm_config=vllm_config, + prefix='').eval().cuda() + + with compilation_counter.expect( + num_graphs_seen=2, + num_piecewise_graphs_seen=6, + num_piecewise_capturable_graphs_seen=4, + num_backend_compilations=4, + num_cudagraph_captured=0, # no cudagraph captured + ): + outputs.append(run_model(vllm_config, model, inputs)) + + # Generally don't expect outputs with and without inductor + # to be bitwise equivalent + assert torch.allclose(outputs[0], outputs[1]) + + # Expect bitwise equivalence using inductor w/ and w/o cudagraph + assert torch.equal(outputs[0], outputs[2]) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 8679d5c3019b..0ba59f4b5a05 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -26,6 +26,30 @@ def test_use_cudagraphs_dynamic(monkeypatch): assert not vllm_config.compilation_config.use_cudagraph +# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends +# on the state of the cache directory on the current machine, which +# may be influenced by other tests. +@pytest.mark.parametrize("val", ["1"]) +def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val): + assert vllm.envs.VLLM_USE_V1 + + # spawn means that the counters are in the same process. + monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn") + monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val) + + compilation_config = { + "use_cudagraph": False, # speed things up a bit + } + with ( + compilation_counter.expect(num_cache_entries_updated=0, + num_compiled_artifacts_saved=0), + # loading the model causes compilation (if enabled) to happen + vllm_runner('facebook/opt-125m', + compilation_config=compilation_config, + gpu_memory_utilization=0.4) as _): + pass + + @pytest.mark.parametrize("enabled", [True, False]) def test_use_cudagraphs(vllm_runner, monkeypatch, enabled): assert vllm.envs.VLLM_USE_V1 diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index 1d000fe00c59..72f962ed7484 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -3,6 +3,7 @@ from __future__ import annotations +import tempfile from typing import Any, Optional, Union import pytest @@ -111,6 +112,11 @@ def test_full_graph( pass_config=PassConfig(enable_fusion=True, enable_noop=True)), model) for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"]) + ] + [ + # Test depyf integration works + (CompilationConfig(level=CompilationLevel.PIECEWISE, + debug_dump_path=tempfile.gettempdir()), + ("facebook/opt-125m", {})), ]) # only test some of the models @create_new_process_for_each_test() diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 040fd176fec1..4a3820e20fd8 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -44,7 +44,9 @@ def __init__(self, hidden_size: int, eps: float, static: bool, ] self.fp8_linear = Fp8LinearOp( cutlass_fp8_supported=cutlass_fp8_enabled, - use_per_token_if_dynamic=True) + act_quant_static=static, + act_quant_group_shape=group_shape, + ) def forward(self, x): resid = torch.sqrt(x) @@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, maybe_create_device_identity() # needed for certain non-cutlass fp8 paths vllm_config = VllmConfig(compilation_config=CompilationConfig( - level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) - vllm_config.compilation_config.pass_config = \ - PassConfig(enable_fusion=True, enable_noop=True) + level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm", "+quant_fp8"], + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + )) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_fusion_all_reduce.py b/tests/compile/test_fusion_all_reduce.py new file mode 100644 index 000000000000..492e90f2a75f --- /dev/null +++ b/tests/compile/test_fusion_all_reduce.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from importlib.util import find_spec + +import pytest +import torch + +import vllm.envs as envs +from vllm.compilation.collective_fusion import AllReduceFusionPass +from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig, + ModelConfig, PassConfig, VllmConfig) +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.platforms import current_platform +from vllm.utils import update_environment_variables + +from ..utils import multi_gpu_test +from .backend import TestBackend + + +class TestAllReduceRMSNormModel(torch.nn.Module): + + def __init__(self, hidden_size=16, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = RMSNorm(hidden_size, eps) + + def forward(self, hidden_states, residual): + view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(view) + norm = self.norm(all_reduce) + return norm + + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + +class TestAllReduceFusedAddRMSNormModel(torch.nn.Module): + + def __init__(self, hidden_size=16, eps=1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm = RMSNorm(hidden_size, eps) + + def forward(self, hidden_states, residual): + view = hidden_states.reshape(-1, self.hidden_size) + all_reduce = tensor_model_parallel_all_reduce(view) + norm, _ = self.norm(all_reduce, residual) + return norm + + def ops_in_model_before(self): + return [torch.ops.vllm.all_reduce.default] + + def ops_in_model_after(self): + return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default] + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "test_model", + [TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [8]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +@pytest.mark.skipif(not find_spec("flashinfer"), + reason="flashinfer is not installed") +@pytest.mark.skipif(not current_platform.is_device_capability(100), + reason="Only test on SM100") +def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module, + batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + num_processes = 2 + + def run_torch_spawn(fn, nprocs): + torch.multiprocessing.spawn(fn, + args=(num_processes, test_model, + batch_size, seq_len, hidden_size, + dtype), + nprocs=nprocs) + + run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes) + + +def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int, + test_model_cls: torch.nn.Module, + batch_size: int, seq_len: int, + hidden_size: int, dtype: torch.dtype): + current_platform.seed_everything(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables({ + 'RANK': str(local_rank), + 'LOCAL_RANK': str(local_rank), + 'WORLD_SIZE': str(world_size), + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=world_size) + + vllm_config = VllmConfig( + compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE, + custom_ops=["+rms_norm"], + compile_sizes=[2, 4, 8])) + vllm_config.compilation_config.pass_config = PassConfig( + enable_fi_allreduce_fusion=True) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + + # this is a fake model name to construct the model config + # in the vllm_config, it's not really used. + model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e" + vllm_config.model_config = ModelConfig(model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=True, + dtype=dtype, + seed=42) + + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + backend = TestBackend(all_reduce_fusion_pass) + + model = test_model_cls(hidden_size) + + hidden_states = torch.randn((batch_size * seq_len, hidden_size), + requires_grad=False) + residual = torch.randn((batch_size * seq_len, hidden_size), + requires_grad=False) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states, residual) + + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + del all_reduce_fusion_pass diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 37ec753bbc9e..70750eb9ac4e 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # DYNAMO_ONCE does not properly propagate shapes. level=CompilationLevel.DYNAMO_AS_IS, backend="tests.compile.test_fusion_attn.backend_unfused", + custom_ops=["+quant_fp8"], ) vllm_config = VllmConfig(compilation_config=compile_config) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) @@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, # DYNAMO_ONCE does not properly propagate shapes. level=CompilationLevel.DYNAMO_AS_IS, backend="tests.compile.test_fusion_attn.backend", + custom_ops=["+quant_fp8"], ) vllm_config = VllmConfig(compilation_config=compile_config) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index df36b86abdbe..5351a3cf35ba 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -4,33 +4,56 @@ import torch import vllm.envs as envs -from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_FP8_SUPPORTED, Fp8LinearOp) +from vllm.platforms import current_platform from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, *args, **kwargs): + def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args, + **kwargs): super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() + self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) + self.w = (torch.rand( + hidden_size, + hidden_size).to(dtype=current_platform.fp8_dtype()).t()) + + self.fp8_linear = Fp8LinearOp( + cutlass_fp8_supported=cutlass_fp8_enabled, + act_quant_static=True, + act_quant_group_shape=GroupShape.PER_TENSOR, + ) + def forward(self, x): y = self.silu_and_mul(x) - x2 = scaled_fp8_quant(y, self.scale) + x2 = self.fp8_linear.apply(y, + self.w, + self.wscale, + input_scale=self.wscale) return x2 @pytest.mark.parametrize("num_tokens", [256]) @pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.parametrize("cutlass_fp8_enabled", + [True, False] if CUTLASS_FP8_SUPPORTED else [False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, + cutlass_fp8_enabled): torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): pass_config=PassConfig(enable_fusion=True, enable_noop=True)) fusion_pass = ActivationQuantFusionPass(config) - backend = TestBackend(fusion_pass) - model = TestModel() + backend = TestBackend(NoOpEliminationPass(config), fusion_pass) + model = TestModel(hidden_size, cutlass_fp8_enabled) # First dimension dynamic - x = torch.rand(num_tokens, hidden_size) + x = torch.rand(num_tokens, hidden_size * 2) torch._dynamo.mark_dynamic(x, 0) result = model(x) diff --git a/tests/conftest.py b/tests/conftest.py index b294b50a5cdd..a18dbf58c803 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -759,7 +759,8 @@ class VllmRunner: - `trust_remote_code`: Set to `True` instead of `False` for convenience. - `seed`: Set to `0` instead of `None` for test reproducibility. - `max_model_len`: Set to `1024` instead of `None` to reduce memory usage. - - `block_size`: Set to `16` instead of `None` to reduce memory usage. + - `block_size`: To reduce memory usage, set default to `64` if on XPU + devices, otherwise default to `16`. - `enable_chunked_prefill`: Set to `False` instead of `None` for test reproducibility. - `enforce_eager`: Set to `False` to test CUDA graph. @@ -777,13 +778,13 @@ def __init__( dtype: str = "auto", disable_log_stats: bool = True, tensor_parallel_size: int = 1, - block_size: int = 16, + block_size: int = 16 if not torch.xpu.is_available() else 64, enable_chunked_prefill: Optional[bool] = False, swap_space: int = 4, enforce_eager: Optional[bool] = False, **kwargs, ) -> None: - self.model = LLM( + self.llm = LLM( model=model_name, task=task, tokenizer=tokenizer_name, @@ -803,7 +804,7 @@ def __init__( def get_inputs( self, - prompts: Union[list[str], list[torch.Tensor]], + prompts: Union[list[str], list[torch.Tensor], list[int]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, @@ -825,11 +826,16 @@ def get_inputs( if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio - text_prompt_kwargs = { - ("prompt" if isinstance(prompt, str) else "prompt_embeds"): - prompt, + text_prompt_kwargs: dict[str, Any] = { "multi_modal_data": multi_modal_data or None } + if isinstance(prompt, str): + text_prompt_kwargs["prompt"] = prompt + elif isinstance(prompt, list): + text_prompt_kwargs["prompt_token_ids"] = prompt + else: + text_prompt_kwargs["prompt_embeds"] = prompt + inputs.append(TextPrompt(**text_prompt_kwargs)) return inputs @@ -848,9 +854,9 @@ def generate( videos=videos, audios=audios) - req_outputs = self.model.generate(inputs, - sampling_params=sampling_params, - **kwargs) + req_outputs = self.llm.generate(inputs, + sampling_params=sampling_params, + **kwargs) outputs: list[tuple[list[list[int]], list[str]]] = [] for req_output in req_outputs: @@ -896,9 +902,9 @@ def generate_w_logprobs( videos=videos, audios=audios) - req_outputs = self.model.generate(inputs, - sampling_params=sampling_params, - **kwargs) + req_outputs = self.llm.generate(inputs, + sampling_params=sampling_params, + **kwargs) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) @@ -918,8 +924,8 @@ def generate_encoder_decoder_w_logprobs( ''' assert sampling_params.logprobs is not None - req_outputs = self.model.generate(encoder_decoder_prompts, - sampling_params=sampling_params) + req_outputs = self.llm.generate(encoder_decoder_prompts, + sampling_params=sampling_params) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) # Omit prompt logprobs if not required by sampling params @@ -1012,7 +1018,7 @@ def generate_beam_search( videos=videos, audios=audios) - outputs = self.model.beam_search( + outputs = self.llm.beam_search( inputs, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) returned_outputs = [] @@ -1023,7 +1029,7 @@ def generate_beam_search( return returned_outputs def classify(self, prompts: list[str]) -> list[list[float]]: - req_outputs = self.model.classify(prompts) + req_outputs = self.llm.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] def embed(self, @@ -1038,11 +1044,11 @@ def embed(self, videos=videos, audios=audios) - req_outputs = self.model.embed(inputs, *args, **kwargs) + req_outputs = self.llm.embed(inputs, *args, **kwargs) return [req_output.outputs.embedding for req_output in req_outputs] def encode(self, prompts: list[str]) -> list[list[float]]: - req_outputs = self.model.encode(prompts) + req_outputs = self.llm.encode(prompts) return [req_output.outputs.data for req_output in req_outputs] def score( @@ -1052,18 +1058,18 @@ def score( *args, **kwargs, ) -> list[float]: - req_outputs = self.model.score(text_1, text_2, *args, **kwargs) + req_outputs = self.llm.score(text_1, text_2, *args, **kwargs) return [req_output.outputs.score for req_output in req_outputs] def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]: - executor = self.model.llm_engine.model_executor + executor = self.llm.llm_engine.model_executor return executor.apply_model(func) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - del self.model + del self.llm cleanup_dist_env_and_memory() diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index 1b958e34df87..9e1b7913dfb9 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -37,7 +37,7 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, num_scheduler_steps=num_scheduler_steps, enable_chunked_prefill=enable_chunked_prefill, enforce_eager=enforce_eager) - engine: LLMEngine = runner.model.llm_engine + engine: LLMEngine = runner.llm.llm_engine # In multi-step + chunked-prefill there is no separate single prompt step. # What is scheduled will run for num_scheduler_steps always. diff --git a/tests/core/test_serialization.py b/tests/core/test_serialization.py index 8281298d6634..ee9ac2129f2d 100644 --- a/tests/core/test_serialization.py +++ b/tests/core/test_serialization.py @@ -6,7 +6,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.sequence import ExecuteModelRequest -from ..spec_decode.utils import create_batch +from .utils import create_batch def test_msgspec_serialization(): diff --git a/tests/core/utils.py b/tests/core/utils.py index b746c1786464..033fffd2c4e2 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -4,15 +4,16 @@ import time from collections import defaultdict from collections.abc import Sequence as GenericSequence -from typing import Any, Optional +from itertools import count +from typing import Any, Optional, Union import torch -from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs from vllm.lora.request import LoRARequest -from vllm.sequence import (Logprob, Sequence, SequenceGroup, +from vllm.sampling_params import SamplingParams +from vllm.sequence import (Logprob, Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata) @@ -262,3 +263,130 @@ def last_schedule_ret( self, ) -> tuple[list[SequenceGroupMetadata], SchedulerOutputs, Any]: _, _, ret = self.call_history["schedule"][-1] return ret + + +def create_seq_group_metadata_from_prompts( + prompts: list[list[int]], + num_gpu_blocks: int, + block_size: int, + final_prompt_lens: list[int], + continuations: Optional[list[list[int]]] = None, + seq_ids: Optional[list[int]] = None, +) -> list[SequenceGroupMetadata]: + + if continuations is None: + continuations = [[] for _ in prompts] + + if seq_ids is None: + seq_ids = list(i for i, _ in enumerate(prompts)) + + free_gpu_blocks = list(range(num_gpu_blocks)) + + block_allocations = { + i: [ + free_gpu_blocks.pop() + for _ in range(round_up_to_next_block(final_len, block_size)) + ] + for i, final_len in enumerate(final_prompt_lens) + } + + seq_grou_metadata_list = [] + for i, (prompt_token_ids, + cont_token_ids) in enumerate(zip(prompts, continuations)): + data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) + data.update_num_computed_tokens( + len(prompt_token_ids) + len(cont_token_ids) - 1) + seq_data = {i: data} + seq_grou_metadata_list.append( + SequenceGroupMetadata( + request_id=str(i), + is_prompt=len(cont_token_ids) == 0, + seq_data=seq_data, + sampling_params=SamplingParams(temperature=0.0), + block_tables={i: block_allocations[i][:]}, + )) + return seq_grou_metadata_list + + +def create_chunked_seq_group_metadata_from_prompt( + prompt: list[int], + num_gpu_blocks: int, + chunk_size: int, + block_size: int, + seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]: + + if seq_id is None: + seq_id = 0 + + free_gpu_blocks = list(range(num_gpu_blocks)) + + block_allocations = [ + free_gpu_blocks.pop() + for _ in range(round_up_to_next_block(len(prompt), block_size)) + ] + + seq_group_metadata_list = [] + for i, idx in enumerate(range(0, len(prompt), chunk_size)): + chunk_ids = prompt[idx:idx + chunk_size] + data = SequenceData.from_seqs(prompt) + data.update_num_computed_tokens(idx) + seq_data = {i: data} + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=str(seq_id), + is_prompt=True, + do_sample=idx + chunk_size >= len(prompt), # terminal chunk + seq_data=seq_data, + sampling_params=SamplingParams(temperature=0.0), + block_tables={i: block_allocations}, + token_chunk_size=len(chunk_ids))) + return seq_group_metadata_list + + +def create_batch(batch_size, + k, + prompt_len: Union[int, list[int]] = 10, + prev_output_token_len: int = 10, + seq_ids: Optional[list[int]] = None, + num_gpu_blocks: Optional[int] = None, + block_size: Optional[int] = None, + prefill_chunk_size: Optional[int] = None): + if block_size is None: + block_size = 8 + + if num_gpu_blocks is None: + num_gpu_blocks = 2048 // block_size + + iterator = count() + + if isinstance(prompt_len, int): + prompt_lens = [prompt_len for _ in range(batch_size)] + else: + prompt_lens = prompt_len + + prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] + + if prefill_chunk_size: + # Create a batch of chunked prompts. + if not seq_ids: + seq_ids = list(range(len(prompts))) + seq_group_metadata_list = [] + for p, sid in zip(prompts, seq_ids): + seq_group_metadata_list += \ + create_chunked_seq_group_metadata_from_prompt( + p, num_gpu_blocks, prefill_chunk_size, block_size, sid) + seq_group_metadata_list = seq_group_metadata_list[:batch_size] + prev_output_tokens = [] + else: + prev_output_tokens = [[ + next(iterator) for _ in range(prev_output_token_len) + ] for _ in range(batch_size)] + final_prompt_lens = [ + len(prompt) + len(prev_output_token) + k + 1 + for prompt, prev_output_token in zip(prompts, prev_output_tokens) + ] + + seq_group_metadata_list = create_seq_group_metadata_from_prompts( + prompts, num_gpu_blocks, block_size, final_prompt_lens, + prev_output_tokens, seq_ids) + return seq_group_metadata_list, prompts, prev_output_tokens diff --git a/tests/detokenizer/test_stop_reason.py b/tests/detokenizer/test_stop_reason.py index 9716f7d72a58..1ff679789c95 100644 --- a/tests/detokenizer/test_stop_reason.py +++ b/tests/detokenizer/test_stop_reason.py @@ -28,7 +28,7 @@ def vllm_model(vllm_runner): def test_stop_reason(vllm_model, example_prompts): tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL) stop_token_id = tokenizer.convert_tokens_to_ids(STOP_STR) - llm = vllm_model.model + llm = vllm_model.llm # test stop token outputs = llm.generate(example_prompts, diff --git a/tests/detokenizer/test_stop_strings.py b/tests/detokenizer/test_stop_strings.py index efe938a20c4f..cb87c44cc399 100644 --- a/tests/detokenizer/test_stop_strings.py +++ b/tests/detokenizer/test_stop_strings.py @@ -101,42 +101,42 @@ def _stop_token_id(llm): def test_stop_strings(): # If V0, must set enforce_eager=False since we use # async output processing below. - vllm_model = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) + llm = LLM(MODEL, enforce_eager=envs.VLLM_USE_V1) if envs.VLLM_USE_V1: - _stop_basic(vllm_model) + _stop_basic(llm) else: - _set_async_mode(vllm_model, True) - _stop_basic(vllm_model) + _set_async_mode(llm, True) + _stop_basic(llm) - _set_async_mode(vllm_model, False) - _stop_basic(vllm_model) + _set_async_mode(llm, False) + _stop_basic(llm) if envs.VLLM_USE_V1: - _stop_multi_tokens(vllm_model) + _stop_multi_tokens(llm) else: - _set_async_mode(vllm_model, True) - _stop_multi_tokens(vllm_model) + _set_async_mode(llm, True) + _stop_multi_tokens(llm) - _set_async_mode(vllm_model, False) - _stop_multi_tokens(vllm_model) + _set_async_mode(llm, False) + _stop_multi_tokens(llm) if envs.VLLM_USE_V1: - _stop_partial_token(vllm_model) + _stop_partial_token(llm) else: - _set_async_mode(vllm_model, True) - _stop_partial_token(vllm_model) + _set_async_mode(llm, True) + _stop_partial_token(llm) - _set_async_mode(vllm_model, False) - _stop_partial_token(vllm_model) + _set_async_mode(llm, False) + _stop_partial_token(llm) if envs.VLLM_USE_V1: # FIXME: this does not respect include_in_output=False - # _stop_token_id(vllm_model) + # _stop_token_id(llm) pass else: - _set_async_mode(vllm_model, True) - _stop_token_id(vllm_model) + _set_async_mode(llm, True) + _stop_token_id(llm) - _set_async_mode(vllm_model, False) - _stop_token_id(vllm_model) + _set_async_mode(llm, False) + _stop_token_id(llm) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 7d569fd83821..2391430a083a 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -14,8 +14,9 @@ import pytest -from vllm.config import TaskOption +from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, TaskOption from vllm.logger import init_logger +from vllm.transformers_utils.config import get_config from ..models.registry import HF_EXAMPLE_MODELS from ..utils import compare_two_settings, create_new_process_for_each_test @@ -158,7 +159,7 @@ def iter_params(self, model_id: str): "databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"), "Deci/DeciLM-7B-instruct": PPTestSettings.fast(), "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), - "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(), + "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(tp_base=2), "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(), "tiiuae/falcon-7b": PPTestSettings.fast(), "google/gemma-1.1-2b-it": PPTestSettings.fast(), @@ -176,7 +177,7 @@ def iter_params(self, model_id: str): "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), # Tests TransformersForCausalLM - "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), + "hmellor/Ilama-3.2-1B": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(), # Uses Llama @@ -210,9 +211,11 @@ def iter_params(self, model_id: str): EMBEDDING_MODELS = { # type: ignore[var-annotated] # [Text-only] - "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), - "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), - "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"), + "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(task="embed"), + "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(task="embed"), + "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast( + load_format="dummy", task="embed" + ), } MULTIMODAL_MODELS = { @@ -246,8 +249,9 @@ def iter_params(self, model_id: str): # [LANGUAGE GENERATION] "microsoft/Phi-3.5-MoE-instruct", "meta-llama/Llama-3.2-1B-Instruct", - "ArthurZ/Ilama-3.2-1B", + "hmellor/Ilama-3.2-1B", "ibm/PowerLM-3b", + "deepseek-ai/DeepSeek-V2-Lite-Chat", # [LANGUAGE EMBEDDING] "intfloat/e5-mistral-7b-instruct", "BAAI/bge-multilingual-gemma2", @@ -287,6 +291,11 @@ def _compare_tp( trust_remote_code = model_info.trust_remote_code tokenizer_mode = model_info.tokenizer_mode hf_overrides = model_info.hf_overrides + hf_config = get_config(model_id, trust_remote_code) + + dtype = "float16" + if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS: + dtype = "bfloat16" if load_format == "dummy": # Avoid OOM @@ -316,7 +325,7 @@ def _compare_tp( common_args = [ # use half precision for speed and memory savings in CI environment "--dtype", - "float16", + dtype, "--max-model-len", "2048", "--max-num-seqs", @@ -338,6 +347,7 @@ def _compare_tp( common_args.extend(["--hf-overrides", json.dumps(hf_overrides)]) specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill + testing_ray_compiled_graph = False if distributed_backend == "ray" and (vllm_major_version == "1" or specific_case): # For V1, test Ray Compiled Graph for all the tests @@ -351,6 +361,7 @@ def _compare_tp( # Temporary. Currently when zeromq + SPMD is used, it does not properly # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") + testing_ray_compiled_graph = True elif distributed_backend == "mp": # Both V0/V1 of multiprocessing executor support PP pp_env = { @@ -394,7 +405,6 @@ def _compare_tp( tp_env, method=method) except Exception: - testing_ray_compiled_graph = pp_env is not None if testing_ray_compiled_graph and vllm_major_version == "0": # Ray Compiled Graph tests are flaky for V0, # so we don't want to fail the test diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 5b32b90f3cfe..abfad9ebfe7d 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -4,6 +4,7 @@ import multiprocessing import os +import numpy as np import pytest import torch import torch.distributed @@ -177,6 +178,38 @@ def test_pynccl_all_gather(): distributed_run(all_gather_worker_fn, 2) +@worker_fn_wrapper +def all_gatherv_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + assert world_size <= 8 + sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] + num_elems = sizes[rank] + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * 100 + result = torch.zeros(sum(sizes), dtype=torch.float32, device=device) + + expected = torch.cat([ + torch.arange(sizes[r], dtype=torch.float32) + r * 100 + for r in range(world_size) + ]).to(device) + + pynccl_comm.all_gatherv(result, tensor, sizes=sizes) + torch.cuda.synchronize() + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_all_gatherv(): + distributed_run(all_gatherv_worker_fn, 2) + + @worker_fn_wrapper def reduce_scatter_worker_fn(): pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, @@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter(): distributed_run(reduce_scatter_worker_fn, 2) +@worker_fn_wrapper +def reduce_scatterv_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + assert world_size <= 8 + sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size] + num_elems = sum(sizes) + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * 100 + result = torch.zeros(sizes[rank], dtype=torch.float32, device=device) + + # Calculate expected result for this rank's chunk + all_tensors = [ + torch.arange(num_elems, dtype=torch.float32) + r * 100 + for r in range(world_size) + ] + sizes_cumsum = np.cumsum(sizes) + start = 0 if rank == 0 else sizes_cumsum[rank - 1] + end = sizes_cumsum[rank] + expected = sum(tensor[start:end] for tensor in all_tensors).to(device) + + pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes) + torch.cuda.synchronize() + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_reduce_scatterv(): + distributed_run(reduce_scatterv_worker_fn, 2) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl_with_cudagraph(): diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 86e28c687847..5a91758414a5 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json -from argparse import ArgumentError, ArgumentTypeError +from argparse import ArgumentError from contextlib import nullcontext from dataclasses import dataclass, field from typing import Annotated, Literal, Optional @@ -12,8 +12,8 @@ from vllm.config import CompilationConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, get_type_hints, is_not_builtin, - is_type, literal_to_kwargs, nullable_kvs, - optional_type, parse_type) + is_type, literal_to_kwargs, optional_type, + parse_type) from vllm.utils import FlexibleArgumentParser @@ -25,18 +25,10 @@ "foo": 1, "bar": 2 }), - (json.loads, "foo=1,bar=2", { - "foo": 1, - "bar": 2 - }), ]) def test_parse_type(type, value, expected): parse_type_func = parse_type(type) - context = nullcontext() - if value == "foo=1,bar=2": - context = pytest.warns(DeprecationWarning) - with context: - assert parse_type_func(value) == expected + assert parse_type_func(value) == expected def test_optional_type(): @@ -203,34 +195,6 @@ def test_get_kwargs(): assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 -@pytest.mark.parametrize(("arg", "expected"), [ - (None, dict()), - ("image=16", { - "image": 16 - }), - ("image=16,video=2", { - "image": 16, - "video": 2 - }), - ("Image=16, Video=2", { - "image": 16, - "video": 2 - }), -]) -def test_limit_mm_per_prompt_parser(arg, expected): - """This functionality is deprecated and will be removed in the future. - This argument should be passed as JSON string instead. - - TODO: Remove with nullable_kvs.""" - parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) - if arg is None: - args = parser.parse_args([]) - else: - args = parser.parse_args(["--limit-mm-per-prompt", arg]) - - assert args.limit_mm_per_prompt == expected - - @pytest.mark.parametrize( ("arg", "expected"), [ @@ -326,18 +290,6 @@ def test_prefix_cache_default(): assert not engine_args.enable_prefix_caching -@pytest.mark.parametrize( - ("arg"), - [ - "image", # Missing = - "image=4,image=5", # Conflicting values - "image=video=4" # Too many = in tokenized arg - ]) -def test_bad_nullable_kvs(arg): - with pytest.raises(ArgumentTypeError): - nullable_kvs(arg) - - # yapf: disable @pytest.mark.parametrize(("arg", "expected", "option"), [ (None, None, "mm-processor-kwargs"), diff --git a/tests/entrypoints/llm/test_accuracy.py b/tests/entrypoints/llm/test_accuracy.py index a2d35486a5e8..6c5706d16340 100644 --- a/tests/entrypoints/llm/test_accuracy.py +++ b/tests/entrypoints/llm/test_accuracy.py @@ -15,15 +15,18 @@ from vllm.platforms import current_platform MODEL_NAMES = [ - "Qwen/Qwen2-1.5B-Instruct", + "Qwen/Qwen3-1.7B", "google/gemma-3-1b-it", ] +FP8_KV_MODEL_NAMES = [ + "Qwen/Qwen3-1.7B", +] NUM_CONCURRENT = 500 TASK = "gsm8k" FILTER = "exact_match,strict-match" RTOL = 0.03 EXPECTED_VALUES = { - "Qwen/Qwen2-1.5B-Instruct": 0.58, + "Qwen/Qwen3-1.7B": 0.68, "google/gemma-3-1b-it": 0.25, } @@ -69,6 +72,10 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): more_args = None if current_platform.is_tpu(): # Limit compilation time for TPU V1 + + # xet doesn't work well for both Qwen/Qwen3-1.7B and + # google/gemma-3-1b-it + m.setenv("HF_HUB_DISABLE_XET", "1") more_args = "max_model_len=2048,max_num_seqs=64" # Add TP test (if provided) @@ -78,9 +85,27 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch): run_test(model, more_args) -def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch): - """Run with the V0 Engine.""" +@pytest.mark.skipif(not current_platform.is_cuda() + and not current_platform.is_tpu(), + reason="V1 is currently only supported on CUDA and TPU") +@pytest.mark.parametrize("model", FP8_KV_MODEL_NAMES) +def test_lm_eval_accuracy_v1_engine_fp8_kv_cache( + model, monkeypatch: pytest.MonkeyPatch): + """Run with the V1 Engine.""" with monkeypatch.context() as m: - m.setenv("VLLM_USE_V1", "0") - run_test("Qwen/Qwen2-1.5B-Instruct") + m.setenv("VLLM_USE_V1", "1") + + more_args = None + if current_platform.is_tpu(): + # Limit compilation time for TPU V1 + + # xet doesn't work well for Qwen/Qwen3-1.7B + m.setenv("HF_HUB_DISABLE_XET", "1") + more_args = "max_model_len=2048,max_num_seqs=128,kv_cache_dtype=fp8" + + # Add TP test (if provided) + if TPU_TP_TEST_STR: + more_args += ",{}".format(TPU_TP_TEST_STR) + + run_test(model, more_args) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index d41b0a436c62..55578341cb2e 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -16,14 +16,18 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" -GUIDED_DECODING_BACKENDS = [ + +# Separate backends which support grammars vs ones +# which only support regex based constraints in tests. +GRAMMAR_DECODING_BACKENDS = [ # (backend, disable_any_whitespace), - ("outlines", False), ("lm-format-enforcer", False), ("xgrammar", True), ("guidance", True), ] +ALL_DECODING_BACKENDS = ([("outlines", False)] + GRAMMAR_DECODING_BACKENDS) + @pytest.fixture(scope="module") def llm(): @@ -39,7 +43,7 @@ def llm(): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( @@ -49,6 +53,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, regex=sample_regex, backend=guided_decoding_backend, disable_any_whitespace=disable_any_whitespace)) + outputs = llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, @@ -69,7 +74,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_json_completion(sample_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -103,7 +108,7 @@ def test_guided_json_completion(sample_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_complex_json_completion(sample_complex_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -138,7 +143,7 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_definition_json_completion(sample_definition_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -173,7 +178,7 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_enum_json_completion(sample_enum_json_schema, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -218,7 +223,7 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -248,7 +253,7 @@ def test_guided_choice_completion(sample_guided_choice, llm, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + GRAMMAR_DECODING_BACKENDS) def test_guided_grammar(sample_sql_statements, llm, guided_decoding_backend: str, disable_any_whitespace: bool): @@ -344,7 +349,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + GRAMMAR_DECODING_BACKENDS) def test_guided_json_object(llm, guided_decoding_backend: str, disable_any_whitespace: bool): sampling_params = SamplingParams( @@ -377,7 +382,9 @@ def test_guided_json_object(llm, guided_decoding_backend: str, # Parse to verify it is valid JSON parsed_json = json.loads(generated_text) - assert isinstance(parsed_json, dict) + # A list is not what was intended, but is still valid + # json. + assert isinstance(parsed_json, (dict, list)) class CarType(str, Enum): @@ -395,7 +402,7 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, disable_any_whitespace: bool): json_schema = CarDescription.model_json_schema() @@ -427,7 +434,7 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, @pytest.mark.skip_global_cleanup @pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", - GUIDED_DECODING_BACKENDS) + ALL_DECODING_BACKENDS) def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, disable_any_whitespace: bool): sample_output_schema = { diff --git a/tests/entrypoints/openai/correctness/test_lmeval.py b/tests/entrypoints/openai/correctness/test_lmeval.py index 41b70f80e3b8..a07a147cdc2b 100644 --- a/tests/entrypoints/openai/correctness/test_lmeval.py +++ b/tests/entrypoints/openai/correctness/test_lmeval.py @@ -69,8 +69,9 @@ def run_test(more_args): @pytest.mark.skipif(not current_platform.is_cuda() - and not current_platform.is_tpu(), - reason="V1 currently only supported on CUDA and TPU") + and not current_platform.is_tpu() + and not current_platform.is_xpu(), + reason="V1 currently only supported on CUDA, XPU and TPU") def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch): """Run with the V1 Engine.""" diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index dab947b21b28..e7c3ffaa6a9f 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer): @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME, ""]) -async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer, - model_name: str): - +async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer): openai_api_key = "EMPTY" openai_api_base = f"http://localhost:{server.port}/v1" @@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer, messages=messages, ) assert response.model == MODEL_NAME + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + request_args = { + "model": MODEL_NAME, + "messages": messages, + "max_completion_tokens": 5, + "temperature": 0.0, + "logprobs": False, + } + + chat_completion = await client.chat.completions.create(**request_args) + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + chat_output = chat_completion.model_dump() + invocation_output = invocation_response.json() + + assert chat_output.keys() == invocation_output.keys() + assert chat_output["choices"] == invocation_output["choices"] diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py index 6d5f925152c3..b2472658ca81 100644 --- a/tests/entrypoints/openai/test_classification.py +++ b/tests/entrypoints/openai/test_classification.py @@ -155,3 +155,29 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer, assert output.object == "list" assert isinstance(output.data, list) assert len(output.data) == 0 + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer): + request_args = { + "model": MODEL_NAME, + "input": "This product was excellent and exceeded my expectations" + } + + classification_response = requests.post(server.url_for("classify"), + json=request_args) + classification_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + classification_output = classification_response.json() + invocation_output = invocation_response.json() + + assert classification_output.keys() == invocation_output.keys() + for classification_data, invocation_data in zip( + classification_output["data"], invocation_output["data"]): + assert classification_data.keys() == invocation_data.keys() + assert classification_data["probs"] == pytest.approx( + invocation_data["probs"], rel=0.01) diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 504fd72aa4ae..b20838956d72 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -153,3 +153,13 @@ def test_chat_template_validation_for_sad_paths(serve_parser): args = serve_parser.parse_args(args=["--chat-template", "does/not/exist"]) with pytest.raises(ValueError): validate_parsed_serve_args(args) + + +@pytest.mark.parametrize( + "cli_args, expected_middleware", + [(["--middleware", "middleware1", "--middleware", "middleware2" + ], ["middleware1", "middleware2"]), ([], [])]) +def test_middleware(serve_parser, cli_args, expected_middleware): + """Ensure multiple middleware args are parsed properly""" + args = serve_parser.parse_args(args=cli_args) + assert args.middleware == expected_middleware diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 7933ca5cd6c6..6eca3e767f3f 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # imports for guided decoding tests import json +import os import shutil from tempfile import TemporaryDirectory from typing import Optional @@ -11,6 +12,7 @@ import pytest import pytest_asyncio import regex as re +import requests # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -25,10 +27,6 @@ # technically these adapters use a different base model, # but we're not testing generation quality here LORA_NAME = "typeof/zephyr-7b-beta-lora" -PA_NAME = "swapnilbp/llama_tweet_ptune" -# if PA_NAME changes, PA_NUM_VIRTUAL_TOKENS might also -# need to change to match the prompt adapter -PA_NUM_VIRTUAL_TOKENS = 8 GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"] @@ -55,13 +53,7 @@ def zephyr_lora_added_tokens_files(zephyr_lora_files): @pytest.fixture(scope="module") -def zephyr_pa_files(): - return snapshot_download(repo_id=PA_NAME) - - -@pytest.fixture(scope="module") -def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, - zephyr_pa_files): +def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files): return [ # use half precision for speed and memory savings in CI environment "--dtype", @@ -80,15 +72,6 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, "64", "--max-cpu-loras", "2", - # pa config - "--enable-prompt-adapter", - "--prompt-adapters", - f"zephyr-pa={zephyr_pa_files}", - f"zephyr-pa2={zephyr_pa_files}", - "--max-prompt-adapters", - "2", - "--max-prompt-adapter-token", - "128", ] @@ -97,8 +80,19 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, def server(default_server_args, request): if request.param: default_server_args.append(request.param) - with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server + + original_value = os.environ.get('VLLM_USE_V1') + os.environ['VLLM_USE_V1'] = '0' + try: + with RemoteOpenAIServer(MODEL_NAME, + default_server_args) as remote_server: + yield remote_server + finally: + # Restore original env value + if original_value is None: + os.environ.pop('VLLM_USE_V1', None) + else: + os.environ['VLLM_USE_V1'] = original_value @pytest_asyncio.fixture @@ -109,14 +103,11 @@ async def client(server): @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters - "model_name,num_virtual_tokens", - [(MODEL_NAME, 0), ("zephyr-lora", 0), ("zephyr-lora2", 0), - ("zephyr-pa", PA_NUM_VIRTUAL_TOKENS), - ("zephyr-pa2", PA_NUM_VIRTUAL_TOKENS)], + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], ) -async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, - num_virtual_tokens: int): +async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): completion = await client.completions.create(model=model_name, prompt="Hello, my name is", max_tokens=5, @@ -129,9 +120,7 @@ async def test_single_completion(client: openai.AsyncOpenAI, model_name: str, assert len(choice.text) >= 5 assert choice.finish_reason == "length" assert completion.usage == openai.types.CompletionUsage( - completion_tokens=5, - prompt_tokens=6 + num_virtual_tokens, - total_tokens=11 + num_virtual_tokens) + completion_tokens=5, prompt_tokens=6, total_tokens=11) # test using token IDs completion = await client.completions.create( @@ -174,9 +163,9 @@ async def test_added_lora_tokens_base_model(client: openai.AsyncOpenAI): @pytest.mark.asyncio @pytest.mark.parametrize( - # first test base model, then test loras, then test prompt adapters + # first test base model, then test loras "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-lora2", "zephyr-pa", "zephyr-pa2"], + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], ) async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -193,9 +182,9 @@ async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( - # just test 1 lora and 1 pa hereafter + # just test 1 lora "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -216,7 +205,7 @@ async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): # test using token IDs @@ -237,7 +226,7 @@ async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, model_name: str): @@ -313,7 +302,7 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_completion_streaming(client: openai.AsyncOpenAI, model_name: str): @@ -347,7 +336,7 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): """Streaming for parallel sampling. @@ -381,7 +370,7 @@ async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_completion_stream_options(client: openai.AsyncOpenAI, model_name: str): @@ -518,7 +507,7 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", - [MODEL_NAME, "zephyr-lora", "zephyr-pa"], + [MODEL_NAME, "zephyr-lora"], ) async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): # test both text and token IDs @@ -833,3 +822,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI, assert content is not None and saying in content else: assert content is not None and saying not in content + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI): + request_args = { + "model": MODEL_NAME, + "prompt": "Hello, my name is", + "max_tokens": 5, + "temperature": 0.0, + "logprobs": None, + } + + completion = await client.completions.create(**request_args) + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + completion_output = completion.model_dump() + invocation_output = invocation_response.json() + + assert completion_output.keys() == invocation_output.keys() + assert completion_output["choices"] == invocation_output["choices"] diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py index 84ad7a09165a..eca048d855b5 100644 --- a/tests/entrypoints/openai/test_completion_with_function_calling.py +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -72,8 +72,43 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "The unit to fetch the temperature in", "enum": ["celsius", "fahrenheit"], }, + "options": { + "$ref": "#/$defs/WeatherOptions", + "description": + "Optional parameters for weather query", + }, }, "required": ["country", "unit"], + "$defs": { + "WeatherOptions": { + "title": "WeatherOptions", + "type": "object", + "additionalProperties": False, + "properties": { + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + "description": "Temperature unit", + "title": "Temperature Unit", + }, + "include_forecast": { + "type": "boolean", + "default": False, + "description": + "Whether to include a 24-hour forecast", + "title": "Include Forecast", + }, + "language": { + "type": "string", + "default": "zh-CN", + "description": "Language of the response", + "title": "Language", + "enum": ["zh-CN", "en-US", "ja-JP"], + }, + }, + }, + }, }, }, }, @@ -145,7 +180,11 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str, "enable_thinking": enable_thinking } }) - + if enable_thinking: + assert chat_completion.choices[0].message.\ + reasoning_content is not None + assert chat_completion.choices[0].message.\ + reasoning_content != "" assert chat_completion.choices[0].message.tool_calls is not None assert len(chat_completion.choices[0].message.tool_calls) > 0 else: diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index adb094127e40..f03c96b12179 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -14,6 +14,7 @@ from ...models.language.pooling.embed_utils import ( run_embedding_correctness_test) +from ...models.utils import check_embeddings_close from ...utils import RemoteOpenAIServer MODEL_NAME = "intfloat/multilingual-e5-small" @@ -296,3 +297,75 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI, assert "error" in response.object assert "truncate_prompt_tokens value is greater than max_model_len. "\ "Please, select a smaller truncation size." in response.message + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer, + client: openai.AsyncOpenAI): + input_texts = [ + "The chef prepared a delicious meal.", + ] + + request_args = { + "model": MODEL_NAME, + "input": input_texts, + "encoding_format": "float", + } + + completion_response = await client.embeddings.create(**request_args) + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + completion_output = completion_response.model_dump() + invocation_output = invocation_response.json() + + assert completion_output.keys() == invocation_output.keys() + for completion_data, invocation_data in zip(completion_output["data"], + invocation_output["data"]): + assert completion_data.keys() == invocation_data.keys() + check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="completion", + name_1="invocation") + + +@pytest.mark.asyncio +async def test_invocations_conversation(server: RemoteOpenAIServer): + messages = [{ + "role": "user", + "content": "The cat sat on the mat.", + }, { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }] + + request_args = { + "model": MODEL_NAME, + "messages": messages, + "encoding_format": "float", + } + + chat_response = requests.post(server.url_for("v1/embeddings"), + json=request_args) + chat_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + chat_output = chat_response.json() + invocation_output = invocation_response.json() + + assert chat_output.keys() == invocation_output.keys() + for chat_data, invocation_data in zip(chat_output["data"], + invocation_output["data"]): + assert chat_data.keys() == invocation_data.keys() + check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]], + embeddings_1_lst=[invocation_data["embedding"]], + name_0="chat", + name_1="invocation") diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 4ded37595384..580bf34f20c4 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json from typing import Final import pytest @@ -29,7 +30,7 @@ def server(): "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", - f"image={MAXIMUM_IMAGES}", + json.dumps({"image": MAXIMUM_IMAGES}), ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -95,6 +96,10 @@ def test_openapi_stateless(case: schemathesis.Case): case.operation.method.upper(), case.operation.path, ) + if case.operation.path.startswith("/v1/responses"): + # Skip responses API as it is meant to be stateful. + return + timeout = { # requires a longer timeout ("POST", "/v1/chat/completions"): diff --git a/tests/entrypoints/openai/test_pooling.py b/tests/entrypoints/openai/test_pooling.py index 41c30e71684b..02165ee6d58e 100644 --- a/tests/entrypoints/openai/test_pooling.py +++ b/tests/entrypoints/openai/test_pooling.py @@ -13,7 +13,7 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" +MODEL_NAME = "internlm/internlm2-1_8b-reward" DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 @@ -21,15 +21,16 @@ def server(): args = [ "--task", - "classify", + "reward", # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", "--enforce-eager", "--max-model-len", - "8192", + "512", "--chat-template", DUMMY_CHAT_TEMPLATE, + "--trust-remote-code", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 8 assert poolings.usage.completion_tokens == 0 - assert poolings.usage.prompt_tokens == 7 - assert poolings.usage.total_tokens == 7 + assert poolings.usage.prompt_tokens == 8 + assert poolings.usage.total_tokens == 8 # test using token IDs input_tokens = [1, 1, 1, 1, 1] @@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 1 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 5 assert poolings.usage.completion_tokens == 0 assert poolings.usage.prompt_tokens == 5 assert poolings.usage.total_tokens == 5 @@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 3 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 8 assert poolings.usage.completion_tokens == 0 - assert poolings.usage.prompt_tokens == 25 - assert poolings.usage.total_tokens == 25 + assert poolings.usage.prompt_tokens == 29 + assert poolings.usage.total_tokens == 29 # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], @@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str): assert poolings.id is not None assert len(poolings.data) == 4 - assert len(poolings.data[0].data) == 2 + assert len(poolings.data[0].data) == 5 assert poolings.usage.completion_tokens == 0 assert poolings.usage.prompt_tokens == 17 assert poolings.usage.total_tokens == 17 @@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer, chat_response.raise_for_status() chat_poolings = PoolingResponse.model_validate(chat_response.json()) - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + tokenizer = get_tokenizer( + tokenizer_name=model_name, + tokenizer_mode="fast", + trust_remote_code=True, + ) prompt = tokenizer.apply_chat_template( messages, chat_template=DUMMY_CHAT_TEMPLATE, @@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) float_response.raise_for_status() responses_float = PoolingResponse.model_validate(float_response.json()) + float_data = [ + np.array(d.data).squeeze(-1).tolist() for d in responses_float.data + ] base64_response = requests.post( server.url_for("pooling"), @@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, np.frombuffer(base64.b64decode(data.data), dtype="float32").tolist()) - check_embeddings_close( - embeddings_0_lst=[d.data for d in responses_float.data], - embeddings_1_lst=decoded_responses_base64_data, - name_0="float32", - name_1="base64") + check_embeddings_close(embeddings_0_lst=float_data, + embeddings_1_lst=decoded_responses_base64_data, + name_0="float32", + name_1="base64") # Default response is float32 decoded from base64 by OpenAI Client default_response = requests.post( @@ -240,9 +247,83 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer, ) default_response.raise_for_status() responses_default = PoolingResponse.model_validate(default_response.json()) + default_data = [ + np.array(d.data).squeeze(-1).tolist() for d in responses_default.data + ] + + check_embeddings_close(embeddings_0_lst=float_data, + embeddings_1_lst=default_data, + name_0="float32", + name_1="default") + + +@pytest.mark.asyncio +async def test_invocations(server: RemoteOpenAIServer): + input_texts = [ + "The chef prepared a delicious meal.", + ] + + request_args = { + "model": MODEL_NAME, + "input": input_texts, + "encoding_format": "float", + } + + completion_response = requests.post(server.url_for("pooling"), + json=request_args) + completion_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + completion_output = completion_response.json() + invocation_output = invocation_response.json() + + assert completion_output.keys() == invocation_output.keys() + for completion_data, invocation_data in zip(completion_output["data"], + invocation_output["data"]): + assert completion_data.keys() == invocation_data.keys() + check_embeddings_close(embeddings_0_lst=completion_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="completion", + name_1="invocation") + + +@pytest.mark.asyncio +async def test_invocations_conversation(server: RemoteOpenAIServer): + messages = [{ + "role": "user", + "content": "The cat sat on the mat.", + }, { + "role": "assistant", + "content": "A feline was resting on a rug.", + }, { + "role": "user", + "content": "Stars twinkle brightly in the night sky.", + }] + + request_args = { + "model": MODEL_NAME, + "messages": messages, + "encoding_format": "float", + } + + chat_response = requests.post(server.url_for("pooling"), json=request_args) + chat_response.raise_for_status() - check_embeddings_close( - embeddings_0_lst=[d.data for d in responses_default.data], - embeddings_1_lst=[d.data for d in responses_default.data], - name_0="float32", - name_1="base64") + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + chat_output = chat_response.json() + invocation_output = invocation_response.json() + + assert chat_output.keys() == invocation_output.keys() + for chat_data, invocation_data in zip(chat_output["data"], + invocation_output["data"]): + assert chat_data.keys() == invocation_data.keys() + check_embeddings_close(embeddings_0_lst=chat_data["data"], + embeddings_1_lst=invocation_data["data"], + name_0="chat", + name_1="invocation") diff --git a/tests/entrypoints/openai/test_rerank.py b/tests/entrypoints/openai/test_rerank.py index e40bbca9a8ad..4da97fe13691 100644 --- a/tests/entrypoints/openai/test_rerank.py +++ b/tests/entrypoints/openai/test_rerank.py @@ -94,3 +94,34 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str): # Assert just a small fragments of the response assert "Please reduce the length of the input." in \ rerank_response.text + + +def test_invocations(server: RemoteOpenAIServer): + query = "What is the capital of France?" + documents = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + request_args = { + "model": MODEL_NAME, + "query": query, + "documents": documents, + } + + rerank_response = requests.post(server.url_for("rerank"), + json=request_args) + rerank_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + rerank_output = rerank_response.json() + invocation_output = invocation_response.json() + + assert rerank_output.keys() == invocation_output.keys() + for rerank_result, invocations_result in zip(rerank_output["results"], + invocation_output["results"]): + assert rerank_result.keys() == invocations_result.keys() + assert rerank_result["relevance_score"] == pytest.approx( + invocations_result["relevance_score"], rel=0.01) diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index 099062e55c72..af58fbd4b364 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -13,7 +13,6 @@ from .test_completion import default_server_args # noqa: F401 from .test_completion import zephyr_lora_added_tokens_files # noqa: F401 from .test_completion import zephyr_lora_files # noqa: F401 -from .test_completion import zephyr_pa_files # noqa: F401 from .test_completion import MODEL_NAME diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py index 8927fe771809..187542b7bafc 100644 --- a/tests/entrypoints/openai/test_score.py +++ b/tests/entrypoints/openai/test_score.py @@ -191,3 +191,32 @@ def test_score_max_model_len(self, server: RemoteOpenAIServer, assert score_response.status_code == 400 assert "Please, select a smaller truncation size." in \ score_response.text + + def test_invocations(self, server: RemoteOpenAIServer, model: dict[str, + Any]): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + request_args = { + "model": model["name"], + "text_1": text_1, + "text_2": text_2, + } + + score_response = requests.post(server.url_for("score"), + json=request_args) + score_response.raise_for_status() + + invocation_response = requests.post(server.url_for("invocations"), + json=request_args) + invocation_response.raise_for_status() + + score_output = score_response.json() + invocation_output = invocation_response.json() + + assert score_output.keys() == invocation_output.keys() + for score_data, invocation_data in zip(score_output["data"], + invocation_output["data"]): + assert score_data.keys() == invocation_data.keys() + assert score_data["score"] == pytest.approx( + invocation_data["score"], rel=0.01) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index ad80946b5671..8a7892cf6d6a 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -7,6 +7,8 @@ from typing import Any, Optional from unittest.mock import MagicMock +import pytest + from vllm.config import MultiModalConfig from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest @@ -73,7 +75,8 @@ def test_async_serving_chat_init(): assert serving_completion.chat_template == CHAT_TEMPLATE -def test_serving_chat_should_set_correct_max_tokens(): +@pytest.mark.asyncio +async def test_serving_chat_should_set_correct_max_tokens(): mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False @@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens(): chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", request_logger=None) + req = ChatCompletionRequest( model=MODEL_NAME, messages=[{ @@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 93 req.max_tokens = 10 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 10 @@ -143,7 +147,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 10 @@ -151,7 +155,7 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 15 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 10 @@ -159,7 +163,7 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 5 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 5 @@ -198,7 +202,7 @@ def test_serving_chat_should_set_correct_max_tokens(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 93 @@ -206,7 +210,7 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 100 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 93 @@ -214,12 +218,13 @@ def test_serving_chat_should_set_correct_max_tokens(): req.max_tokens = 5 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].max_tokens == 5 -def test_serving_chat_could_load_correct_generation_config(): +@pytest.mark.asyncio +async def test_serving_chat_could_load_correct_generation_config(): mock_model_config = MockModelConfig() mock_model_config.diff_sampling_param = { @@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config(): chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", request_logger=None) + req = ChatCompletionRequest( model=MODEL_NAME, messages=[{ @@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config(): ) with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].temperature == 0.5 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 @@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config(): req.temperature = 0.1 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].temperature == 0.1 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 @@ -270,13 +276,14 @@ def test_serving_chat_could_load_correct_generation_config(): req.temperature = 0.0 with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 -def test_serving_chat_did_set_correct_cache_salt(): +@pytest.mark.asyncio +async def test_serving_chat_did_set_correct_cache_salt(): mock_model_config = MockModelConfig() mock_engine = MagicMock(spec=MQLLMEngineClient) @@ -306,11 +313,11 @@ def test_serving_chat_did_set_correct_cache_salt(): # By default cache_salt in the engine prompt is not set with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert "cache_salt" not in mock_engine.generate.call_args.args[0] # Test with certain cache_salt req.cache_salt = "test_salt" with suppress(Exception): - asyncio.run(serving_chat.create_chat_completion(req)) + await serving_chat.create_chat_completion(req) assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 5f334c754a3f..c3b458d717fb 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -32,8 +32,7 @@ async def _async_serving_models_init() -> OpenAIServingModels: serving_models = OpenAIServingModels(engine_client=mock_engine_client, base_model_paths=BASE_MODEL_PATHS, model_config=mock_model_config, - lora_modules=None, - prompt_adapters=None) + lora_modules=None) await serving_models.init_static_loras() return serving_models diff --git a/tests/entrypoints/openai/test_tensorizer_entrypoint.py b/tests/entrypoints/openai/test_tensorizer_entrypoint.py index e143150356d9..4bf379850365 100644 --- a/tests/entrypoints/openai/test_tensorizer_entrypoint.py +++ b/tests/entrypoints/openai/test_tensorizer_entrypoint.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import gc -import json +import os import tempfile import openai @@ -58,18 +58,20 @@ def tensorize_model_and_lora(tmp_dir, model_uri): @pytest.fixture(scope="module") def server(model_uri, tensorize_model_and_lora): - model_loader_extra_config = { - "tensorizer_uri": model_uri, - } + # In this case, model_uri is a directory with a model.tensors + # file and all necessary model artifacts, particularly a + # HF `config.json` file. In this case, Tensorizer can infer the + # `TensorizerConfig` so --model-loader-extra-config can be completely + # omitted. ## Start OpenAI API server args = [ - "--load-format", "tensorizer", "--device", "cuda", - "--model-loader-extra-config", - json.dumps(model_loader_extra_config), "--enable-lora" + "--load-format", "tensorizer", "--served-model-name", MODEL_NAME, + "--enable-lora" ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + model_dir = os.path.dirname(model_uri) + with RemoteOpenAIServer(model_dir, args) as remote_server: yield remote_server diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 57dd25fe1b16..0dbbdfbfd24a 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811 f"zephyr-lora2={zephyr_lora_added_tokens_files}", "--max-lora-rank", "64", + "--enable-tokenizer-info-endpoint", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -283,3 +284,106 @@ async def test_detokenize( response.raise_for_status() assert response.json() == {"prompt": prompt} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], +) +async def test_tokenizer_info_basic( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): + """Test basic tokenizer info endpoint functionality.""" + response = requests.get(server.url_for("tokenizer_info")) + response.raise_for_status() + result = response.json() + assert "tokenizer_class" in result + assert isinstance(result["tokenizer_class"], str) + assert result["tokenizer_class"] + + +@pytest.mark.asyncio +async def test_tokenizer_info_schema(server: RemoteOpenAIServer): + """Test that the response matches expected schema types.""" + response = requests.get(server.url_for("tokenizer_info")) + response.raise_for_status() + result = response.json() + field_types = { + "add_bos_token": bool, + "add_prefix_space": bool, + "clean_up_tokenization_spaces": bool, + "split_special_tokens": bool, + "bos_token": str, + "eos_token": str, + "pad_token": str, + "unk_token": str, + "chat_template": str, + "errors": str, + "model_max_length": int, + "additional_special_tokens": list, + "added_tokens_decoder": dict, + } + for field, expected_type in field_types.items(): + if field in result and result[field] is not None: + assert isinstance( + result[field], + expected_type), (f"{field} should be {expected_type.__name__}") + + +@pytest.mark.asyncio +async def test_tokenizer_info_added_tokens_structure( + server: RemoteOpenAIServer, ): + """Test added_tokens_decoder structure if present.""" + response = requests.get(server.url_for("tokenizer_info")) + response.raise_for_status() + result = response.json() + added_tokens = result.get("added_tokens_decoder") + if added_tokens: + for token_id, token_info in added_tokens.items(): + assert isinstance(token_id, str), "Token IDs should be strings" + assert isinstance(token_info, dict), "Token info should be a dict" + assert "content" in token_info, "Token info should have content" + assert "special" in token_info, ( + "Token info should have special flag") + assert isinstance(token_info["special"], + bool), ("Special flag should be boolean") + + +@pytest.mark.asyncio +async def test_tokenizer_info_consistency_with_tokenize( + server: RemoteOpenAIServer, ): + """Test that tokenizer info is consistent with tokenization endpoint.""" + info_response = requests.get(server.url_for("tokenizer_info")) + info_response.raise_for_status() + info = info_response.json() + tokenize_response = requests.post( + server.url_for("tokenize"), + json={ + "model": MODEL_NAME, + "prompt": "Hello world!" + }, + ) + tokenize_response.raise_for_status() + tokenize_result = tokenize_response.json() + info_max_len = info.get("model_max_length") + tokenize_max_len = tokenize_result.get("max_model_len") + if info_max_len and tokenize_max_len: + assert info_max_len >= tokenize_max_len, ( + "Info max length should be >= tokenize max length") + + +@pytest.mark.asyncio +async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer): + """Test chat template is properly included.""" + response = requests.get(server.url_for("tokenizer_info")) + response.raise_for_status() + result = response.json() + chat_template = result.get("chat_template") + if chat_template: + assert isinstance(chat_template, + str), ("Chat template should be a string") + assert chat_template.strip(), "Chat template should not be empty" \ No newline at end of file diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index e1d175d9c6e1..a8e2eb40b157 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -17,6 +17,11 @@ from ...utils import RemoteOpenAIServer +MISTRAL_FORMAT_ARGS = [ + "--tokenizer_mode", "mistral", "--config_format", "mistral", + "--load_format", "mistral" +] + @pytest.fixture def mary_had_lamb(): @@ -33,9 +38,15 @@ def winning_call(): @pytest.mark.asyncio -async def test_basic_audio(mary_had_lamb): - model_name = "openai/whisper-large-v3-turbo" +@pytest.mark.parametrize( + "model_name", + ["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"]) +async def test_basic_audio(mary_had_lamb, model_name): server_args = ["--enforce-eager"] + + if model_name.startswith("mistralai"): + server_args += MISTRAL_FORMAT_ARGS + # Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb. with RemoteOpenAIServer(model_name, server_args) as remote_server: client = remote_server.get_async_client() @@ -65,10 +76,13 @@ async def test_bad_requests(mary_had_lamb): @pytest.mark.asyncio -async def test_long_audio_request(mary_had_lamb): - model_name = "openai/whisper-large-v3-turbo" +@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"]) +async def test_long_audio_request(mary_had_lamb, model_name): server_args = ["--enforce-eager"] + if model_name.startswith("openai"): + return + mary_had_lamb.seek(0) audio, sr = librosa.load(mary_had_lamb) # Add small silence after each audio for repeatability in the split process @@ -87,7 +101,8 @@ async def test_long_audio_request(mary_had_lamb): response_format="text", temperature=0.0) out = json.loads(transcription)['text'] - assert out.count("Mary had a little lamb") == 10 + counts = out.count("Mary had a little lamb") + assert counts == 10, counts @pytest.mark.asyncio @@ -154,7 +169,8 @@ async def post_with_stream(*args, **kwargs): file=winning_call, language="en", temperature=0.0, - extra_body=dict(stream=True)) + extra_body=dict(stream=True), + timeout=30) # Reconstruct from chunks and validate async for chunk in res: # just a chunk @@ -184,7 +200,8 @@ async def post_with_stream(*args, **kwargs): temperature=0.0, extra_body=dict(stream=True, stream_include_usage=True, - stream_continuous_usage_stats=True)) + stream_continuous_usage_stats=True), + timeout=30) final = False continuous = True async for chunk in res: diff --git a/tests/entrypoints/openai/test_translation_validation.py b/tests/entrypoints/openai/test_translation_validation.py index 0c2cb367f330..79e769e3a1aa 100644 --- a/tests/entrypoints/openai/test_translation_validation.py +++ b/tests/entrypoints/openai/test_translation_validation.py @@ -39,8 +39,8 @@ async def test_basic_audio(foscolo): # TODO remove once language detection is implemented extra_body=dict(language="it"), temperature=0.0) - out = json.loads(translation)['text'].strip() - assert "Nor will I ever touch the sacred" in out + out = json.loads(translation)['text'].strip().lower() + assert "greek sea" in out @pytest.mark.asyncio @@ -168,5 +168,4 @@ async def test_long_audio_request(foscolo): response_format="text", temperature=0.0) out = json.loads(translation)['text'].strip().lower() - # TODO investigate higher model uncertainty in for longer translations. - assert out.count("nor will i ever") == 2 + assert out.count("greek sea") == 2 diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index fd613842f986..b6f1d64803e5 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -36,11 +36,11 @@ ], [ "The image shows a Venn diagram with three over", - "This image shows a Venn diagram with three over", + "The image shows a Venn diagram with three intersect", ], [ "This image displays a gradient of colors ranging from", - "This image displays a gradient of colors transitioning from", + "The image displays a gradient of colors ranging from", ], ] diff --git a/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py new file mode 100644 index 000000000000..bd8e06513e13 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_hunyuan_a13b_tool_parser.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json +from unittest.mock import MagicMock + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, run_tool_extraction_streaming) +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + + +def make_tool_call(name, arguments): + return ToolCall(type="function", + function=FunctionCall(name=name, + arguments=json.dumps(arguments))) + + +# TODO: add reason prefix and suffix. + + +@pytest.mark.parametrize( + "model_output,expected_tool_calls,expected_content", + [ + # No tool call + ("How can I help you today?", [], "How can I help you today?"), + # Single tool call, no content + ( + "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501 + [ + make_tool_call("get_weather", { + "city": "San Francisco", + "metric": "celsius" + }) + ], + None), + # Multiple tool calls + ( + "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501 + [ + make_tool_call("get_weather", { + "city": "San Francisco", + "metric": "celsius" + }), + make_tool_call( + "register_user", { + "name": "John Doe", + "age": 37, + "address": { + "city": "San Francisco", + "state": "CA" + }, + "role": None, + "passed_test": True, + "aliases": ["John", "Johnny"] + }) + ], + None), + # Content before tool call + ( + "I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501 + [make_tool_call("get_weather", {"city": "Boston"})], + "I will call the tool now. "), + # Content after tool call (should be stripped) + ( + "<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501 + [make_tool_call("get_weather", {"city": "Seattle"})], + None), + ( + "<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>", + [ + make_tool_call( + "complex_tool", + {"level1": { + "level2": { + "level3": { + "value": 123 + } + } + }}) + ], + None, + ), + ]) +def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls, + expected_content): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "hunyuan_a13b")(mock_tokenizer) + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=False) + + # align the random id. + for idx in range(len(tool_calls)): + tool_calls[idx].id = expected_tool_calls[idx].id + assert tool_calls == expected_tool_calls + assert content == expected_content + + +# Streaming test: simulate incremental output +@pytest.mark.parametrize("model_deltas,expected_tool_calls", [ + ([ + "<tool_calls>[{\"name\": \"get_weather\", ", + "\"arguments\": {\"city\": \"San Francisco\", ", + "\"metric\": \"celsius\"}}]", "</tool_calls>" + ], [ + make_tool_call("get_weather", { + "city": "San Francisco", + "metric": "celsius" + }) + ]), + ([ + "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":", + " {\"city\": \"Boston\"}", "}]", "</tool_calls>" + ], [make_tool_call("get_weather", {"city": "Boston"})]), + ([ + "", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":", + " {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>" + ], [make_tool_call("get_weather", {"city": "Boston"})]), + pytest.param([ + "<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ", + " {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}", + "]</tool_calls>" + ], [ + make_tool_call("complex_tool", + {"level1": { + "level2": { + "level3": { + "value": 123 + } + } + }}) + ], + marks=pytest.mark.xfail( + reason="stream parsing not support nested json yet.")), +]) +def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls): + mock_tokenizer = MagicMock() + + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "hunyuan_a13b")(mock_tokenizer) + reconstructor = run_tool_extraction_streaming( + tool_parser, model_deltas, assert_one_tool_per_delta=False) + + # align the random id. + for idx in range(len(reconstructor.tool_calls)): + reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id + + assert reconstructor.tool_calls == expected_tool_calls diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index e41ea686e992..ed57fe39df64 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -2,11 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import warnings -from typing import Optional +from collections.abc import Mapping +from typing import Literal, Optional import pytest +from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy, + SpecialTokens) +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, + Tekkenizer) +from vllm.assets.audio import AudioAsset from vllm.assets.image import ImageAsset +from vllm.assets.video import VideoAsset from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template, parse_chat_messages, @@ -15,8 +22,10 @@ resolve_hf_chat_template) from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import encode_image_base64 +from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64, + encode_video_base64) from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH @@ -28,6 +37,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct" QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" +QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B" MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B" @@ -48,6 +58,21 @@ def phi3v_model_config(): }) +@pytest.fixture(scope="function") +def phi3v_model_config_mm_interleaved(): + return ModelConfig(PHI3V_MODEL_ID, + task="generate", + tokenizer=PHI3V_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="auto", + seed=0, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + }) + + @pytest.fixture(scope="module") def phi3v_tokenizer(): return TokenizerGroup( @@ -58,6 +83,32 @@ def phi3v_tokenizer(): ) +@pytest.fixture(scope="function") +def qwen25omni_model_config_mm_interleaved(): + return ModelConfig(QWEN25OMNI_MODEL_ID, + task="generate", + tokenizer=QWEN25OMNI_MODEL_ID, + tokenizer_mode="auto", + dtype="auto", + seed=0, + interleave_mm_strings=True, + limit_mm_per_prompt={ + "image": 2, + "audio": 1, + "video": 1, + }) + + +@pytest.fixture(scope="module") +def qwen25omni_tokenizer(): + return TokenizerGroup( + tokenizer_id=QWEN25OMNI_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + @pytest.fixture(scope="module") def mllama_model_config(): return ModelConfig(MLLAMA_MODEL_ID, @@ -113,6 +164,20 @@ def image_url(): return f"data:image/jpeg;base64,{base64}" +@pytest.fixture(scope="module") +def video_url(): + video = VideoAsset('baby_reading', 1) + base64 = encode_video_base64(video.np_ndarrays) + return f"data:video/jpeg;base64,{base64}" + + +@pytest.fixture(scope="module") +def audio_url(): + audio = AudioAsset('mary_had_lamb') + base64 = encode_audio_base64(*audio.audio_and_sample_rate) + return f"data:audio/ogg;base64,{base64}" + + def _assert_mm_data_is_image_input( mm_data: Optional[MultiModalDataDict], image_count: int, @@ -126,6 +191,23 @@ def _assert_mm_data_is_image_input( assert isinstance(image_data, list) and len(image_data) == image_count +ModalityType = Literal["image", "video", "audio"] +MultiModalDataCounts = Mapping[ModalityType, int] + + +def _assert_mm_data_inputs( + mm_data: Optional[MultiModalDataDict], + data_count: MultiModalDataCounts, +) -> None: + assert mm_data is not None + assert set(data_count.keys()) == (set(mm_data.keys())) + + for modality, n in data_count.items(): + modality_data = mm_data.get(modality) + assert modality_data is not None + assert isinstance(modality_data, list) and len(modality_data) == n + + def test_parse_chat_messages_single_image( phi3v_model_config, phi3v_tokenizer, @@ -637,6 +719,277 @@ def test_parse_chat_messages_multiple_images_uncommon_input( _assert_mm_data_is_image_input(mm_data, 2) +def test_parse_chat_messages_multiple_images_interleave( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [{ + "type": "text", + "text": "I need you to compare this image" + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "and this one" + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "Do they have differences?" + }] + }], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?" + }] + _assert_mm_data_is_image_input(mm_data, 2) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_interleave_async( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages_futures( + [{ + "role": + "user", + "content": [{ + "type": "text", + "text": "I need you to compare this image" + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "and this one" + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "Do they have differences?" + }] + }], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?" + }] + _assert_mm_data_is_image_input(await mm_data, 2) + + +def test_parse_chat_messages_multiple_images_multiple_messages_interleave( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Be accurate." + }, + ] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "text", + "text": "What's on this image?" + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }] + }], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "What's on this image?\n<|image_1|>\nBe accurate." + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": "user", + "content": "What's on this image?\n<|image_2|>" + }] + _assert_mm_data_is_image_input(mm_data, 2) + + +def test_parse_chat_messages_multiple_modals_multiple_messages_interleave( + qwen25omni_model_config_mm_interleaved, qwen25omni_tokenizer, + image_url, video_url, audio_url): + conversation, mm_data = parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's on this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "Now listen to this audio" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + } + }, + ] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "text", + "text": "What's on this image?" + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "And what's in the video?" + }, { + "type": "video_url", + "video_url": { + "url": video_url + } + }] + }], + qwen25omni_model_config_mm_interleaved, + qwen25omni_tokenizer, + content_format="string", + ) + + assert conversation == [{ + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "Now listen to this audio\nAudio 1: <|audio_bos|><|AUDIO|><|audio_eos|>" + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": + "What's on this image?\n<|vision_start|><|IMAGE|><|vision_end|>\n" + "And what's in the video?\n<|vision_start|><|VIDEO|><|vision_end|>" + }] + + _assert_mm_data_inputs(mm_data, {"image": 2, "video": 1, "audio": 1}) + + +def test_parse_chat_messages_multiple_images_interleave_with_placeholders( + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + image_url, +): + with pytest.raises( + ValueError, + match=r"Found more '<|image_1|>' placeholders in input prompt " + "than actual multimodal data items."): + parse_chat_messages( + [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": + "text", + "text": + "I need you to compare this image\n<|image_1|>\nand this one\n<|image_2|>\n" # noqa: E501 + "Do they have differences?" + }, + ] + }], + phi3v_model_config_mm_interleaved, + phi3v_tokenizer, + content_format="string", + ) + + ### Mllama currently wraps images / texts as interleaved dictionaries def test_mllama_single_image( mllama_model_config, @@ -1026,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format): ) assert resolved_format == expected_format + + +def test_parse_chat_messages_include_thinking_chunk(mistral_model_config, + mistral_tokenizer): + messages = [{ + "role": + "system", + "content": [{ + "type": "text", + "text": "You are a helpful assistant." + }, { + "type": + "thinking", + "closed": + True, + "thinking": + "Only return the answer when you are confident." + }] + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": + "assistant", + "content": [{ + "type": "text", + "text": "Let me think about it." + }, { + "type": "thinking", + "closed": True, + "thinking": "2+2 = 4" + }, { + "type": "text", + "text": "The answer is 4.", + }], + }] + + conversation_with_thinking, _ = parse_chat_messages( + messages, + mistral_model_config, + mistral_tokenizer, + content_format="openai", + ) + + expected_conversation = [{ + "role": + "system", + "content": [{ + "type": "text", + "text": "You are a helpful assistant." + }, { + "type": "text", + "text": "Only return the answer when you are confident." + }], + }, { + "role": + "user", + "content": [{ + "type": "text", + "text": "What is 2+2?" + }], + }, { + "role": + "assistant", + "content": [ + { + "type": "text", + "text": "Let me think about it." + }, + { + "type": "text", + "text": "2+2 = 4" + }, + { + "type": "text", + "text": "The answer is 4." + }, + ] + }] + + assert conversation_with_thinking == expected_conversation + + +def test_apply_mistral_chat_template_thinking_chunk(): + # Moved import here to avoid yapf and isort conflicts + from vllm.entrypoints.chat_utils import apply_mistral_chat_template + messages = [{ + "role": + "system", + "content": [{ + "type": "text", + "text": "You are a helpful assistant." + }, { + "type": + "thinking", + "closed": + True, + "thinking": + "Only return the answer when you are confident." + }] + }, { + "role": "user", + "content": "What is 2+2?" + }, { + "role": + "assistant", + "content": [{ + "type": "text", + "text": "Let me think about it." + }, { + "type": "thinking", + "closed": True, + "thinking": "2+2 = 4" + }, { + "type": "text", + "text": "The answer is 4.", + }], + }, { + "role": "user", + "content": "Thanks, what is 3+3?" + }] + + # TODO(Julien): upon model release change to a tokenizer already configured. + # ================================================================= + mistral_tokenizer = MistralTokenizer.from_pretrained( + "mistralai/Devstral-Small-2507") + assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) + # Add think special tokens to the tokenizer + mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( + rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) + mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( + rank=36, is_control=True, token_str=SpecialTokens.end_think.value) + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { + k: v + for k, v in + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() + if v not in {35, 36} + } + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.begin_think.value] = 35 + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.end_think.value] = 36 + mistral_tokenizer.instruct.BEGIN_THINK = 35 + mistral_tokenizer.instruct.END_THINK = 36 + # ================================================================= + + tokens_ids = apply_mistral_chat_template(mistral_tokenizer, + messages, + chat_template=None, + tools=None) + + string_tokens = mistral_tokenizer.mistral.decode( + tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP) + + expected_tokens = ( + r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the" + r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]" + r"[INST]What is 2+2?[/INST]" + r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>" + r"[INST]Thanks, what is 3+3?[/INST]") + + assert string_tokens == expected_tokens diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index 3722e0eb537f..93bf20da4adb 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -36,7 +36,8 @@ def clear_cache(): DEVICE_MLA_BLOCK_SIZES = { "cuda": [16, 64], # CUDA supports both standard and extended block sizes "hip": [16, 1], # HIP requires special handling for block_size=1 - "cpu": [16] # CPU uses fixed block size from test cases + # "cpu": [16] # CPU uses fixed block size from test cases + "cpu": [] # FIXME(woosuk): Temporarily disable CPU tests } @@ -81,14 +82,14 @@ def test_env( m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0") if device == "cpu": + if not use_v1: + pytest.skip("CPU backend only supports V1") + with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float16, torch.float16, block_size, False) - if use_v1: - assert backend.get_name() == "TORCH_SDPA_VLLM_V1" - else: - assert backend.get_name() == "TORCH_SDPA" + assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "hip": with patch("vllm.attention.selector.current_platform", @@ -204,12 +205,14 @@ def test_fp32_fallback( m.setenv("VLLM_USE_V1", "1" if use_v1 else "0") if device == "cpu": + if not use_v1: + pytest.skip("CPU backend only supports V1") + with patch("vllm.attention.selector.current_platform", CpuPlatform()): backend = get_attn_backend(16, torch.float32, torch.float32, 16, False) - assert (backend.get_name() == "TORCH_SDPA_VLLM_V1" - if use_v1 else "TORCH_SDPA") + assert backend.get_name() == "TORCH_SDPA_VLLM_V1" elif device == "cuda": with patch("vllm.attention.selector.current_platform", diff --git a/tests/kernels/attention/test_blocksparse_attention.py b/tests/kernels/attention/test_blocksparse_attention.py deleted file mode 100644 index 9aee818c9956..000000000000 --- a/tests/kernels/attention/test_blocksparse_attention.py +++ /dev/null @@ -1,441 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from typing import Optional - -import pytest -import torch - -from tests.kernels.allclose_default import get_default_atol, get_default_rtol -from vllm import _custom_ops as ops -from vllm.attention.ops.blocksparse_attention.interface import ( - LocalStridedBlockSparseAttn) -from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes - -FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 -# This will change depending on the compute capability. -# - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -# MAX_SEQ_LEN = 2771 - -# There may not be enough gpu memory due to large NUM_BLOCKS. -# Reduce NUM_BLOCKS when it happens. -NUM_BLOCKS = 4321 # Arbitrary values for testing -PARTITION_SIZE = 512 -DTYPES = [torch.half, torch.bfloat16] -NUM_GEN_SEQS = [3] # Arbitrary values for testing -NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -NUM_HEADS = [(40, 40)] # Arbitrary values for testing - -HEAD_SIZES = [64, 112] -BLOCK_SIZES = [16] -USE_ALIBI = [False, True] -KV_CACHE_DTYPE = ["auto", "fp8"] -SEEDS = [0] -CUDA_DEVICES = ['cuda:0'] -BLOCKSPARSE_LOCAL_BLOCKS = [16] -BLOCKSPARSE_VERT_STRIDES = [8] - -BLOCKSPARSE_BLOCK_SIZES = [64] -BLOCKSPARSE_HEADS_SLIDINGS = [2, -1] -BLOCKSPARSE_HOMO_HEADS = [True, False] - - -def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - attn_mask: Optional[torch.Tensor] = None, -) -> torch.Tensor: - attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() - if attn_mask is not None: - attn_weights = attn_weights + attn_mask.float() - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.einsum("hqk,khd->qhd", attn_weights, value) - return out - - -def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 1, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, -) -> None: - num_query_heads = query.shape[1] - num_kv_heads = value_cache.shape[1] - head_size = value_cache.shape[2] - block_size = value_cache.shape[3] - num_seqs = query.shape[0] - - block_tables_lst = block_tables.cpu().tolist() - seq_lens_lst = seq_lens.cpu().tolist() - for i in range(num_seqs): - q = query[i].unsqueeze(0) - block_table = block_tables_lst[i] - seq_len = int(seq_lens_lst[i]) - - keys_lst: list[torch.Tensor] = [] - values_lst: list[torch.Tensor] = [] - for j in range(seq_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_kv_heads, head_size) - keys_lst.append(k) - - v = value_cache[block_number, :, :, block_offset] - values_lst.append(v) - keys = torch.stack(keys_lst, dim=0) - values = torch.stack(values_lst, dim=0) - if num_queries_per_kv > 1: - # Handle MQA and GQA - keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) - values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - - alibi_bias = None - if alibi_slopes is not None: - # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(seq_len).int() - alibi_bias = (position_ids - seq_len + 1).float() - alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( - 1, 1, -1) - - if blocksparse_vert_stride >= 1: - bsize = blocksparse_block_size - hsliding = blocksparse_head_sliding_step - vert = blocksparse_vert_stride - locals = blocksparse_local_blocks - qb = (seq_len - 1) // bsize - attn_mask = q.new_zeros( - (num_query_heads, 1, seq_len)).float() - torch.inf - for h in range(num_query_heads): - if hsliding >= 0: # slide with q heads - bs_offset = (tp_rank * num_query_heads + h) * hsliding + 1 - else: # slide with kv heads - bs_offset = (tp_rank * num_kv_heads + - h // num_queries_per_kv) * (-hsliding) + 1 - for kb in range(qb + 1): - kj = kb * bsize - if (qb - kb) < locals or \ - (kb + bs_offset) % vert == 0: - attn_mask[h, 0, kj:min(kj + bsize, seq_len)] = 0 - if alibi_bias is not None: - attn_mask += alibi_bias - else: - attn_mask = alibi_bias - - out = ref_masked_attention(q, keys, values, scale, attn_mask=attn_mask) - out = out.view(num_query_heads, head_size) - output[i].copy_(out, non_blocking=True) - - -@pytest.mark.parametrize("version", ["v1", "v2"]) -@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("use_alibi", USE_ALIBI) -@pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) -@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) -@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) -@pytest.mark.parametrize("blocksparse_head_sliding_step", - BLOCKSPARSE_HEADS_SLIDINGS) -def test_paged_attention( - kv_cache_factory, - version: str, - num_seqs: int, - num_heads: tuple[int, int], - head_size: int, - use_alibi: bool, - block_size: int, - dtype: torch.dtype, - kv_cache_dtype: str, - seed: int, - device: str, - blocksparse_local_blocks: int, - blocksparse_vert_stride: int, - blocksparse_block_size: int, - blocksparse_head_sliding_step: int, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) - query.uniform_(-scale, scale) - - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None - if use_alibi: - alibi_slopes = torch.rand(num_query_heads, dtype=torch.float) - - seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] - seq_lens[-1] = MAX_SEQ_LEN - max_seq_len = max(seq_lens) - seq_lens = torch.tensor(seq_lens, dtype=torch.int) - - # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_seqs): - block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int) - - # Create the KV caches. - key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, - num_kv_heads, head_size, - kv_cache_dtype, dtype, seed, - device) - key_cache, value_cache = key_caches[0], value_caches[0] - - # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) - tp_rank = 0 - - # Call the paged attention kernel. - output = torch.empty_like(query) - if version == "v1": - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank=tp_rank, - blocksparse_local_blocks=blocksparse_local_blocks, - blocksparse_vert_stride=blocksparse_vert_stride, - blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step, - ) - elif version == "v2": - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) - assert PARTITION_SIZE % block_size == 0 - num_seqs, num_heads, head_size = output.shape - tmp_output = torch.empty( - size=(num_seqs, num_heads, num_partitions, head_size), - dtype=output.dtype, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, num_partitions), - dtype=torch.float32, - ) - max_logits = torch.empty_like(exp_sums) - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank=tp_rank, - blocksparse_local_blocks=blocksparse_local_blocks, - blocksparse_vert_stride=blocksparse_vert_stride, - blocksparse_block_size=blocksparse_block_size, - blocksparse_head_sliding_step=blocksparse_head_sliding_step, - ) - else: - raise AssertionError(f"Unknown version: {version}") - - # Run the reference implementation. - if kv_cache_dtype == "fp8": - # Convert cache data back to dtype. - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, - block_size, x) - dequantized_key_cache = torch.empty(size=key_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(dequantized_key_cache, key_cache) - key_cache = dequantized_key_cache - - value_cache_shape = value_cache.shape - dequantized_value_cache = torch.empty(size=value_cache_shape, - dtype=dtype, - device=device) - ops.convert_fp8(dequantized_value_cache, value_cache) - value_cache = dequantized_value_cache - - ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - - # NOTE(woosuk): Due to the kernel-level differences in the two - # implementations, there is a small numerical difference in the two - # outputs. Thus, we use a relaxed tolerance for the test. - atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 - rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - - # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, - # so we use a relaxed tolerance for the test. - atol, rtol = 1e-3, 1e-5 - if kv_cache_dtype == "fp8": - atol, rtol = 1e-2, 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) - - -def ref_multi_query_kv_attention( - cu_seq_lens: list[int], - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, - dtype: torch.dtype, -) -> torch.Tensor: - num_seqs = len(cu_seq_lens) - 1 - ref_outputs = [] - for i in range(num_seqs): - start_idx = cu_seq_lens[i] - end_idx = cu_seq_lens[i + 1] - seq_len = end_idx - start_idx - - # Create attention mask. - attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), - diagonal=1) - attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype) - - ref_output = ref_masked_attention( - query[start_idx:end_idx], - key[start_idx:end_idx], - value[start_idx:end_idx], - scale, - attn_mask=attn_mask, - ) - ref_outputs.append(ref_output) - ref_output = torch.cat(ref_outputs, dim=0) - return ref_output - - -@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("blocksparse_local_blocks", BLOCKSPARSE_LOCAL_BLOCKS) -@pytest.mark.parametrize("blocksparse_vert_stride", BLOCKSPARSE_VERT_STRIDES) -@pytest.mark.parametrize("blocksparse_block_size", BLOCKSPARSE_BLOCK_SIZES) -@pytest.mark.parametrize("blocksparse_homo_heads", BLOCKSPARSE_HOMO_HEADS) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_varlen_blocksparse_attention_prefill( - num_seqs: int, - num_heads: tuple[int, int], - head_size: int, - blocksparse_local_blocks: int, - blocksparse_vert_stride: int, - blocksparse_block_size: int, - blocksparse_homo_heads: bool, - dtype: torch.dtype, - seed: int, - device: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. - # As the xformers library is already tested with its own tests, we can use - # a smaller MAX_SEQ_LEN here. - max_len = min(MAX_SEQ_LEN, 4096) - seq_lens = random.sample(range(1, max_len), num_seqs) - cu_seq_lens = torch.cumsum(torch.tensor([0] + seq_lens), dim=0) - num_tokens = sum(seq_lens) - - scale = float(1.0 / (head_size**0.5)) - num_query_heads, num_kv_heads = num_heads - assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - - qkv = torch.empty(num_tokens, - num_query_heads + 2 * num_kv_heads, - head_size, - dtype=dtype) - qkv.uniform_(-scale, scale) - query, key, value = qkv.split( - [num_query_heads, num_kv_heads, num_kv_heads], dim=1) - - bs_attn_op = LocalStridedBlockSparseAttn( - num_query_heads, - max_len, - local_blocks=blocksparse_local_blocks, - vert_stride=blocksparse_vert_stride, - block_size=blocksparse_block_size, - device=device, - dtype=dtype, - homo_head=blocksparse_homo_heads) - - output = bs_attn_op(query, - key, - value, - cu_seq_lens.to(device), - sm_scale=scale) - - if num_queries_per_kv > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) - value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) - - ref_output = ref_multi_query_kv_attention( - cu_seq_lens.tolist(), - query, - key, - value, - scale, - dtype, - ) - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/attention/test_flashinfer.py b/tests/kernels/attention/test_flashinfer.py index 3ad6e1d32911..8f9b4eceaa72 100644 --- a/tests/kernels/attention/test_flashinfer.py +++ b/tests/kernels/attention/test_flashinfer.py @@ -77,6 +77,7 @@ def ref_paged_attn( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("sliding_window", [None, 64]) @torch.inference_mode def test_flashinfer_decode_with_paged_kv( kv_lens: list[int], @@ -85,6 +86,7 @@ def test_flashinfer_decode_with_paged_kv( dtype: torch.dtype, block_size: int, soft_cap: Optional[float], + sliding_window: Optional[int], ) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) @@ -136,17 +138,20 @@ def test_flashinfer_decode_with_paged_kv( use_tensor_cores=( (num_query_heads//num_kv_heads) > 4) ) - wrapper.plan(kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - q_data_type=dtype, - kv_data_type=dtype, - logits_soft_cap=soft_cap) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + window_left=sliding_window - 1 if sliding_window is not None else -1, + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap, + ) output = wrapper.run(query, key_value_cache) @@ -157,7 +162,8 @@ def test_flashinfer_decode_with_paged_kv( kv_lens=kv_lens, block_tables=block_tables, scale=scale, - soft_cap=soft_cap) + soft_cap=soft_cap, + sliding_window=sliding_window) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -168,12 +174,17 @@ def test_flashinfer_decode_with_paged_kv( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@pytest.mark.parametrize("sliding_window", [None, 64]) @torch.inference_mode -def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], - num_heads: tuple[int, int], - head_size: int, dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float]) -> None: +def test_flashinfer_prefill_with_paged_kv( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + sliding_window: Optional[int], +) -> None: torch.set_default_device("cuda") current_platform.seed_everything(0) num_seqs = len(seq_lens) @@ -242,6 +253,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], num_kv_heads, head_size, block_size, + window_left=sliding_window - 1 if sliding_window is not None else -1, q_data_type=dtype, kv_data_type=dtype, logits_soft_cap=soft_cap, @@ -259,7 +271,8 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: list[tuple[int, int]], kv_lens=kv_lens, block_tables=block_tables, scale=scale, - soft_cap=soft_cap) + soft_cap=soft_cap, + sliding_window=sliding_window) torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py new file mode 100644 index 000000000000..96eee13695a9 --- /dev/null +++ b/tests/kernels/attention/test_flashinfer_trtllm_decode_attention.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import flashinfer +import pytest +import torch + +from vllm.platforms import current_platform + +if not current_platform.is_device_capability(100): + pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.", + allow_module_level=True) + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 + +# KV Cache Layout for TRT-LLM +# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) + +NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)] +HEAD_SIZES = [128] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. +SOFT_CAPS = [None, 30.0, 50.0] + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax * 0.1 + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("kv_layout", ["HND"]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", SOFT_CAPS) +@torch.inference_mode +def test_flashinfer_trtllm_decode_with_baseline( + kv_lens: list[int], + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + kv_layout: str, +) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + kv_cache_shape = None + if kv_layout == "NHD": + kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size) + elif kv_layout == "HND": + kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size) + else: + raise ValueError(f"Invalid kv_layout: {kv_layout}") + key_value_cache = torch.randn(kv_cache_shape, dtype=dtype) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + k_scale = v_scale = 1.0 + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.\ + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout, + use_tensor_cores=( + (num_query_heads//num_kv_heads) > 4) + ) + wrapper.plan(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + q_data_type=dtype, + kv_data_type=dtype, + logits_soft_cap=soft_cap) + + output = wrapper.run(query, key_value_cache, scale) + + # TRTLLM Decode + max_kv_len = max(kv_lens) + kv_lens_tensor = torch.tensor(kv_lens, + dtype=torch.int, + device=query.device) + output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache( + query.contiguous(), + key_value_cache, + workspace_buffer, + num_query_heads, + num_kv_heads, + scale, + block_tables, + kv_lens_tensor, + block_size, + max_kv_len, + "auto", + k_scale, + v_scale, + ) + + torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - output_trtllm))}" diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 34311b9ccd76..d56d3f4638f1 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -33,8 +33,12 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # change the attention backend to triton MLA m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA") - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 16, + False, + use_mla=True) assert (backend.get_name() == "TRITON_MLA" or backend.get_name() == "TRITON_MLA_VLLM_V1") @@ -42,15 +46,23 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # If use_mla is true # The selected backend is triton MLA m.setenv(STR_BACKEND_ENV_VAR, None) - backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 16, + False, + use_mla=True) assert (backend.get_name() == "TRITON_MLA" or backend.get_name() == "TRITON_MLA_VLLM_V1") # change the attention backend to AITER MLA m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 1, + False, + use_mla=True) assert (backend.get_name() == "ROCM_AITER_MLA" or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") @@ -60,7 +72,11 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): # The selected backend is ROCM_AITER_MLA m.setenv(STR_BACKEND_ENV_VAR, None) m.setenv("VLLM_ROCM_USE_AITER", "1") - backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, - False, True) + backend = get_attn_backend(576, + torch.bfloat16, + "auto", + 1, + False, + use_mla=True) assert (backend.get_name() == "ROCM_AITER_MLA" or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") diff --git a/tests/kernels/core/test_layernorm.py b/tests/kernels/core/test_layernorm.py index 3eac062738f8..02316ceaac73 100644 --- a/tests/kernels/core/test_layernorm.py +++ b/tests/kernels/core/test_layernorm.py @@ -26,6 +26,7 @@ @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("strided_input", [False, True]) @torch.inference_mode() def test_rms_norm( num_tokens: int, @@ -34,13 +35,17 @@ def test_rms_norm( dtype: torch.dtype, seed: int, device: str, + strided_input: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) + last_dim = 2 * hidden_size if strided_input else hidden_size + x = torch.randn(num_tokens, last_dim, dtype=dtype) + x = x[..., :hidden_size] + assert x.is_contiguous() != strided_input x *= scale residual = torch.randn_like(x) * scale if add_residual else None @@ -72,6 +77,7 @@ def test_rms_norm( @pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0]) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("strided_input", [False, True]) def test_fused_rms_norm_quant( num_tokens: int, hidden_size: int, @@ -80,13 +86,18 @@ def test_fused_rms_norm_quant( quant_scale: float, seed: int, device: str, + strided_input: bool, ) -> None: current_platform.seed_everything(seed) torch.set_default_device(device) weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype) + last_dim = 2 * hidden_size if strided_input else hidden_size + x_base = torch.randn(num_tokens, last_dim, dtype=dtype) + x = x_base[..., :hidden_size] + assert x.is_contiguous() != strided_input + x *= scale if add_residual: residual = torch.randn_like(x) * scale @@ -106,9 +117,11 @@ def test_fused_rms_norm_quant( # Unfused kernel is in-place so it goes second # Also use a separate clone of x to avoid modifying the input - x_unfused = x.clone() + x_unfused_base = x_base.clone() + x_unfused = x_unfused_base[..., :hidden_size] + assert x_unfused.is_contiguous() != strided_input torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6) - torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused, + torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused.contiguous(), quant_scale_t) torch.cuda.synchronize() @@ -116,7 +129,6 @@ def test_fused_rms_norm_quant( residual, atol=1e-2, rtol=1e-2) - opcheck( torch.ops._C.fused_add_rms_norm_static_fp8_quant, (out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)) @@ -131,7 +143,7 @@ def test_fused_rms_norm_quant( opcheck(torch.ops._C.rms_norm_static_fp8_quant, (out_quant_fused, x, weight, quant_scale_t, 1e-6)) - torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32), - out_quant.to(dtype=torch.float32), + torch.testing.assert_close(out_quant.to(dtype=torch.float32), + out_quant_fused.to(dtype=torch.float32), atol=1e-3, rtol=1e-3) diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index addb8bfcda13..411bd9e904b0 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -6,9 +6,8 @@ import pytest import torch import torch.nn.functional as F +from einops import rearrange -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops # noqa: F401 from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) @@ -144,79 +143,6 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, x = x.contiguous() bias = bias.contiguous() if bias is not None else None - opcheck(torch.ops._C.causal_conv1d_fwd, - (x, weight, bias, conv_states, cu_seq_len, cache_indices, - has_initial_state, activation in ["silu", "swish"], pad_slot_id)) - - -@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) -@pytest.mark.parametrize("silu_activation", [True]) -@pytest.mark.parametrize("has_bias", [True]) -@pytest.mark.parametrize("has_initial_state", [True, False]) -@pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize( - 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) -@pytest.mark.parametrize('dim', [64]) -@pytest.mark.parametrize('batch', [1]) -def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, - has_initial_state, itype): - device = "cuda" - rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 1e-2, 5e-2 - # set seed - current_platform.seed_everything(0) - x = torch.randn(batch, dim, seqlen, device=device, - dtype=itype).contiguous() - - weight = torch.randn(dim, width, device=device, dtype=itype) - bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - if has_initial_state: - initial_states = torch.randn(batch, - dim, - width - 1, - device=device, - dtype=itype) - has_initial_state_tensor = torch.ones(batch, - dtype=torch.bool, - device=x.device) - else: - initial_states = None - has_initial_state_tensor = None - x_ref = x.clone() - weight_ref = weight.clone() - bias_ref = bias.clone() if bias is not None else None - initial_states_ref = initial_states.clone( - ) if initial_states is not None else None - activation = None if not silu_activation else "silu" - out = causal_conv1d_fn(x, - weight, - bias, - activation=activation, - conv_states=initial_states, - has_initial_state=has_initial_state_tensor) - out_ref, final_states_ref = causal_conv1d_ref( - x_ref, - weight_ref, - bias_ref, - initial_states=initial_states_ref, - return_final_states=True, - activation=activation) - if has_initial_state: - assert initial_states is not None and final_states_ref is not None - assert torch.allclose(initial_states, - final_states_ref, - rtol=rtol, - atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - - causal_conv1d_opcheck_fn(x, - weight, - bias, - activation=activation, - conv_states=initial_states, - has_initial_state=has_initial_state_tensor) - @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @@ -255,22 +181,19 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, - (x, conv_state, weight, bias, activation - in ["silu", "swish"], None, None, PAD_SLOT_ID)) - @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) -@pytest.mark.parametrize("seqlen", [1, 4, 5]) -@pytest.mark.parametrize("width", [2, 3, 4]) -@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("seqlen", [1, 3]) +@pytest.mark.parametrize("width", [3, 4]) +@pytest.mark.parametrize("dim", [2048 + 16, 4096]) # tests correctness in case subset of the sequences are padded @pytest.mark.parametrize("with_padding", [True, False]) -def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, - seqlen, has_bias, +@pytest.mark.parametrize("batch_size", [3]) +def test_causal_conv1d_update_with_batch_gather(batch_size, with_padding, dim, + width, seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -280,12 +203,15 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, # set seed current_platform.seed_everything(0) - batch_size = 3 padding = 5 if with_padding else 0 padded_batch_size = batch_size + padding + # total_entries = number of cache line total_entries = 10 * batch_size - x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) + # x will be (batch, dim, seqlen) with contiguous along dim-axis + x = torch.randn(padded_batch_size, seqlen, dim, device=device, + dtype=itype).transpose(1, 2) + x_ref = x.clone() conv_state_indices = torch.randperm(total_entries)[:batch_size].to( @@ -300,17 +226,22 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) ], dim=0) + + # conv_state will be (cache_lines, dim, state_len) + # with contiguous along dim-axis conv_state = torch.randn(total_entries, - dim, width - 1, + dim, device=device, - dtype=itype) + dtype=itype).transpose(1, 2) + conv_state_for_padding_test = conv_state.clone() weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, conv_state, weight, @@ -325,26 +256,21 @@ def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) assert torch.equal(conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]) - - opcheck(torch.ops._C.causal_conv1d_update, - (x, conv_state, weight, bias, activation - in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID)) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize( - 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) +@pytest.mark.parametrize('seqlen', [8, 30, 249, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) -# tests correctness in case subset of the sequences are padded @pytest.mark.parametrize('with_padding', [True, False]) -def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, - silu_activation, itype): +@pytest.mark.parametrize('batch', [4, 10]) +def test_causal_conv1d_varlen(batch, with_padding, dim, seqlen, width, + has_bias, silu_activation, itype): device = "cuda" torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -353,14 +279,13 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, # set seed current_platform.seed_everything(0) seqlens = [] - batch_size = 4 - if seqlen < 10: - batch_size = 1 + batch_size = batch padding = 3 if with_padding else 0 padded_batch_size = batch_size + padding nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + seqlens.append( torch.diff( torch.cat( @@ -373,19 +298,22 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) - x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, - dtype=itype)[:, 4096:4096 + dim, :] + x = rearrange( + torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype), + "b s d -> b d s")[:, 4096:4096 + dim, :] + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None x_ref = x.clone() weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" final_states = torch.randn(total_entries, - dim, width - 1, + dim, device=x.device, - dtype=x.dtype) + dtype=x.dtype).transpose(1, 2) final_states_ref = final_states.clone() has_initial_states = torch.randint(0, 2, (cumsum.shape[0] - 1, ), @@ -400,10 +328,16 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), ], dim=-1) + out = causal_conv1d_fn(x.squeeze(0), + weight, + bias=bias, + conv_states=final_states, + query_start_loc=cumsum.cuda(), + cache_indices=padded_state_indices, + has_initial_state=has_initial_states, + activation=activation, + pad_slot_id=PAD_SLOT_ID) - out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - padded_state_indices, has_initial_states, - final_states, activation, PAD_SLOT_ID) out_ref = [] out_ref_b = [] @@ -426,13 +360,9 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) out_ref_tensor = torch.cat(out_ref, dim=0) - unpadded_out = out[:, :out_ref_tensor.shape[-1]] - assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) assert torch.allclose(final_states[state_indices], final_states_ref[state_indices], rtol=rtol, atol=atol) - - causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - padded_state_indices, has_initial_states, - final_states, activation) + unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) diff --git a/tests/kernels/mamba/test_mamba_mixer2.py b/tests/kernels/mamba/test_mamba_mixer2.py index f5c6a18614ff..16c310726ad1 100644 --- a/tests/kernels/mamba/test_mamba_mixer2.py +++ b/tests/kernels/mamba/test_mamba_mixer2.py @@ -119,7 +119,8 @@ def mixer2_gated_norm_tensor_parallel( gate_states[..., local_rank * N:(local_rank + 1) * N], ) ref_output = mixer_single_gpu(hidden_states, gate_states) - torch.allclose(output, - ref_output[..., local_rank * N:(local_rank + 1) * N], - atol=1e-3, - rtol=1e-3) + torch.testing.assert_close(output, + ref_output[..., + local_rank * N:(local_rank + 1) * N], + atol=5e-3, + rtol=1e-3) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index ccf0ff6abd16..00c1a2911d7d 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -6,11 +6,11 @@ import torch.nn.functional as F from einops import rearrange, repeat -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _query_start_loc_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform +from vllm.v1.attention.backends.mamba_attn import ( + _query_start_loc_to_chunk_indices_offsets) # Added by the IBM Team, 2024 @@ -193,6 +193,13 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, # this tests the kernels on a single example (no batching) + # TODO: the bfloat16 case requires higher thresholds. To be investigated + + if itype == torch.bfloat16: + atol, rtol = 5e-2, 5e-2 + else: + atol, rtol = 8e-3, 5e-3 + # set seed batch_size = 1 # batch_size # ssd_minimal_discrete requires chunk_size divide seqlen @@ -216,14 +223,14 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, return_final_states=True) # just test the last in sequence - torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) + torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) # just test the last head # NOTE, in the kernel we always cast states to fp32 - torch.allclose(final_state[:, -1], - final_state_min[:, -1].to(torch.float32), - atol=1e-3, - rtol=1e-3) + torch.testing.assert_close(final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=atol, + rtol=rtol) @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) @@ -263,6 +270,13 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + # TODO: the irregular chunk size cases have some issues and require higher + # tolerance. This is to be invesigated + if chunk_size not in {8, 256}: + atol, rtol = 5e-1, 5e-1 + else: + atol, rtol = 5e-3, 5e-3 + # hold state during the cutting process so we know if an # example has been exhausted and needs to cycle last_taken: dict = {} # map: eg -> pointer to last taken sample @@ -300,7 +314,7 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, # just test one dim and dstate Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] Y_min_eg = Y_min[i][:, 0, 0] - torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) # update states states = new_states diff --git a/tests/spec_decode/__init__.py b/tests/kernels/moe/modular_kernel_tools/__init__.py similarity index 100% rename from tests/spec_decode/__init__.py rename to tests/kernels/moe/modular_kernel_tools/__init__.py diff --git a/tests/kernels/moe/modular_kernel_tools/cli_args.py b/tests/kernels/moe/modular_kernel_tools/cli_args.py new file mode 100644 index 000000000000..b95d87cd04f5 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/cli_args.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig + +from .common import Config +from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES, + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + + +def make_config_arg_parser(description: str): + + def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize: + for pf in MK_ALL_PREPARE_FINALIZE_TYPES: + if pf.__name__ == s: + return pf + raise ValueError( + f"Cannot find a PrepareFinalize type that matches {s}") + + def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute: + for fe in MK_FUSED_EXPERT_TYPES: + if fe.__name__ == s: + return fe + raise ValueError(f"Cannot find a FusedExperts type that matches {s}") + + def to_quant_torch_dtype(s: str) -> torch.dtype: + if s == "torch.float8_e4m3fn": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported quant type {s}") + + parser = argparse.ArgumentParser(description=description) + + parser.add_argument( + "--world-size", + type=int, + default=2, + help="Number of ranks that participate in all2all", + ) + parser.add_argument( + "--pf-type", + type=to_pf_class_type, + required=True, + help=("Choose a PrepareFinalize Type : " + f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"), + ) + parser.add_argument( + "--experts-type", + type=to_experts_class_type, + required=True, + help=(f"Choose a FusedExpert type : " + f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"), + ) + parser.add_argument( + "-m", + nargs="+", + type=int, + default=[64], + help="num tokens per rank", + ) + parser.add_argument( + "-k", + type=int, + default=7168, + help="hidden-size", + ) + parser.add_argument( + "-n", + type=int, + default=1024, + help="N dimension of the first fused-moe matmul", + ) + parser.add_argument("--num-experts", + type=int, + default=32, + help="Global num experts") + parser.add_argument("--topk", + nargs="+", + type=int, + default=[4, 1], + help="num topk") + parser.add_argument( + "--fused-moe-chunk-size", + type=int, + help="Fused moe chunk size used for the non-batched fused experts impl." + ) + + # Quant args + parser.add_argument("--quant-dtype", + type=to_quant_torch_dtype, + help="Quant datatype") + parser.add_argument("--per-token-quantized-activations", + action='store_true', + help=("The input activations must be per-token " + "quantized")) + parser.add_argument("--per-channel-quantized-weights", + action="store_true", + help="The weights must be per-channel quantized.") + parser.add_argument("--block-shape", + nargs="+", + type=int, + help="Quantization block shape") + + # Torch trace profile generation args + parser.add_argument("--torch-trace-dir-path", + type=str, + default=None, + help="Get torch trace for single execution") + + return parser + + +def _validate_args(args: argparse.Namespace): + + if args.quant_dtype is not None: + assert args.quant_dtype == torch.float8_e4m3fn + if args.block_shape is not None: + assert len(args.block_shape) == 2, ( + f"block shape must have 2 elements. got {args.block_shape}") + + if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: + assert args.world_size == 1, ( + "Single GPU objects need world size set to 1") + + if args.torch_trace_dir_path is not None: + from pathlib import Path + assert Path(args.torch_trace_dir_path).is_dir(), ( + f"Please create {args.torch_trace_dir_path}") + + +def make_config(args: argparse.Namespace) -> Config: + + _validate_args(args) + + quant_config = None + if args.quant_dtype is not None: + quant_config = FusedMoEQuantConfig( + quant_dtype=args.quant_dtype, + per_act_token_quant=args.per_token_quantized_activations, + per_out_ch_quant=args.per_channel_quantized_weights, + block_shape=args.block_shape) + + return Config( + Ms=args.m, + K=args.k, + N=args.n, + E=args.num_experts, + topks=args.topk, + dtype=torch.bfloat16, # hard-code + quant_config=quant_config, + prepare_finalize_type=args.pf_type, + fused_experts_type=args.experts_type, + fused_moe_chunk_size=args.fused_moe_chunk_size, + world_size=args.world_size, + torch_trace_dir_path=args.torch_trace_dir_path) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py new file mode 100644 index 000000000000..fd99e8dc5c98 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -0,0 +1,641 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch + +import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from tests.kernels.utils import torch_experts +from vllm.config import VllmConfig +from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size +# Fused experts and PrepareFinalize imports +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig) +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase, + TritonExperts) +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx + +from .parallel_utils import ProcessGroupInfo +from .utils import (make_block_quant_fp8_weights, make_non_quant_weights, + make_quant_fp8_weights, per_token_cast_to_fp8) + +if has_pplx(): + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) +if has_deep_ep(): + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + + +def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str: + if t is None: + return f"{name} : None" + else: + return f"{name} : {t.shape} {t.dtype} {t.device}" + + +@dataclass +class Config: + Ms: Union[list[int], int] + K: int + N: int + E: int + topks: Union[list[int], int] + dtype: torch.dtype + quant_config: Optional[FusedMoEQuantConfig] + + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute + + fused_moe_chunk_size: Optional[int] + world_size: int + + torch_trace_dir_path: Optional[str] = None + + def describe(self) -> str: + s = "" + s += "== Config: \n" + s += f" world_size={self.world_size} \n" + s += f" PF={self.prepare_finalize_type.__name__} \n" + s += f" FE={self.fused_experts_type.__name__} \n" + s += f" topk={self.topks} \n" + s += f" dtype={self.dtype} \n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n" + s += " Quant: \n" + s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n " + if self.quant_config is not None: + s += f" q_dtype={self.quant_dtype} \n" + s += f" q_block_shape={self.quant_block_shape} \n" + s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n" + s += f" q_per_act_token={self.is_per_act_token_quant} \n" + else: + s += " quant=None \n" + return s + + @property + def M(self) -> int: + assert isinstance(self.Ms, int) + return self.Ms + + @property + def quant_dtype(self) -> Optional[torch.dtype]: + if self.quant_config is None: + return None + return self.quant_config.quant_dtype + + @property + def is_per_act_token_quant(self) -> bool: + if self.quant_config is None: + return False + return self.quant_config.per_act_token_quant + + @property + def is_per_tensor_act_quant(self) -> bool: + if self.quant_config is None: + return False + return (not self.is_per_act_token_quant + and self.quant_block_shape is None) + + @property + def is_per_out_ch_quant(self) -> bool: + if self.quant_config is None: + return False + return self.quant_config.per_out_ch_quant + + @property + def quant_block_shape(self) -> Optional[list[int]]: + if self.quant_config is None: + return None + return self.quant_config.block_shape + + @property + def topk(self) -> int: + assert isinstance(self.topks, int) + return self.topks + + @property + def topk_ids_dtype(self) -> Optional[torch.dtype]: + topk_ids_dtype = None + if self.prepare_finalize_type == PplxPrepareAndFinalize: + topk_ids_dtype = torch.uint32 + elif self.prepare_finalize_type in [ + DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ]: + topk_ids_dtype = torch.int64 + return topk_ids_dtype + + @property + def num_local_experts(self) -> int: + return self.E // self.world_size + + def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]: + """ + make env data for vllm launch. + """ + vllm_config = VllmConfig() + vllm_config.parallel_config.data_parallel_size = self.world_size + vllm_config.parallel_config.enable_expert_parallel = True + + env_dict = { + "VLLM_ALL2ALL_BACKEND": self.all2all_backend(), + "VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())), + } + if self.fused_moe_chunk_size is not None: + env_dict.update( + {"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)}) + return vllm_config, env_dict + + def is_fp8_block_quantized(self): + return (self.quant_dtype == torch.float8_e4m3fn + and self.quant_block_shape is not None) + + def is_batched_prepare_finalize(self): + return self.prepare_finalize_type in [ + PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ] + + def is_batched_fused_experts(self): + return self.fused_experts_type in [ + CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts, + NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts + ] + + def is_standard_fused_experts(self): + return self.fused_experts_type in [ + CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, + TritonExperts + ] + + def is_fe_16bit_supported(self): + return self.fused_experts_type in [ + BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, + NaiveBatchedExperts, TritonExperts + ] + + def is_fe_fp8_supported(self): + return self.fused_experts_type in [ + BatchedDeepGemmExperts, + BatchedTritonExperts, + BatchedTritonOrDeepGemmExperts, + CutlassExpertsFp8, + DeepGemmExperts, + TritonExperts, + TritonOrDeepGemmExperts, + NaiveBatchedExperts, + ] + + def is_fe_block_fp8_supported(self): + return self.fused_experts_type in [ + BatchedDeepGemmExperts, + BatchedTritonOrDeepGemmExperts, + DeepGemmExperts, + TritonExperts, + TritonOrDeepGemmExperts, + BatchedTritonExperts, + NaiveBatchedExperts, + ] + + def is_fe_supports_chunking(self): + return self.fused_experts_type in [ + CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts, + TritonExperts + ] + + def needs_deep_gemm(self): + return self.fused_experts_type in [ + BatchedDeepGemmExperts, + DeepGemmExperts, + ] + + def needs_pplx(self): + return self.prepare_finalize_type in [PplxPrepareAndFinalize] + + def needs_deep_ep(self): + return self.prepare_finalize_type in [ + DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ] + + def all2all_backend(self): + if self.needs_pplx(): + return "pplx" + if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize: + return "deepep_high_throughput" + if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize: + return "deepep_low_latency" + return "naive" + + def needs_all2all(self): + return self.prepare_finalize_type in [ + PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize + ] + + def is_valid(self): + # Check prepare-finalize and fused-experts compatibility + if self.is_batched_prepare_finalize(): + if not self.is_batched_fused_experts(): + return False + else: + if not self.is_standard_fused_experts(): + return False + + use_chunking = self.fused_moe_chunk_size is not None + if use_chunking and not self.is_fe_supports_chunking(): + return False + + # Check quantization sanity + if (int(self.is_per_act_token_quant) + + int(self.is_per_tensor_act_quant) + + int(self.quant_block_shape is not None)) > 1: + # invalid quant config + return False + + # check bf16 / fp16 support + is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None) + if is_16bit and not self.is_fe_16bit_supported(): + return False + + # Check fp8 support + is_fp8 = self.quant_dtype == torch.float8_e4m3fn + if is_fp8 and not self.is_fe_fp8_supported(): + return False + + # Check fp8 block quanization support + is_block_quatized = self.quant_block_shape is not None + if is_block_quatized and not is_fp8: + return False + if is_block_quatized and not self.is_fe_block_fp8_supported(): + return False + + # deep_gemm only works with block-quantized + if self.needs_deep_gemm() and not is_block_quatized: + return False + + # Check dependencies + if self.needs_deep_ep() and not has_deep_ep(): + return False + if self.needs_deep_gemm() and not has_deep_gemm(): + return False + if self.needs_pplx() and not has_pplx(): # noqa: SIM103 + return False + + return True + + +@dataclass +class WeightTensors: + w1: torch.Tensor + w2: torch.Tensor + w1_scale: Optional[torch.Tensor] + w2_scale: Optional[torch.Tensor] + + def describe(self): + s = "" + s += "== Weight Tensors: \n" + s += f' - {_describe_tensor(self.w1, "w1")} \n' + s += f' - {_describe_tensor(self.w2, "w2")} \n' + s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n' + s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n' + return s + + def to_current_device(self): + self.w1 = self.w1.to(device=torch.cuda.current_device()) + self.w2 = self.w2.to(device=torch.cuda.current_device()) + is_quantized = self.w1.dtype == torch.float8_e4m3fn + if is_quantized: + assert self.w1_scale is not None + assert self.w2_scale is not None + self.w1_scale = self.w1_scale.to( + device=torch.cuda.current_device()) + self.w2_scale = self.w2_scale.to( + device=torch.cuda.current_device()) + + def slice_weights(self, rank: int, + num_local_experts: int) -> "WeightTensors": + s = rank * num_local_experts + e = s + num_local_experts + w1 = self.w1[s:e, :, :] + w2 = self.w2[s:e, :, :] + is_quantized = self.w1.dtype == torch.float8_e4m3fn + w1_scale, w2_scale = (None, None) + if is_quantized: + assert self.w1_scale is not None + assert self.w2_scale is not None + w1_scale = self.w1_scale[s:e, :, :] + w2_scale = self.w2_scale[s:e, :, :] + return WeightTensors(w1, w2, w1_scale, w2_scale) + + @staticmethod + def make(config: Config) -> "WeightTensors": + + if config.quant_dtype is None: + # just make normal dtype weights + w1, w2 = make_non_quant_weights(e=config.E, + n=config.N, + k=config.K, + dtype=config.dtype) + return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None) + + assert config.quant_dtype == torch.float8_e4m3fn + if not config.is_fp8_block_quantized(): + w1, w2, w1_scale, w2_scale = make_quant_fp8_weights( + e=config.E, + n=config.N, + k=config.K, + per_out_channel_quant=config.is_per_out_ch_quant, + ) + return WeightTensors(w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale) + + assert config.quant_block_shape is not None + w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + e=config.E, + n=config.N, + k=config.K, + block_size=config.quant_block_shape, + ) + return WeightTensors(w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale) + + +@dataclass +class RankTensors: + hidden_states: torch.Tensor + hidden_states_scale: Optional[torch.Tensor] + + topk_weights: torch.Tensor + topk_ids: torch.Tensor + expert_map: Optional[torch.Tensor] + + quant_config: Optional[FusedMoEQuantConfig] + + def describe(self): + s = "" + s += "== Rank Tensors: \n" + s += f' - {_describe_tensor(self.hidden_states, "HS")} \n' + s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n' + s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n' + s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n' + s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n' + return s + + @staticmethod + def make_hidden_states( + config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Return hidden_states + """ + m, k, dtype = (config.M, config.K, config.dtype) + a = (torch.randn( + (m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0) + + if config.quant_dtype is None: + return a, None + + # We dequant and use that as hidden_states so the tests are stable. + # quantizing and dequantizing yield slightly different results + # depending on the hardware. Here we, quantize and dequantize + # first - so further quantize and dequantize will yield the same + # values. + if config.is_per_tensor_act_quant: + a_q, a_scales = ops.scaled_fp8_quant( + a, use_per_token_if_dynamic=False) + return a_q.float().mul(a_scales).to(dtype), a_scales + + if config.is_per_act_token_quant: + a_q, a_scales = ops.scaled_fp8_quant(a, + use_per_token_if_dynamic=True) + return a_q.float().mul(a_scales).to(dtype), None + + assert config.quant_block_shape is not None + block_k = config.quant_block_shape[1] + a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k) + return a_q.float().view( + (-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None + + @staticmethod + def make(config: Config, pgi: ProcessGroupInfo): + + dtype = config.dtype + topk, m, _ = (config.topk, config.M, config.K) + hidden_states, hidden_states_scale = RankTensors.make_hidden_states( + config) + + num_local_experts, global_num_experts = (config.num_local_experts, + config.E) + score = torch.randn((m, global_num_experts), + device="cuda", + dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk, + False) + topk_ids = topk_ids.to(config.topk_ids_dtype) + + # distribute topk_ids evenly + for mi in range(m): + topk_ids[mi] = torch.randperm(config.E)[:topk] + topk_ids = topk_ids.to(device=torch.cuda.current_device()) + + expert_map = None + if config.world_size > 1: + expert_map = torch.full((global_num_experts, ), + fill_value=-1, + dtype=torch.int32) + s = pgi.rank * num_local_experts + e = s + num_local_experts + expert_map[s:e] = torch.tensor(list(range(num_local_experts))) + expert_map = expert_map.to(device=torch.cuda.current_device(), + dtype=torch.int32) + + return RankTensors( + hidden_states=hidden_states, + hidden_states_scale=hidden_states_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + quant_config=config.quant_config, + ) + + +def reference_moe_impl(config: Config, weights: WeightTensors, + rank_tensors: RankTensors) -> torch.Tensor: + + return torch_experts(a=rank_tensors.hidden_states, + w1=weights.w1, + w2=weights.w2, + topk_weight=rank_tensors.topk_weights, + topk_ids=rank_tensors.topk_ids, + global_num_experts=config.E, + expert_map=None, + w1_scale=weights.w1_scale, + w2_scale=weights.w2_scale, + a1_scale=rank_tensors.hidden_states_scale, + quant_dtype=config.quant_dtype, + per_act_token_quant=config.is_per_act_token_quant, + block_shape=config.quant_block_shape, + apply_router_weights_on_input=config.topk == 1) + + +def make_fused_experts( + config: Config, moe: FusedMoEConfig, + num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute: + + use_fp8 = config.quant_dtype == torch.float8_e4m3fn + batch_kwargs = { + "max_num_tokens": moe.max_num_tokens, + "num_dispatchers": num_dispatchers, + } + quant_kwargs = { + "use_fp8_w8a8": use_fp8, + "use_int8_w8a8": False, + "use_int8_w8a16": False, + "use_int4_w4a16": False, + "block_shape": config.quant_block_shape, + "per_act_token_quant": config.is_per_act_token_quant, + } + deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()} + + if config.fused_experts_type == BatchedDeepGemmExperts: + kwargs = batch_kwargs | { + "block_shape": config.quant_block_shape, + "per_act_token_quant": config.is_per_act_token_quant, + } + print(f"Making BatchedDeepGemmExperts {kwargs} ...") + experts = BatchedDeepGemmExperts(**kwargs) + elif config.fused_experts_type == BatchedTritonExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making BatchedTritonExperts {kwargs} ...") + experts = BatchedTritonExperts(**kwargs) + elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts: + kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs + print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") + experts = BatchedTritonOrDeepGemmExperts(**kwargs) + elif config.fused_experts_type == DeepGemmExperts: + print("Making DeepGemmExperts () ...") + experts = DeepGemmExperts() + elif config.fused_experts_type == TritonExperts: + kwargs = quant_kwargs + print(f"Making TritonExperts {kwargs} ...") + experts = TritonExperts(**kwargs) + elif config.fused_experts_type == TritonOrDeepGemmExperts: + kwargs = quant_kwargs | deepgemm_kwargs + print(f"Making TritonOrDeepGemmExperts {kwargs} ...") + experts = TritonOrDeepGemmExperts(**kwargs) + elif config.fused_experts_type == NaiveBatchedExperts: + kwargs = batch_kwargs | quant_kwargs + print(f"Making NaiveBatchedExperts {kwargs} ...") + experts = NaiveBatchedExperts(**kwargs) + elif config.fused_experts_type == CutlassExpertsFp8: + use_batched_format = config.is_batched_prepare_finalize() + num_experts = (moe.num_local_experts + if use_batched_format else moe.num_experts) + kwargs = { + "max_experts_per_worker": num_experts, + "out_dtype": moe.in_dtype, + "per_act_token_quant": config.is_per_act_token_quant, + "per_out_ch_quant": config.is_per_out_ch_quant, + "block_shape": config.quant_block_shape, + "num_dispatchers": num_dispatchers, + "use_batched_format": use_batched_format + } + print(f"Making CutlassExpertsFp8 {kwargs} ...") + experts = CutlassExpertsFp8(**kwargs) + + return experts + + +def make_modular_kernel(config: Config, + vllm_config: VllmConfig) -> mk.FusedMoEModularKernel: + + def next_power_of_2(x): + import math + if x == 0: + return 1 + return 2**math.ceil(math.log2(x)) + + # make moe config + moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make( + tp_size_=get_tensor_model_parallel_world_size(), + dp_size_=get_dp_group().world_size, + vllm_parallel_config=vllm_config.parallel_config, + ) + moe = FusedMoEConfig( + num_experts=config.E, + experts_per_token=config.topk, + hidden_dim=config.K, + num_local_experts=config.num_local_experts, + moe_parallel_config=moe_parallel_config, + in_dtype=config.dtype, + quant_config=config.quant_config, + max_num_tokens=next_power_of_2(config.M), + ) + + # make modular kernel + prepare_finalize = None + if config.needs_all2all(): + prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe) + assert prepare_finalize is not None + else: + prepare_finalize = MoEPrepareAndFinalizeNoEP() + + fused_experts = make_fused_experts(config, moe, + prepare_finalize.num_dispatchers()) + + modular_kernel = mk.FusedMoEModularKernel( + prepare_finalize=prepare_finalize, fused_experts=fused_experts) + + return modular_kernel + + +def run_modular_kernel( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + config: Config, + weights: WeightTensors, + rank_tensors: RankTensors, +) -> torch.Tensor: + assert isinstance(config.Ms, int) + assert isinstance(config.topks, int) + + # weights for rank + rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) + + mk = make_modular_kernel(config, vllm_config) + + mk_kwargs = { + "hidden_states": rank_tensors.hidden_states.clone( + ), # impls might update the tensor in place + "w1": rank_weights.w1, + "w2": rank_weights.w2, + "topk_weights": rank_tensors.topk_weights, + "topk_ids": rank_tensors.topk_ids, + "expert_map": rank_tensors.expert_map, + "w1_scale": rank_weights.w1_scale, + "w2_scale": rank_weights.w2_scale, + "a1_scale": rank_tensors.hidden_states_scale, + "global_num_experts": config.E, + "apply_router_weight_on_input": config.topk == 1, + } + out = mk.forward(**mk_kwargs) + + return out diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py new file mode 100644 index 000000000000..5dbfdfc153f9 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +from enum import Enum +from itertools import product +from typing import Optional + +import torch +from tqdm import tqdm + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.platforms import current_platform + +from .common import (Config, RankTensors, WeightTensors, reference_moe_impl, + run_modular_kernel) +from .mk_objects import (MK_FUSED_EXPERT_TYPES, + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS) +from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config + + +class Result(Enum): + PASS = 1 + FAIL = 2 + SKIP = 3 + + +def rank_worker( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + cpu_group, + config: Config, + weights: WeightTensors, +): + current_platform.seed_everything(pgi.rank) + + # sanity check + from vllm import envs + if config.fused_moe_chunk_size is not None: + assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + # get weights to this device + weights.to_current_device() + + Ms = config.Ms + assert isinstance(Ms, list) + TOPKs = config.topks + assert isinstance(TOPKs, list) + + for m, topk in product(Ms, TOPKs): + print(f"Running m={m}, topk={topk} ...") + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk + + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) + + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) + + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + + torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + + +def make_feature_matrix(csv_file_path: str): + + from dataclasses import asdict + + import pandas as pd + + def add_to_results(config: Config, + success: Result, + results_df: Optional[pd.DataFrame] = None): + config_dict = asdict(config) + config_dict['prepare_finalize_type'] = config_dict[ + 'prepare_finalize_type'].__name__ + config_dict['fused_experts_type'] = config_dict[ + 'fused_experts_type'].__name__ + config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant + quant_config_dict = config_dict['quant_config'] + del config_dict['quant_config'] + if quant_config_dict is None: + quant_config = FusedMoEQuantConfig(None) + quant_config_dict = asdict(quant_config) + + config_dict |= quant_config_dict + result_dict = config_dict | {'success': success.name} + + result_df = pd.DataFrame([result_dict]) + if results_df is None: + results_df = result_df + else: + results_df = pd.concat([results_df, result_df], ignore_index=True) + + return results_df + + Ms = [64] + Ks = [7168] # hidden sizes + Ns = [2048] + TOPKs = [[4, 1]] + Es = [32] + DTYPEs = [torch.bfloat16] + PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + FE_TYPES = MK_FUSED_EXPERT_TYPES + Q_TYPES = MK_QUANT_CONFIGS + + combinations = list( + product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES)) + + results_df: Optional[pd.DataFrame] = None + for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm( + combinations): #noqa: E501 + config = Config(Ms=[m], + K=k, + N=n, + E=e, + topks=topks, + dtype=dtype, + prepare_finalize_type=pf_type, + fused_experts_type=experts_type, + quant_config=quant_config, + world_size=2, + fused_moe_chunk_size=None) + + success = None + if config.is_valid(): + print(f"Running config : {config.describe()} ...") + try: + weights: WeightTensors = WeightTensors.make(config) + vllm_config, env_dict = config.make_env_data() + parallel_launch_with_config(config.world_size, rank_worker, + vllm_config, env_dict, config, + weights) + success = Result.PASS + except Exception as _: + success = Result.FAIL + else: + success = Result.SKIP + + results_df = add_to_results(config, success, results_df) + + if results_df is not None: + results_df.to_csv(f"{csv_file_path}") + + +if __name__ == '__main__': + import argparse + from pathlib import Path + parser = argparse.ArgumentParser(description=( + "Make ModularKernel feature matrix \n" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501 + "-f ./feature_matrices/feature_matrix.csv")) + + parser.add_argument("-f", + "--feature-matrix-csv-file-path", + type=str, + required=True, + help="File name to Generate a .csv file") + args = parser.parse_args() + + csv_path = args.feature_matrix_csv_file_path + assert csv_path.endswith( + 'csv'), f"Need a file path ending with .csv, got {csv_path}" + assert Path(csv_path).parent.is_dir( + ), f"Cannot find parent directory for {Path(csv_path).parent}" + + make_feature_matrix(args.feature_matrix_csv_file_path) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py new file mode 100644 index 000000000000..73214066f7ea --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +# Fused experts and PrepareFinalize imports +from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, NaiveBatchedExperts) +from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.prepare_finalize import ( + MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.utils import has_deep_ep, has_pplx + +if has_deep_ep(): + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + +if has_pplx(): + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + +MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = [] +if has_pplx(): + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize] +if has_deep_ep(): + MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [ + DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize + ] + +MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP] + +MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + + MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) + +MK_FUSED_EXPERT_TYPES = [ + BatchedDeepGemmExperts, + BatchedTritonExperts, + NaiveBatchedExperts, + BatchedTritonOrDeepGemmExperts, + CutlassExpertsFp8, + DeepGemmExperts, + TritonOrDeepGemmExperts, + TritonExperts, +] + +MK_QUANT_CONFIGS = [ + None, + # per-channel / per-column weights and per-tensor activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=False, + block_shape=None), + # per-channel / per-column weights and per-token activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=True, + per_act_token_quant=True, + block_shape=None), + # per-tensor weights and per-tensor activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=None), + # per-tensor weights and per-token activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=True, + block_shape=None), + # block-quantized weights and 128 block per-token activations + FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn, + per_out_ch_quant=False, + per_act_token_quant=False, + block_shape=[128, 128]), + # TODO (varun) : Should we test the following combinations ? + # block-quantized weights and per-token activations + # block-quantized weights and per-tensor activations +] diff --git a/tests/kernels/moe/modular_kernel_tools/parallel_utils.py b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py new file mode 100644 index 000000000000..1f8d21a7a702 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/parallel_utils.py @@ -0,0 +1,138 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import os +import traceback +from typing import Any, Callable, Optional + +import torch +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) +from vllm.utils import get_open_port + +## Parallel Processes Utils + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int, + local_rank: int): + + import tempfile + temp_file = tempfile.mkstemp()[1] + + set_current_vllm_config(vllm_config) + with set_current_vllm_config(vllm_config): + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=f"file://{temp_file}", + local_rank=local_rank, + backend="nccl", + ) + + initialize_model_parallel( + tensor_model_parallel_size=vllm_config.parallel_config. + tensor_parallel_size, + pipeline_model_parallel_size=vllm_config.parallel_config. + pipeline_parallel_size, + ) + cpu_group = torch.distributed.new_group(list(range(world_size)), + backend="gloo") + return cpu_group + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, + P], None], + vllm_config: Optional[VllmConfig], + env_dict: Optional[dict], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + if env_dict is not None: + os.environ.update(env_dict) + + cpu_group = None + if vllm_config is not None: + cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + vllm_config, + cpu_group, + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch_with_config( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None], + vllm_config: VllmConfig, + env_dict: dict[Any, Any], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}", + worker, + vllm_config, + env_dict, + ) + args, + nprocs=world_size, + join=True, + ) diff --git a/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py new file mode 100644 index 000000000000..dd16ffb2eabe --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +from itertools import product +from typing import Any, Callable + +import torch + +from vllm.config import VllmConfig +from vllm.platforms import current_platform + +from .common import Config, RankTensors, WeightTensors, make_modular_kernel +from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config + + +def do_profile(fn: Callable, + fn_kwargs: dict[Any, Any], + pgi: ProcessGroupInfo, + config: Config, + num_warmups: int = 5): + for _ in range(num_warmups): + fn(**fn_kwargs) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + record_shapes=True, + ) as tprof: + fn(**fn_kwargs) + torch.cuda.synchronize(torch.cuda.current_device()) + + # TODO (varun): Add a descriptive trace file name + tprof.export_chrome_trace( + f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json") + + +def profile_modular_kernel( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + config: Config, + weights: WeightTensors, + rank_tensors: RankTensors, +) -> None: + assert isinstance(config.Ms, int) + assert isinstance(config.topks, int) + + # weights for rank + rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts) + + # make modular kernel + mk = make_modular_kernel(config, vllm_config) + + mk_kwargs = { + "hidden_states": rank_tensors.hidden_states, + "w1": rank_weights.w1, + "w2": rank_weights.w2, + "topk_weights": rank_tensors.topk_weights, + "topk_ids": rank_tensors.topk_ids, + "expert_map": rank_tensors.expert_map, + "w1_scale": rank_weights.w1_scale, + "w2_scale": rank_weights.w2_scale, + "a1_scale": rank_tensors.hidden_states_scale, + "global_num_experts": config.E, + "apply_router_weight_on_input": config.topk == 1, + } + + do_profile(mk.forward, mk_kwargs, pgi, config) + + +def rank_worker( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + cpu_group, + config: Config, + weights: WeightTensors, +): + current_platform.seed_everything(pgi.rank) + + # sanity check + from vllm import envs + if config.fused_moe_chunk_size is not None: + assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + # get weights to this device + weights.to_current_device() + + Ms = config.Ms + assert isinstance(Ms, list) + TOPKs = config.topks + assert isinstance(TOPKs, list) + + for m, topk in product(Ms, TOPKs): + print(f"Running m={m}, topk={topk} ...") + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk + + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) + profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors) + + +def run(config: Config): + weights: WeightTensors = WeightTensors.make(config) + vllm_config, env_dict = config.make_env_data() + parallel_launch_with_config(config.world_size, rank_worker, vllm_config, + env_dict, config, weights) + + +if __name__ == '__main__': + from .cli_args import make_config, make_config_arg_parser + parser = make_config_arg_parser(description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + )) + args = parser.parse_args() + assert args.torch_trace_dir_path is not None, ( + "Please pass in a directory to store torch traces") + config = make_config(args) + + run(config) diff --git a/tests/kernels/moe/modular_kernel_tools/utils.py b/tests/kernels/moe/modular_kernel_tools/utils.py new file mode 100644 index 000000000000..09bb4a34f318 --- /dev/null +++ b/tests/kernels/moe/modular_kernel_tools/utils.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math + +import torch + +import vllm._custom_ops as ops + + +def per_token_cast_to_fp8( + x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + pad_size = (block_size - (n % block_size)) % block_size + x = torch.nn.functional.pad(x, + (0, pad_size), value=0) if pad_size > 0 else x + x_view = x.view(m, -1, block_size) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn) + return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1) + + +def per_block_cast_to_fp8( + x: torch.Tensor, block_size_k: int, + block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + ( + int(math.ceil(m / block_size_k)) * block_size_k, + int(math.ceil(n / block_size_n)) * block_size_n, + ), + dtype=x.dtype, + device=x.device, + ) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, block_size_k, + x_padded.size(1) // block_size_k, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def make_non_quant_weights( + e: int, + n: int, + k: int, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Return weights w1, w2 + """ + device = torch.cuda.current_device() + w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15 + w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15 + return w1, w2 + + +def make_block_quant_fp8_weights( + e: int, + n: int, + k: int, + block_size: list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return weights w1, w2, w1_scale, w2_scale + """ + dtype = torch.bfloat16 + device = torch.cuda.current_device() + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype) + w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) + w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * n) + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w2 = (n + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device) + + w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), + device=device, + dtype=torch.float32) + w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), + device=device, + dtype=torch.float32) + + assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n, + (k + (block_k - 1)) // block_k) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + for i in range(e): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i], + block_size_k=block_k, + block_size_n=block_n) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i], + block_size_k=block_k, + block_size_n=block_n) + + return w1, w2, w1_s, w2_s + + +def make_quant_fp8_weights( + e: int, + n: int, + k: int, + per_out_channel_quant: bool, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return w1, w2, w1_scale, w2_scale + """ + q_dtype = torch.float8_e4m3fn + + w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16) + + # w1 -> w1_q, w2 -> w2_q + w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype) + w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype) + + n_b_scales = 2 * n if per_out_channel_quant else 1 + k_b_scales = k if per_out_channel_quant else 1 + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_channel_quant) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_channel_quant) + return w1_q, w2_q, w1_scale, w2_scale diff --git a/tests/kernels/moe/parallel_utils.py b/tests/kernels/moe/parallel_utils.py index f4049eb0d095..1ad361ae0733 100644 --- a/tests/kernels/moe/parallel_utils.py +++ b/tests/kernels/moe/parallel_utils.py @@ -4,7 +4,6 @@ DeepEP test utilities """ import dataclasses -import importlib import os import traceback from typing import Callable, Optional @@ -15,10 +14,9 @@ spawn) # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Concatenate, ParamSpec -from vllm.utils import get_open_port +from vllm.utils import get_open_port, has_deep_ep -has_deep_ep = importlib.util.find_spec("deep_ep") is not None -if has_deep_ep: +if has_deep_ep(): from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 DeepEPHTPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index c9a4375ac939..69317405d48b 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -6,7 +6,6 @@ import pytest import torch -import triton.language as tl from tests.kernels.moe.utils import (batched_moe, make_quantized_test_activations, @@ -18,6 +17,7 @@ invoke_moe_batched_triton_kernel) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.platforms import current_platform +from vllm.triton_utils import tl MNK_FACTORS = [ (1, 128, 128), diff --git a/tests/kernels/moe/test_block_fp8.py b/tests/kernels/moe/test_block_fp8.py index c187542205a5..7dc6282326b6 100644 --- a/tests/kernels/moe/test_block_fp8.py +++ b/tests/kernels/moe/test_block_fp8.py @@ -15,13 +15,13 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe) from vllm.platforms import current_platform +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used -dg_available = False -try: - import deep_gemm - dg_available = True -except ImportError: - pass +dg_available = has_deep_gemm() + +if dg_available: + from deep_gemm import get_m_alignment_for_contiguous_layout if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -224,6 +224,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed, @pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@pytest.mark.skipif(is_blackwell_deep_gemm_used(), reason="Not E8M0 scale MOE") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, monkeypatch): @@ -238,8 +239,7 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed, torch.manual_seed(seed) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size)) - - block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_m = get_m_alignment_for_contiguous_layout() block_size = [block_m, block_m] dtype = torch.bfloat16 diff --git a/tests/kernels/moe/test_count_expert_num_tokens.py b/tests/kernels/moe/test_count_expert_num_tokens.py new file mode 100644 index 000000000000..0872836b6064 --- /dev/null +++ b/tests/kernels/moe/test_count_expert_num_tokens.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests compute_expert_num_tokens kernels +""" + +import dataclasses +from typing import Optional + +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens + + +@dataclasses.dataclass +class TestTensors: + + topk_ids: torch.Tensor + expert_map: Optional[torch.Tensor] = None + + def to_device(self, device: str): + self.topk_ids = self.topk_ids.to(device=device) + if self.expert_map is not None: + self.expert_map = self.expert_map.to(device=device) + + @staticmethod + def make(num_tokens: int, num_topk: int, num_experts: int, device: str, + topk_ids_dtype: torch.dtype) -> "TestTensors": + + # make topk ids + topk_ids = torch.empty((num_tokens, num_topk), + device=device, + dtype=torch.int64) + for x in range(num_tokens): + topk_ids[x] = torch.randperm(num_experts)[:num_topk] + topk_ids = topk_ids.to(dtype=torch.int64) + return TestTensors(topk_ids=topk_ids) + + def with_ep_rank(self, ep_rank: int, num_global_experts: int, + num_local_experts: int, device: str): + # make an expert map + expert_map = torch.empty((num_global_experts), + device=device, + dtype=torch.int32) + expert_map.fill_(-1) + s = ep_rank * num_local_experts + e = s + num_local_experts + expert_map[s:e] = torch.tensor(list(range(num_local_experts)), + device=device) + + return TestTensors(topk_ids=self.topk_ids.clone(), + expert_map=expert_map) + + +def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor): + # do the reference in cpu + tt.to_device("cpu") + expert_ids, counts = tt.topk_ids.unique(return_counts=True) + + for eid, count in zip(expert_ids, counts): + if eid != -1 and tt.expert_map is not None: + eid = tt.expert_map[eid] + + if eid == -1: + continue + + expert_num_tokens[eid] += count + + +def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int, + num_experts: int, ep_size: int, + topk_ids_dtype: torch.dtype): + + assert num_topk <= num_experts + + tt = TestTensors.make(num_tokens, + num_topk, + num_experts, + topk_ids_dtype=topk_ids_dtype, + device="cpu") + + num_global_experts = num_experts + assert num_global_experts % ep_size == 0 + num_local_experts = num_global_experts // ep_size + for ep_rank in range(ep_size): + tt_rank = tt.with_ep_rank(ep_rank, num_global_experts, + num_local_experts, "cpu") + + ref_expert_num_tokens = torch.zeros((num_local_experts), + device="cpu", + dtype=torch.int32) + ref_impl(tt_rank, ref_expert_num_tokens) + ref_expert_num_tokens = ref_expert_num_tokens.to("cuda") + + tt_rank.to_device("cuda") + # Test with expert_map + triton_expert_num_tokens_w_emap = count_expert_num_tokens( + tt_rank.topk_ids, num_local_experts, tt_rank.expert_map) + + # Test without expert map + topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype) + triton_expert_num_tokens_wo_emap = count_expert_num_tokens( + topk_ids, num_local_experts, expert_map=None) + + torch.testing.assert_close(ref_expert_num_tokens, + triton_expert_num_tokens_w_emap, + atol=0, + rtol=0) + torch.testing.assert_close(ref_expert_num_tokens, + triton_expert_num_tokens_wo_emap, + atol=0, + rtol=0) + + +@pytest.mark.parametrize( + "num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317]) +@pytest.mark.parametrize("num_topk", [2, 6, 8]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("ep_size", [1, 2, 4]) +@pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) +def test_compute_expert_num_tokens(num_tokens: int, num_topk: int, + num_experts: int, ep_size: int, + topk_ids_dtype: torch.dtype): + do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts, + ep_size, topk_ids_dtype) + + +@pytest.mark.parametrize("numel", list(range(1, 8192, 11))) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("ep_size", [2]) +@pytest.mark.parametrize("topk_ids_dtype", [torch.int64]) +def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int, + ep_size: int, + topk_ids_dtype: torch.dtype): + do_test_compute_expert_num_tokens(num_tokens=numel, + num_topk=1, + num_experts=num_experts, + ep_size=ep_size, + topk_ids_dtype=topk_ids_dtype) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 929db9177537..81fb3ec1de18 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses +from math import prod from typing import Optional import pytest @@ -8,9 +9,12 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 +from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8, run_cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, fused_topk) +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) from vllm.platforms import current_platform NUM_EXPERTS = [40, 64] @@ -21,6 +25,7 @@ (2, 1024, 1536), (2, 3072, 1024), (2, 3072, 1536), + (7, 3072, 1536), (64, 1024, 1024), (64, 1024, 1536), (64, 3072, 1024), @@ -236,6 +241,7 @@ def test_cutlass_moe_8_bit_no_graph( per_act_token: bool, per_out_ch: bool, monkeypatch, + ep_size: Optional[int] = None, ): current_platform.seed_everything(7) monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") @@ -254,7 +260,13 @@ def test_cutlass_moe_8_bit_no_graph( triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, topk_ids) - cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token) + if ep_size is not None: + assert e % ep_size == 0, "Cannot distribute experts evenly" + number_local_experts = e // ep_size + else: + number_local_experts = None + cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token, + number_local_experts) # Note 5.5 only needed for larger problem sizes, 5 works ok for # the rest. @@ -340,9 +352,62 @@ def test_cutlass_moe_8_bit_EP( per_out_channel: bool, ep_size: int, monkeypatch, +): + test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, + per_out_channel, monkeypatch, ep_size) + + +LARGE_MNK_FACTORS = [ + (1, 8192, 5120, 31), + (32768, 1024, 1024, 16), + (65536, 512, 1024, 16), +] + + +@pytest.mark.parametrize("m,n,k,topk", LARGE_MNK_FACTORS) +@pytest.mark.parametrize("e", [128]) +@pytest.mark.parametrize("per_act_token", [False]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [8]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_cutlass_moe_8_bit_EP_large( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: int, + monkeypatch, +): + test_cutlass_moe_8_bit_no_graph(m, n, k, e, topk, per_act_token, + per_out_channel, monkeypatch, ep_size) + + +@pytest.mark.parametrize("m,n,k,topk", [(1, 8192, 5120, 31)]) +@pytest.mark.parametrize("e", [128]) +@pytest.mark.parametrize("per_act_token", [False]) +@pytest.mark.parametrize("per_out_channel", [True]) +@pytest.mark.parametrize("ep_size", [8]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +def test_run_cutlass_moe_fp8( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_channel: bool, + ep_size: int, ): current_platform.seed_everything(7) - monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192") with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) @@ -352,20 +417,53 @@ def test_cutlass_moe_8_bit_EP( score, topk, renormalize=False) - - # Note that we are using the dequantized versions of the tensors. - # Using a, w1 and w2 directly results in minor output differences. - triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights, - topk_ids) - - assert e % ep_size == 0, "Cannot distribute experts evenly" - cutlass_output = run_8_bit(mt, - topk_weights, - topk_ids, - per_act_token, - num_local_experts=e // ep_size) - - torch.testing.assert_close(triton_output, - cutlass_output, - atol=5e-2, - rtol=1e-2) + # we want to make sure there is at least one token that's generated in + # this expert shard and at least one token that's NOT generated in this + # expert shard + topk_ids[0][0] = -1 + topk_ids[0][1] = 1 + + workspace13_shape = (m * topk, max(2 * n, k)) + workspace2_shape = (m * topk, n) + output_shape = (m * topk, k) + + workspace13 = torch.empty(prod(workspace13_shape), + device="cuda", + dtype=mt.a.dtype) + workspace2 = torch.empty(prod(workspace2_shape), + device="cuda", + dtype=mt.a.dtype) + + num_local_experts = e // ep_size + start, end = 0, num_local_experts + expert_map = [-1] * e + expert_map[start:end] = list(range(num_local_experts)) + expert_map = torch.tensor(expert_map, dtype=torch.int32, device="cuda") + + activation = lambda o, i: torch.ops._C.silu_and_mul(o, i) + a1q, a1q_scale = moe_kernel_quantize_input(mt.a, mt.a_scale, + torch.float8_e4m3fn, + per_act_token) + global_num_experts = -1 if mt.w1_q is None else mt.w1_q.size(0) + func = lambda output: run_cutlass_moe_fp8( + output, a1q, mt.w1_q, mt.w2_q, topk_ids, activation, + global_num_experts, expert_map, mt.w1_scale, mt.w2_scale, + a1q_scale, None, workspace13, workspace2, None, mt.a.dtype, + per_act_token, per_out_channel, False) + + workspace13.random_() + output_random_workspace = torch.empty(output_shape, + device="cuda", + dtype=mt.a.dtype) + func(output_random_workspace) + + workspace13.fill_(0) + output_zero_workspace = torch.zeros(output_shape, + device="cuda", + dtype=mt.a.dtype) + func(output_zero_workspace) + + torch.testing.assert_close(output_random_workspace, + output_zero_workspace, + atol=5e-3, + rtol=1e-3) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index b74137eeaaa6..074771e49a06 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -20,6 +20,7 @@ FusedMoEModularKernel) from vllm.platforms import current_platform from vllm.utils import has_deep_ep, has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from .parallel_utils import ProcessGroupInfo, parallel_launch from .utils import make_test_weights @@ -368,6 +369,8 @@ def _test_deepep_deepgemm_moe( @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm +@pytest.mark.skipif(is_blackwell_deep_gemm_used(), + reason="Skipping test for Blackwell DeepGEMM") def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, world_dp_size: tuple[int, int]): """ @@ -423,6 +426,8 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int, @pytest.mark.parametrize("world_dp_size", [(2, 1)]) @requires_deep_ep @requires_deep_gemm +@pytest.mark.skipif(is_blackwell_deep_gemm_used(), + reason="Skipping test for Blackwell DeepGEMM") def test_ll_deepep_deepgemm_moe( mnk: tuple[int, int, int], num_experts: int, diff --git a/tests/kernels/moe/test_deepgemm.py b/tests/kernels/moe/test_deepgemm.py index fa62507179a2..f7578e226917 100644 --- a/tests/kernels/moe/test_deepgemm.py +++ b/tests/kernels/moe/test_deepgemm.py @@ -15,46 +15,17 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8) -from vllm.utils import cdiv +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8 -has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None - -if has_deep_gemm: - import deep_gemm - BLOCK_M = deep_gemm.get_m_alignment_for_contiguous_layout() - BLOCK_SIZE = [BLOCK_M, BLOCK_M] +BLOCK_SIZE = [128, 128] requires_deep_gemm = pytest.mark.skipif( - not has_deep_gemm, + not has_deep_gemm(), reason="Requires deep_gemm kernels", ) -def calc_diff(x: torch.Tensor, y: torch.Tensor): - x, y = x.double(), y.double() - denominator = (x * x + y * y).sum() - sim = 2 * (x * y).sum() / denominator - return 1 - sim - - -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (cdiv(m, 128) * 128, cdiv(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - def make_block_quant_fp8_weights( e: int, n: int, @@ -124,7 +95,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size): topk_weights, topk_ids = torch.topk(router_logits, k=topk, dim=-1) topk_weights = torch.nn.functional.softmax(topk_weights, dim=-1) - # triton referrence + # triton reference out_triton = fused_experts( hidden_states=tokens_bf16, w1=w1, @@ -155,17 +126,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size): block_shape=block_size, allow_deep_gemm=True, ) - - base = out_triton.abs().mean() - atol = 0.1 * base.clamp(min=1e-2) # 10% of mean, but not lower than 1e-3 - rtol = 0.05 - # ----- Compare ----- - torch.testing.assert_close( - out_deepgemm.to(torch.float32), - out_triton.to(torch.float32), - rtol=rtol, - atol=float(atol), - ) + diff = calc_diff(out_deepgemm, out_triton) + assert diff < 0.001, f"Diff exceeded 1%: {diff}" # Note: W1 has shape (E, 2N, K), so N = 512 diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py new file mode 100644 index 000000000000..6f2869c3a61d --- /dev/null +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import copy +from itertools import product +from typing import Optional + +import pytest +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.config import VllmConfig, current_platform, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 + BatchedTritonOrDeepGemmExperts) +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.layer import TritonExperts +from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( + TritonOrDeepGemmExperts) +from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx + +from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors, + reference_moe_impl, + run_modular_kernel) +from .modular_kernel_tools.mk_objects import ( + MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, + MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES) +from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo, + parallel_launch_with_config) + +# TODO (varun): These requirements are very strict and could be relaxed. +has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx()) + +meets_package_requirements = pytest.mark.skipif( + not has_all_packages, + reason="Requires deep_ep & deep_gemm & pplx packages", +) + + +def rank_worker( + pgi: ProcessGroupInfo, + vllm_config: VllmConfig, + cpu_group, + config: Config, + weights: WeightTensors, +): + current_platform.seed_everything(pgi.rank) + + # sanity check + from vllm import envs + if config.fused_moe_chunk_size is not None: + assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE) + + # get weights to this device + weights.to_current_device() + + Ms = config.Ms + assert isinstance(Ms, list) + TOPKs = config.topks + assert isinstance(TOPKs, list) + + for m, topk in product(Ms, TOPKs): + print(f"Running m={m}, topk={topk} ...") + # override m and topk + cfgx = copy.deepcopy(config) + cfgx.Ms = m + cfgx.topks = topk + + # inputs for rank + rank_tensors = RankTensors.make(cfgx, pgi) + + # modular kernel out + mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights, + rank_tensors) + + with set_current_vllm_config(vllm_config): + ref_out = reference_moe_impl(cfgx, weights, rank_tensors) + + torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2) + + +def run(config: Config): + assert config.is_valid() + print(f"Testing config \n{config.describe()} ...") + + weights: WeightTensors = WeightTensors.make(config) + + vllm_config, env_dict = config.make_env_data() + parallel_launch_with_config(config.world_size, rank_worker, vllm_config, + env_dict, config, weights) + + +Ms = [32, 64] +Ks = [7168] # hidden sizes +Ns = [2048] +TOPKs = [4, 1] +Es = [32] +DTYPEs = [torch.bfloat16] +FUSED_MOE_CHUNK_SIZEs = [None, 16] + + +def is_nyi_config(config: Config) -> bool: + # We know these configs to be legitimate. but still fail. + + if (config.fused_experts_type in [ + BatchedTritonExperts, BatchedTritonOrDeepGemmExperts, + TritonExperts, TritonOrDeepGemmExperts + ]): + # The triton kernels expect both per-act-token-quant and + # per-out-ch-quant or neither. + unsupported_quant_config = ((config.is_per_act_token_quant + + config.is_per_out_ch_quant) == 1) + return unsupported_quant_config + + # cutlass kernels dont support expert_maps yet. + return config.fused_experts_type == CutlassExpertsFp8 + + +@pytest.mark.parametrize("k", Ks) +@pytest.mark.parametrize("n", Ns) +@pytest.mark.parametrize("e", Es) +@pytest.mark.parametrize("dtype", DTYPEs) +@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +@pytest.mark.parametrize( + "combination", + product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) +@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) +@pytest.mark.parametrize("world_size", [2]) +@meets_package_requirements +def test_modular_kernel_combinations_multigpu( + k: int, n: int, e: int, dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + combination: tuple[mk.FusedMoEPrepareAndFinalize, + mk.FusedMoEPermuteExpertsUnpermute], + fused_moe_chunk_size: Optional[int], world_size: int): + + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=fused_moe_chunk_size, + world_size=world_size, + ) + if not config.is_valid(): + pytest.skip(f"Tests config {config} is not valid. Skipping ...") + + if is_nyi_config(config): + pytest.skip(f"Tests config {config} is nyi. Skipping ...") + + print(f"{config.describe()}") + run(config) + + +@pytest.mark.parametrize("k", Ks) +@pytest.mark.parametrize("n", Ns) +@pytest.mark.parametrize("e", Es) +@pytest.mark.parametrize("dtype", DTYPEs) +@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +@pytest.mark.parametrize( + "combination", + product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES)) +@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) +@pytest.mark.parametrize("world_size", [1]) +@meets_package_requirements +def test_modular_kernel_combinations_singlegpu( + k: int, n: int, e: int, dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + combination: tuple[mk.FusedMoEPrepareAndFinalize, + mk.FusedMoEPermuteExpertsUnpermute], + fused_moe_chunk_size: Optional[int], world_size: int): + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=fused_moe_chunk_size, + world_size=world_size, + ) + + if not config.is_valid(): + pytest.skip(f"Tests config {config} is not valid. Skipping ...") + + if is_nyi_config(config): + pytest.skip(f"Tests config {config} is nyi. Skipping ...") + + run(config) + + +if __name__ == '__main__': + # Ability to test individual PrepareAndFinalize and FusedExperts combination + from .modular_kernel_tools.cli_args import (make_config, + make_config_arg_parser) + parser = make_config_arg_parser(description=( + "Run single prepare-finalize & fused-experts combination test" + "Example : python3 -m tests.kernels.moe.test_modular_kernel_combinations " #noqa: E501 + "--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts" + )) + args = parser.parse_args() + config = make_config(args) + + run(config) diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 96e3f29b3d79..0f1c78704642 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -174,6 +174,7 @@ def test_fused_moe( use_int8_w8a8=False, use_int8_w8a16=False, use_int4_w4a16=False, + use_mxfp4_w4a4=False, per_act_token_quant=False, block_shape=None) diff --git a/tests/kernels/moe/test_moe_align_block_size.py b/tests/kernels/moe/test_moe_align_block_size.py index e980422a7b97..12ef9e776c3a 100644 --- a/tests/kernels/moe/test_moe_align_block_size.py +++ b/tests/kernels/moe/test_moe_align_block_size.py @@ -1,90 +1,315 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools +"""Tests for the MOE align block size function. + +Run `pytest tests/kernels/moe/test_moe_align_block_size.py`. +""" + +from typing import Optional import pytest import torch -from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( - moe_align_block_size_triton) - - -@pytest.mark.parametrize( - "block_size,num_tokens,topk,num_experts", - list( - itertools.product( - [32, 64, 128, 256], # block_size - [ - 1, - 3, - 7, - 16, - 256, - 2256, - 4096, - ], # num_tokens - [1, 4, 16, 64], # topk - [64, 160, 256, 257, 260, 264], # num_experts - )), -) -def test_moe_align_block_size_compare_implementations(block_size, num_tokens, - topk, num_experts): - topk_ids = torch.stack([ - torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] - for _ in range(num_tokens) - ]) + moe_align_block_size) +from vllm.platforms import current_platform +from vllm.utils import round_up + +NUM_TOKENS = [1, 3, 7, 16, 256, 2256, 4096] +NUM_EXPERTS = [32, 160, 256, 257, 512] +TOP_KS = [1, 2, 16, 32] +BLOCK_SIZES = [32, 64, 128, 256] +current_platform.seed_everything(0) + + +def _group_tokens_by_expert( + sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + block_size: int, + valid_length: int, + total_tokens: int, +) -> dict: + num_blocks = valid_length // block_size + expert_tokens: dict[int, list[int]] = {} + + for block_idx in range(num_blocks): + expert_id = expert_ids[block_idx].item() + block_start = block_idx * block_size + block_end = min(block_start + block_size, valid_length) + + block_tokens = sorted_ids[block_start:block_end] + valid_tokens = block_tokens[block_tokens < total_tokens] + + if expert_id not in expert_tokens: + expert_tokens[expert_id] = [] + expert_tokens[expert_id].extend(valid_tokens.tolist()) + return expert_tokens + +def _verify_expert_level_sorting( + actual_sorted_ids: torch.Tensor, + golden_sorted_ids: torch.Tensor, + expert_ids: torch.Tensor, + block_size: int, + valid_length: int, + total_tokens: int, +): + """ + Verify that actual_sorted_ids follows the correct expert-level sorting. + The kerne limplementation may or may not preserve original token order + in topk_ids in the final sorted_ids however this does not impact quality. + """ + # Group tokens by expert from the golden implementation + golden_expert_tokens = _group_tokens_by_expert(golden_sorted_ids, + expert_ids, block_size, + valid_length, total_tokens) + + actual_expert_tokens = _group_tokens_by_expert(actual_sorted_ids, + expert_ids, block_size, + valid_length, total_tokens) + + assert set(golden_expert_tokens.keys()) == set( + actual_expert_tokens.keys()), ( + f"Expert IDs mismatch: golden={set(golden_expert_tokens.keys())}, " + f"actual={set(actual_expert_tokens.keys())}") + + for expert_id in golden_expert_tokens: + golden_tokens = torch.tensor(golden_expert_tokens[expert_id], + device=actual_sorted_ids.device) + actual_tokens = torch.tensor(actual_expert_tokens[expert_id], + device=actual_sorted_ids.device) + assert torch.equal( + torch.sort(golden_tokens)[0], + torch.sort(actual_tokens)[0]), ( + f"Expert {expert_id} token mismatch: " + f"golden={golden_expert_tokens[expert_id]}, " + f"actual={actual_expert_tokens[expert_id]}") + + +def torch_moe_align_block_size( + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + expert_map: Optional[torch.Tensor] = None, + pad_sorted_ids: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Golden torch implementation of moe_align_block_size. + + This function aligns the token distribution across experts to be compatible + with block size for matrix multiplication by sorting tokens by expert and + padding to block boundaries. + """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + if pad_sorted_ids: + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + + flattened_token_indices = torch.arange(topk_ids.numel(), + device=topk_ids.device, + dtype=torch.int32) + flattened_expert_ids = topk_ids.flatten() + sorted_expert_ids, sort_indices = torch.sort(flattened_expert_ids, + stable=True) + sorted_token_indices = flattened_token_indices[sort_indices] + + expert_token_counts = torch.zeros(num_experts, + dtype=torch.int64, + device=topk_ids.device) + for expert_id in range(num_experts): + mask = sorted_expert_ids == expert_id + expert_token_counts[expert_id] = mask.sum() + + expert_padded_counts = torch.zeros(num_experts, + dtype=torch.int64, + device=topk_ids.device) + for expert_id in range(num_experts): + original_count = expert_token_counts[expert_id] + if original_count > 0: + expert_padded_counts[expert_id] = ( + (original_count + block_size - 1) // block_size) * block_size - sorted_ids_cuda = torch.empty((max_num_tokens_padded, ), - dtype=torch.int32, - device=topk_ids.device) - sorted_ids_cuda.fill_(topk_ids.numel()) - max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids_cuda = torch.zeros((max_num_m_blocks, ), - dtype=torch.int32, - device=topk_ids.device) - num_tokens_post_pad_cuda = torch.empty((1), - dtype=torch.int32, - device=topk_ids.device) - - sorted_ids_triton = torch.empty_like(sorted_ids_cuda) - sorted_ids_triton.fill_(topk_ids.numel()) - expert_ids_triton = torch.zeros_like(expert_ids_cuda) - num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - - ops.moe_align_block_size( - topk_ids, - num_experts, + sorted_token_ids = torch.full( + (max_num_tokens_padded, ), + topk_ids.numel(), + dtype=torch.int32, + device=topk_ids.device, + ) + max_num_blocks = (max_num_tokens_padded + block_size - 1) // block_size + expert_ids = torch.zeros(max_num_blocks, + dtype=torch.int32, + device=topk_ids.device) + + current_pos = 0 + current_block = 0 + for expert_id in range(num_experts): + expert_mask = sorted_expert_ids == expert_id + expert_tokens = sorted_token_indices[expert_mask] + num_expert_tokens = expert_tokens.shape[0] + + if num_expert_tokens > 0: + sorted_token_ids[current_pos:current_pos + + num_expert_tokens] = (expert_tokens) + + expert_blocks_needed = expert_padded_counts[expert_id] // block_size + expert_ids[current_block:current_block + + expert_blocks_needed] = (expert_id) + + current_pos += expert_padded_counts[expert_id] + current_block += expert_blocks_needed + + total_padded_tokens = expert_padded_counts.sum() + num_tokens_post_pad = torch.tensor([total_padded_tokens], + dtype=torch.int32, + device=topk_ids.device) + + if expert_map is not None: + expert_ids = expert_map[expert_ids] + return sorted_token_ids, expert_ids, num_tokens_post_pad + + +@pytest.mark.parametrize("m", NUM_TOKENS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("num_experts", NUM_EXPERTS) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("pad_sorted_ids", [False, True]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_moe_align_block_size(m: int, topk: int, num_experts: int, + block_size: int, pad_sorted_ids: bool): + """Test moe_align_block_size without expert mapping""" + topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) + for i in range(m): + experts = torch.randperm(num_experts, device="cuda")[:topk] + topk_ids[i] = experts + + actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( + moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + pad_sorted_ids=pad_sorted_ids, + )) + golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( + torch_moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + pad_sorted_ids=pad_sorted_ids, + )) + + torch.testing.assert_close(actual_num_tokens, + golden_num_tokens, + atol=0, + rtol=0) + torch.testing.assert_close(actual_expert_ids, + golden_expert_ids, + atol=0, + rtol=0) + + # For sorted_token_ids, verify block-level correctness rather than exact + # order Tokens within each expert's blocks can be in any order, but expert + # regions must be correct + _verify_expert_level_sorting( + actual_sorted_ids, + golden_sorted_ids, + actual_expert_ids, block_size, - sorted_ids_cuda, - expert_ids_cuda, - num_tokens_post_pad_cuda, + actual_num_tokens.item(), + m * topk, ) - moe_align_block_size_triton( - topk_ids, - num_experts, + total_tokens = m * topk + assert actual_num_tokens.item() % block_size == 0, ( + "num_tokens_post_pad should be divisible by block_size") + assert actual_num_tokens.item() >= total_tokens, ( + "num_tokens_post_pad should be at least total_tokens") + valid_tokens = actual_sorted_ids[actual_sorted_ids < total_tokens] + assert len(valid_tokens) == total_tokens, ( + f"Should have exactly {total_tokens} valid tokens, " + f"got {len(valid_tokens)}") + assert (actual_expert_ids >= 0).all() and ( + actual_expert_ids + < num_experts).all(), "expert_ids should contain valid expert indices" + + +@pytest.mark.parametrize("m", [16, 32]) +@pytest.mark.parametrize("topk", [2, 4]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +def test_moe_align_block_size_with_expert_map(m: int, topk: int, + num_experts: int, + block_size: int): + """Test moe_align_block_size with expert mapping (EP scenario)""" + topk_ids = torch.zeros((m, topk), device="cuda", dtype=torch.int32) + for i in range(m): + experts = torch.randperm(num_experts, device="cuda")[:topk] + topk_ids[i] = experts + + expert_map = torch.full((num_experts, ), + -1, + device="cuda", + dtype=torch.int32) + local_experts = list(range(0, num_experts, 2)) + for i, expert_id in enumerate(local_experts): + expert_map[expert_id] = i + + actual_sorted_ids, actual_expert_ids, actual_num_tokens = ( + moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + expert_map=expert_map, + )) + golden_sorted_ids, golden_expert_ids, golden_num_tokens = ( + torch_moe_align_block_size( + topk_ids=topk_ids, + block_size=block_size, + num_experts=num_experts, + expert_map=expert_map, + )) + + torch.testing.assert_close(actual_num_tokens, + golden_num_tokens, + atol=0, + rtol=0) + torch.testing.assert_close(actual_expert_ids, + golden_expert_ids, + atol=0, + rtol=0) + _verify_expert_level_sorting( + actual_sorted_ids, + golden_sorted_ids, + actual_expert_ids, block_size, - sorted_ids_triton, - expert_ids_triton, - num_tokens_post_pad_triton, + actual_num_tokens.item(), + m * topk, ) - assert torch.allclose(expert_ids_cuda, expert_ids_triton), ( - f"Expert IDs mismatch for block_size={block_size}, " - f"num_tokens={num_tokens}, topk={topk}\n" - f"CUDA expert_ids: {expert_ids_cuda}\n" - f"Triton expert_ids: {expert_ids_triton}") - assert torch.allclose( - num_tokens_post_pad_cuda, num_tokens_post_pad_triton), ( - f"Num tokens post pad mismatch for block_size={block_size}, " - f"num_tokens={num_tokens}, topk={topk}\n" - f"CUDA num_tokens_post_pad: {num_tokens_post_pad_cuda}\n" - f"Triton num_tokens_post_pad: {num_tokens_post_pad_triton}") +def test_moe_align_block_size_deterministic(): + m, topk, num_experts, block_size = 128, 2, 32, 64 + + torch.manual_seed(42) + topk_ids = torch.randint(0, + num_experts, (m, topk), + device="cuda", + dtype=torch.int32) + # expect the results to be reproducible + results = [] + for _ in range(5): + sorted_ids, expert_ids, num_tokens = moe_align_block_size( + topk_ids=topk_ids, block_size=block_size, num_experts=num_experts) + results.append( + (sorted_ids.clone(), expert_ids.clone(), num_tokens.clone())) -if __name__ == "__main__": - pytest.main([__file__]) + for i in range(1, len(results)): + assert torch.equal( + results[0][0], + results[i][0]), ("sorted_ids should be deterministic") + assert torch.equal( + results[0][1], + results[i][1]), ("expert_ids should be deterministic") + assert torch.equal( + results[0][2], + results[i][2]), ("num_tokens should be deterministic") diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py index 7cc83b512c8b..8d215a0cbeed 100644 --- a/tests/kernels/moe/test_moe_permute_unpermute.py +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -17,28 +17,34 @@ moe_permute, moe_permute_unpermute_supported, moe_unpermute) from vllm.platforms import current_platform -NUM_EXPERTS = [16, 64] +NUM_EXPERTS = [16, 64, 256] TOP_KS = [2, 4, 6, 8] EP_SIZE = [1, 4, 16] current_platform.seed_everything(0) -def torch_permute(hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, - start_expert: int, - expert_map: Optional[torch.Tensor] = None, - align_block_size: Optional[int] = None, - fill_invalid_expert: int = -1) -> list[torch.Tensor]: +def torch_permute( + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + # token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1) -> list[torch.Tensor]: n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] if expert_map is not None: is_local_expert = (expert_map[topk_ids] != -1) not_local_expert = (expert_map[topk_ids] == -1) topk_ids = is_local_expert * ( topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + token_expert_indices = torch.arange(0, + n_token * topk, + dtype=torch.int32, + device=hidden_states.device).reshape( + (n_token, topk)) sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), stable=True) @@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor, valid_row_idx = [] if align_block_size is None: - permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % - n_token, ...] + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // + topk, ...] permuted_row_size = permuted_hidden_states.shape[0] m_indices = torch.empty(permuted_row_size, device="cuda", @@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor, 0, n_token * topk, device="cuda", dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + dst_row_id2src_row_id_map[ + expert_first_token_offset[-1]:] = n_token * topk return [ permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices, valid_row_idx + src_row_id2dst_row_id_map, dst_row_id2src_row_id_map, m_indices, + valid_row_idx ] else: permuted_row_size = (topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size + permuted_idx = torch.full((permuted_row_size, ), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device) permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype) @@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor, align_first_token_offset = align_expert_first_token_offset[i - 1] align_last_token_offset = align_expert_first_token_offset[i] dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ - first_token_offset:first_token_offset + - n_token_in_expert] % n_token + first_token_offset:first_token_offset + n_token_in_expert] # store token in current expert with align_first_token_offset permuted_hidden_states[align_first_token_offset:\ align_first_token_offset+n_token_in_expert,\ ...] = hidden_states[\ - dst_row_id2src_row_id_in_expert, ...] + dst_row_id2src_row_id_in_expert // topk,\ + ...] + permuted_idx[align_first_token_offset:\ + align_first_token_offset+\ + n_token_in_expert] = dst_row_id2src_row_id_in_expert # set current expert m_indices m_indices[align_first_token_offset:align_last_token_offset] = i - 1 valid_row_idx += [ @@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor, src2dst_idx].reshape((n_token, topk)) return [ permuted_hidden_states, align_expert_first_token_offset, - align_src_row_id2dst_row_id, m_indices, valid_row_idx + align_src_row_id2dst_row_id, permuted_idx, m_indices, valid_row_idx ] @@ -146,15 +162,18 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor, valid_row_idx: torch.Tensor, topk: int, n_expert: int) -> torch.Tensor: # ignore invalid row + n_hidden = permuted_hidden_states.shape[1] mask = torch.zeros(permuted_hidden_states.shape[0], dtype=bool, device="cuda") mask[valid_row_idx] = True permuted_hidden_states[~mask] = 0 - idx = src_row_id2dst_row_id_map.flatten()[ - token_expert_indices.flatten()].reshape(token_expert_indices.shape) - output = permuted_hidden_states[idx, ...] * topk_weights[..., None] - output = output.sum(dim=1).to(permuted_hidden_states.dtype) + + permuted_hidden_states = permuted_hidden_states[ + src_row_id2dst_row_id_map.flatten(), ...] + permuted_hidden_states = permuted_hidden_states.view(-1, topk, n_hidden) + output = (permuted_hidden_states * topk_weights.unsqueeze(2)).sum(1).to( + permuted_hidden_states.dtype) return output @@ -184,43 +203,56 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states, gating_output, topk, False) - gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( - hidden_states, - topk_ids, - token_expert_indices, - topk, - n_expert, - n_local_expert, - start_expert, - expert_map=expert_map, - align_block_size=align_block_size, - fill_invalid_expert=fill_invalid_expert) - - result0, result1, result2, result3 = moe_permute( - hidden_states, topk_weights, topk_ids, token_expert_indices, topk, - n_expert, n_local_expert, expert_map, align_block_size, - fill_invalid_expert) + (gold_permuted_hidden_states, gold_expert_first_token_offset, + gold_inv_permuted_idx, gold_permuted_idx, gold_m_indices, + valid_row_idx) = torch_permute( + hidden_states, + topk_ids, + # token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) + + (permuted_hidden_states, _, expert_first_token_offset, inv_permuted_idx, + m_indices) = moe_permute(hidden_states=hidden_states, + a1q_scale=None, + topk_ids=topk_ids, + n_expert=n_expert, + n_local_expert=n_local_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) # check expert_first_token_offset - torch.testing.assert_close(gold1, result1, atol=0, rtol=0) + torch.testing.assert_close(gold_expert_first_token_offset, + expert_first_token_offset, + atol=0, + rtol=0) # check src_row_id2dst_row_id_map - torch.testing.assert_close(gold2, result2, atol=0, rtol=0) + torch.testing.assert_close(gold_inv_permuted_idx.flatten(), + inv_permuted_idx, + atol=0, + rtol=0) # check mindice - torch.testing.assert_close(gold3, result3, atol=0, rtol=0) + torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0) # check permuted_hidden_states, only valid token - torch.testing.assert_close(gold0[valid_row_idx], - result0[valid_row_idx], + torch.testing.assert_close(gold_permuted_hidden_states[valid_row_idx], + permuted_hidden_states[valid_row_idx], atol=0, rtol=0) - # add a random tensor to simulate group gemm - result0 = 0.5 * result0 + torch.randn_like(result0) + result0 = 0.5 * permuted_hidden_states + torch.randn_like( + permuted_hidden_states) + result4 = torch.empty_like(hidden_states) + moe_unpermute(result4, result0, topk_weights, inv_permuted_idx, + expert_first_token_offset) - result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, - topk, n_expert, n_local_expert) gold4 = torch_unpermute(result0, topk_weights, topk_ids, - token_expert_indices, result2, valid_row_idx, topk, - n_local_expert) - + token_expert_indices, inv_permuted_idx, + valid_row_idx, topk, n_local_expert) # check unpermuted hidden torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_mxfp4_moe.py b/tests/kernels/moe/test_mxfp4_moe.py new file mode 100644 index 000000000000..824b072a9f93 --- /dev/null +++ b/tests/kernels/moe/test_mxfp4_moe.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib +import importlib.metadata +from dataclasses import dataclass + +import pytest +import torch +from packaging import version + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( + "quark") is not None and version.parse( + importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@pytest.mark.parametrize('model_case', [ + ModelCase("fxmarty/qwen_1.5-moe-a2.7b-mxfp4", tp=1), + ModelCase("fxmarty/deepseek_r1_3_layers_mxfp4", tp=8), + ModelCase("fxmarty/Llama-4-Scout-17B-16E-Instruct-2-layers-mxfp4", tp=1) +]) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +def test_mxfp4_loading_and_execution_moe(vllm_runner, model_case: ModelCase): + if torch.cuda.device_count() < model_case.tp: + pytest.skip(f"This test requires >={model_case.tp} gpus, got only " + f"{torch.cuda.device_count()}") + + with vllm_runner(model_case.model_id, + tensor_parallel_size=model_case.tp, + load_format="dummy") as llm: + + # TODO: llm.apply_model(check_model) currently relies on V0 internals. + # Re-enable this later. + # def check_model(model): + # layer = model.model.layers[0] + + # qkv_proj = layer.self_attn.qkv_proj + + # assert isinstance(qkv_proj.quant_method, QuarkLinearMethod) + # assert isinstance(qkv_proj.scheme, QuarkW4A4MXFP4) + + # assert isinstance(layer.mlp.experts.quant_method, + # QuarkW4A4MXFp4MoEMethod) + + # if model_case.model_id == "fxmarty/qwen_1.5-moe-a2.7b-mxfp4": + # llm.apply_model(check_model) + + output = llm.generate_greedy("Today I am in the French Alps and", + max_tokens=20) + assert output \ No newline at end of file diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py index 3f5412e75821..3ff385360299 100644 --- a/tests/kernels/moe/test_nvfp4_moe.py +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -93,11 +93,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, a1_gscale=a1_gs, w1_fp4=w1_q, w1_blockscale=w1_blockscale, - w1_alphas=(1 / w1_gs), + g1_alphas=(1 / w1_gs), a2_gscale=a2_gs, w2_fp4=w2_q, w2_blockscale=w2_blockscale, - w2_alphas=(1 / w2_gs), + g2_alphas=(1 / w2_gs), topk_weights=topk_weights, topk_ids=topk_ids, m=m, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index d28e0e040629..f7a661b4bc7b 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -32,6 +32,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEModularKernel) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.platforms import current_platform from vllm.utils import round_up @@ -371,6 +373,7 @@ def pplx_prepare_finalize( chunk_topk_weight, chunk_topk_ids, False, + weight_and_reduce_impl=TopKWeightAndReduceDelegate(), ) torch.cuda.synchronize() diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 42d5526dc21f..26aa8d652e63 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -8,19 +8,14 @@ import torch from tests.kernels.quant_utils import (native_per_token_group_quant_fp8, - native_w8a8_block_matmul, - per_block_cast_to_fp8) + native_w8a8_block_matmul) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, + w8a8_block_fp8_matmul) from vllm.platforms import current_platform - -dg_available = False -try: - import deep_gemm - dg_available = True -except ImportError: - pass +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8 if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", @@ -106,7 +101,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) -@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") +@pytest.mark.skipif(not has_deep_gemm(), + reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -120,9 +116,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - _, block_k = block_size[0], block_size[1] - - A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) + A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1]) B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) As = As_fp8.to(torch.float32) @@ -132,14 +126,14 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): out_dtype) # Transpose earlier so that the testing will not trigger transposing kernels - As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) + As_fp8 = get_col_major_tma_aligned_tensor(As_fp8) out = torch.zeros((M, N), device='cuda', dtype=out_dtype) assert As_fp8.shape == (M, (K + 127) // 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) + fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / diff --git a/tests/kernels/quantization/test_per_token_group_quant.py b/tests/kernels/quantization/test_per_token_group_quant.py new file mode 100644 index 000000000000..f826983fe94e --- /dev/null +++ b/tests/kernels/quantization/test_per_token_group_quant.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from unittest.mock import patch + +import pytest +import torch + +from vllm.model_executor.layers.quantization.utils import fp8_utils + + +@pytest.mark.parametrize("shape", [(32, 128), (64, 256), (16, 512)]) +@pytest.mark.parametrize("column_major", [False, True]) +@pytest.mark.parametrize("scale_ue8m0", [False, True]) +@pytest.mark.parametrize("group_size", [64, 128]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_per_token_group_quant_fp8(shape, column_major: bool, + scale_ue8m0: bool, group_size: int): + device = "cuda" + + torch.manual_seed(42) + num_tokens, hidden_dim = shape + + x = (torch.randn( + (num_tokens, hidden_dim), device=device, dtype=torch.bfloat16) * 8) + + # cuda path + out_q, scale = fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + + # triton ref + with patch("vllm.platforms.current_platform.is_cuda", return_value=False): + ref_q, ref_s = fp8_utils.per_token_group_quant_fp8( + x, + group_size, + column_major_scales=column_major, + use_ue8m0=scale_ue8m0, + ) + + assert torch.allclose(out_q.float(), ref_q.float(), atol=0.15, rtol=0.15) + assert torch.allclose(scale, ref_s, atol=0.01, rtol=0.01) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index fcaa93762856..2e8febbdcf26 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -1072,6 +1072,7 @@ def torch_experts( quant_dtype: Optional[torch.dtype] = None, per_act_token_quant=False, block_shape: Optional[list[int]] = None, + apply_router_weights_on_input: bool = False, ) -> torch.Tensor: assert (global_num_experts == -1 or (global_num_experts == w1.shape[0] and expert_map is None) @@ -1081,11 +1082,17 @@ def torch_experts( M, K = a.shape topk = topk_ids.shape[1] + if apply_router_weights_on_input: + assert topk == 1 + a = a * topk_weight.to(a.dtype) + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) - a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype, + if a1_scale: + assert not per_act_token_quant and block_shape is None + a, a_scale = moe_kernel_quantize_input(a, a1_scale, quant_dtype, per_act_token_quant, block_shape) num_experts = w1.shape[0] @@ -1104,6 +1111,7 @@ def torch_experts( tmp2 = SiluAndMul()(tmp1) out[mask] = tmp2 @ w2[i].transpose(0, 1) elif block_shape is not None: + # block quantized assert (a_scale is not None and w1_scale is not None and w2_scale is not None) tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask], @@ -1121,15 +1129,27 @@ def torch_experts( assert (a_scale is not None and w1_scale is not None and w2_scale is not None) scales = a_scale if a_scale.numel() == 1 else a_scale[mask] + tmp1 = a[mask].to(f32) * scales w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1) - tmp1 = tmp1 @ w1_dq - tmp2 = SiluAndMul()(tmp1) + tmp1 = (tmp1 @ w1_dq).to(out.dtype) + + tmp2 = SiluAndMul()(tmp1).to(out.dtype) + + tmp2, b_scale = moe_kernel_quantize_input( + tmp2, a2_scale, quant_dtype, per_act_token_quant, + block_shape) + assert b_scale is not None + + tmp2 = tmp2.to(f32) * b_scale w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1) out[mask] = (tmp2 @ w2_dq).to(out.dtype) - return (out.view(M, -1, w2.shape[1]).to(f32) * - topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) + if apply_router_weights_on_input: + return out + else: + return (out.view(M, -1, w2.shape[1]).to(f32) * + topk_weight.view(M, -1, 1)).sum(dim=1).to(out.dtype) def torch_moe(a: torch.Tensor, diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 881d5efa6919..909b73933139 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -221,11 +221,6 @@ def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") -@pytest.fixture(scope="session") -def long_context_lora_files_16k_1(): - return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") - - @pytest.fixture def llama_2_7b_engine_extra_embeddings(): cleanup_dist_env_and_memory(shutdown_ray=True) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index 3ac3b80ec827..b1ad1fdd0606 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -169,7 +169,8 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", MODEL_PATH, "--lora-path", lora_path, "--tensor-parallel-size", str(tp_size), "serialize", "--serialized-directory", - str(tmp_path), "--suffix", suffix + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' ], check=True, capture_output=True, @@ -184,27 +185,26 @@ def test_tp2_serialize_and_deserialize_lora(tmp_path, sql_lora_files, model_uri = tmp_path / "vllm" / model_ref / suffix / model_name tensorizer_config = TensorizerConfig(tensorizer_uri=str(model_uri)) - tensorizer_config.lora_dir = tensorizer_config.tensorizer_dir - loaded_vllm_model = LLM(model=model_ref, - load_format="tensorizer", - enable_lora=True, - enforce_eager=True, - model_loader_extra_config=tensorizer_config, - max_num_seqs=13, - tensor_parallel_size=2, - max_loras=2) + loaded_llm = LLM(model=model_ref, + load_format="tensorizer", + enable_lora=True, + enforce_eager=True, + model_loader_extra_config=tensorizer_config, + max_num_seqs=13, + tensor_parallel_size=2, + max_loras=2) - tensorizer_config_dict = tensorizer_config.to_dict() + tc_as_dict = tensorizer_config.to_serializable() print("lora adapter created") - assert do_sample(loaded_vllm_model, + assert do_sample(loaded_llm, sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, + tensorizer_config_dict=tc_as_dict, lora_id=0) == EXPECTED_NO_LORA_OUTPUT print("lora 1") - assert do_sample(loaded_vllm_model, + assert do_sample(loaded_llm, sql_lora_files, - tensorizer_config_dict=tensorizer_config_dict, + tensorizer_config_dict=tc_as_dict, lora_id=1) == EXPECTED_LORA_OUTPUT diff --git a/tests/lora/test_peft_helper.py b/tests/lora/test_peft_helper.py index f16589e06b2d..df8696cf58e0 100644 --- a/tests/lora/test_peft_helper.py +++ b/tests/lora/test_peft_helper.py @@ -38,8 +38,8 @@ ] -def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): - peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1, +def test_peft_helper_pass(sql_lora_files, tmp_path): + peft_helper = PEFTHelper.from_local_dir(sql_lora_files, max_position_embeddings=4096) lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2) peft_helper.validate_legal(lora_config) @@ -56,15 +56,12 @@ def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path): "embed_tokens", "lm_head", ] - assert peft_helper.context_length == 16384 assert peft_helper.vllm_max_position_embeddings == 4096 - assert peft_helper.vllm_long_context_scaling_factor == float( - math.ceil(peft_helper.context_length / - peft_helper.vllm_max_position_embeddings)) + # test RSLoRA rslora_config = dict(use_rslora=True) test_dir = tmp_path / "test_rslora" - shutil.copytree(long_context_lora_files_16k_1, test_dir) + shutil.copytree(sql_lora_files, test_dir) # Load and modify configuration config_path = test_dir / "adapter_config.json" diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 9d75512a248b..3090941e6367 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - import vllm from vllm.lora.request import LoRARequest @@ -49,9 +47,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]: return generated_texts -# Skipping for V1 for now as we are hitting, -# "Head size 80 is not supported by FlashAttention." error. -@pytest.mark.skip(reason="Head size 80 is not supported by FlashAttention") def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. diff --git a/tests/lora/test_transformers_model.py b/tests/lora/test_transformers_model.py index 5065a2fb7164..723f7a54778f 100644 --- a/tests/lora/test_transformers_model.py +++ b/tests/lora/test_transformers_model.py @@ -9,7 +9,7 @@ from ..utils import create_new_process_for_each_test, multi_gpu_test -MODEL_PATH = "ArthurZ/ilama-3.2-1B" +MODEL_PATH = "hmellor/Ilama-3.2-1B" PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 7bb5d8980d61..8cae8a80d38e 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -1,15 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time - import pytest import ray from prometheus_client import REGISTRY import vllm.envs as envs from vllm import EngineArgs, LLMEngine -from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import RayPrometheusStatLogger @@ -44,7 +41,7 @@ def test_metric_counter_prompt_tokens( dtype=dtype, disable_log_stats=False, gpu_memory_utilization=0.4) as vllm_model: - tokenizer = vllm_model.model.get_tokenizer() + tokenizer = vllm_model.llm.get_tokenizer() prompt_token_counts = [ len(tokenizer.encode(p)) for p in example_prompts ] @@ -56,7 +53,7 @@ def test_metric_counter_prompt_tokens( vllm_prompt_token_count = sum(prompt_token_counts) _ = vllm_model.generate_greedy(example_prompts, max_tokens) - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] metric_count = stat_logger.metrics.counter_prompt_tokens.labels( **stat_logger.labels)._value.get() @@ -80,8 +77,8 @@ def test_metric_counter_generation_tokens( disable_log_stats=False, gpu_memory_utilization=0.4) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - tokenizer = vllm_model.model.get_tokenizer() - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + tokenizer = vllm_model.llm.get_tokenizer() + stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] metric_count = stat_logger.metrics.counter_generation_tokens.labels( **stat_logger.labels)._value.get() vllm_generation_count = 0 @@ -116,8 +113,8 @@ def test_metric_counter_generation_tokens_multi_step( disable_async_output_proc=disable_async_output_proc, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - tokenizer = vllm_model.model.get_tokenizer() - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + tokenizer = vllm_model.llm.get_tokenizer() + stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] metric_count = stat_logger.metrics.counter_generation_tokens.labels( **stat_logger.labels)._value.get() vllm_generation_count = 0 @@ -148,7 +145,7 @@ def test_metric_set_tag_model_name(vllm_runner, model: str, dtype: str, disable_log_stats=False, gpu_memory_utilization=0.3, served_model_name=served_model_name) as vllm_model: - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + stat_logger = vllm_model.llm.llm_engine.stat_loggers['prometheus'] metrics_tag_content = stat_logger.labels["model_name"] if envs.VLLM_CI_USE_S3: @@ -232,149 +229,6 @@ def test_engine_log_metrics_regression( assert_metrics(model, engine, disable_log_stats, len(example_prompts)) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [10]) -def test_metric_spec_decode( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - k = 5 - - with vllm_runner( - model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4, - speculative_config={ - "model": model, - "num_speculative_tokens": k, - }, - ) as vllm_model: - - # Force log interval to be 0 to catch all metrics. - stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] - stat_logger.local_interval = 0 - - # Note that the purpose of this test is to verify spec decode - # metrics instead of functional correctness, so the expected values - # are intended to be loose. - metric_name_to_expected_fn = { - "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1, - "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, - "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, - "counter_spec_decode_num_draft_tokens": lambda v: v == k, - "counter_spec_decode_num_emitted_tokens": - lambda v: 0 <= v <= k + 1, - } - - # Use one request to better inspect the metrics. - prompts = example_prompts[:1] - - _ = vllm_model.generate_greedy(prompts, max_tokens) - for metric_name, is_expected in metric_name_to_expected_fn.items(): - metric_val = getattr( - stat_logger.metrics, - metric_name).labels(**stat_logger.labels)._value.get() - assert is_expected(metric_val), ( - f"the value of metric {metric_name} ({metric_val}) " - "does not meet expectation") - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [10]) -@pytest.mark.parametrize("log_interval", [1, 3, 5, 7]) -def test_metric_spec_decode_interval( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - log_interval: int, -) -> None: - k = 5 - - engine_args = EngineArgs( - model=model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4, - speculative_config={ - "model": model, - "num_speculative_tokens": k, - }, - enforce_eager=True, - ) - - engine = LLMEngine.from_engine_args(engine_args) - - try: - - engine.add_request( - "request-id-0", - example_prompts[0], - SamplingParams(max_tokens=max_tokens), - ) - - # set log internal - stat_logger = engine.stat_loggers['prometheus'] - stat_logger.local_interval = log_interval - - # prefill - engine.step() - - # wait for 5 seconds to ensure that spec decode metrics - # get triggered in first decode step - time.sleep(5) - - # first decode step should trigger async collection of metrics - engine.step() - - # wait one second to allow H2D transfer to finish - time.sleep(1) - - # second decode step should now be able to collect the spec - # decode stats and the request should also be finished - engine.step() - - # must have finisehd now - assert not engine.has_unfinished_requests() - - # wait to ensure logging occurs - time.sleep(log_interval) - - # force logging - engine.step() - - # Note that the purpose of this test is to verify spec decode - # metrics instead of functional correctness, so the expected values - # are intended to be loose. - metric_name_to_expected_fn = { - "gauge_spec_decode_draft_acceptance_rate": lambda v: 0 <= v <= 1, - "gauge_spec_decode_efficiency": lambda v: 0 <= v <= 1, - "counter_spec_decode_num_accepted_tokens": lambda v: 0 <= v <= k, - "counter_spec_decode_num_draft_tokens": lambda v: v == k, - "counter_spec_decode_num_emitted_tokens": - lambda v: 0 <= v <= k + 1, - } - - for metric_name, is_expected in metric_name_to_expected_fn.items(): - metric_val = getattr( - stat_logger.metrics, - metric_name).labels(**stat_logger.labels)._value.get() - assert is_expected(metric_val), ( - f"the value of metric {metric_name} ({metric_val}) " - "does not meet expectation") - - finally: - del engine - cleanup_dist_env_and_memory() - - def assert_metrics(model: str, engine: LLMEngine, disable_log_stats: bool, num_requests: int) -> None: if disable_log_stats: diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index ac31064d9212..721478f42442 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -46,20 +46,15 @@ def test_guided_logits_processors(zephyr_7B_tokenzer, sample_regex, whitespace_pattern=None, reasoner=None) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an example IPv4 address with this regex: {sample_regex}") tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - regex_LP(token_ids, tensor) + tensor = regex_LP([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}" - ) tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - json_LP(token_ids, tensor) + tensor = json_LP([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) @@ -81,8 +76,6 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, seed=0, dtype="bfloat16", ) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an example IPv4 address with this regex: {sample_regex}") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = get_local_guided_decoding_logits_processor( @@ -92,13 +85,11 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - tensor = regex_lp(token_ids, tensor) + # allowed tokens at state 0 + tensor = regex_lp([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - token_ids = zephyr_7B_tokenzer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}" - ) json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = await get_guided_decoding_logits_processor( @@ -106,7 +97,7 @@ async def test_guided_logits_processor_black_box(backend: str, is_local: bool, assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) - tensor = json_lp(token_ids, tensor) + tensor = json_lp([], tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) @@ -130,7 +121,6 @@ async def test_guided_logits_processor_with_reasoning( dtype="bfloat16", ) token_ids = deepseek_r1_qwen_tokenizer.encode( - f"Give an example IPv4 address with this regex: {sample_regex}." "<think>here is the thinking process") regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) @@ -141,14 +131,13 @@ async def test_guided_logits_processor_with_reasoning( regex_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert regex_lp is not None - tensor = torch.rand(32000) + tensor = torch.rand(151664) original_tensor = torch.clone(tensor) tensor = regex_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape assert torch.allclose(tensor, original_tensor) token_ids = deepseek_r1_qwen_tokenizer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}." "<think>here is the thinking process") json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) @@ -158,7 +147,7 @@ async def test_guided_logits_processor_with_reasoning( await get_guided_decoding_logits_processor( json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert json_lp is not None - tensor = torch.rand(32000) + tensor = torch.rand(151664) original_tensor = torch.clone(tensor) tensor = json_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape @@ -166,8 +155,7 @@ async def test_guided_logits_processor_with_reasoning( # Thinking is over, so the tensor should change. token_ids = deepseek_r1_qwen_tokenizer.encode( - f"Give an employee profile that fits this schema: {sample_json_schema}." - "<think>here is the thinking process</think> Then") + "<think>here is the thinking process</think>") json_request = GuidedDecodingParams(json=sample_json_schema, backend=backend) json_lp = get_local_guided_decoding_logits_processor( @@ -176,7 +164,7 @@ async def test_guided_logits_processor_with_reasoning( await get_guided_decoding_logits_processor( json_request, deepseek_r1_qwen_tokenizer, config, reasoning_backend) assert json_lp is not None - tensor = torch.rand(32000) + tensor = torch.rand(151664) original_tensor = torch.clone(tensor) tensor = json_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape @@ -201,19 +189,6 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") -def test_guided_decoding_backend_options(): - """Test backend-specific options""" - with pytest.warns(DeprecationWarning): - guided_decoding_params = GuidedDecodingParams( - backend= - "xgrammar:no-fallback,disable-any-whitespace,no-additional-properties" - ) - assert guided_decoding_params.backend == "xgrammar" - assert guided_decoding_params.disable_fallback - assert guided_decoding_params.disable_any_whitespace - assert guided_decoding_params.disable_additional_properties - - def test_pickle_xgrammar_tokenizer_data(): try: import xgrammar as xgr diff --git a/tests/model_executor/test_model_load_with_params.py b/tests/model_executor/test_model_load_with_params.py index 4bdb651e5170..aae9a4d1ef11 100644 --- a/tests/model_executor/test_model_load_with_params.py +++ b/tests/model_executor/test_model_load_with_params.py @@ -5,7 +5,8 @@ import pytest -from vllm.model_executor.layers.pooler import CLSPool, MeanPool, PoolingType +from vllm.model_executor.layers.pooler import (CLSPool, DispatchPooler, + MeanPool, PoolingType) from vllm.model_executor.models.bert import BertEmbeddingModel from vllm.model_executor.models.roberta import RobertaEmbeddingModel from vllm.platforms import current_platform @@ -32,8 +33,8 @@ def test_model_loading_with_params(vllm_runner): output = vllm_model.embed("Write a short story about a robot that" " dreams for the first time.\n") - model_config = vllm_model.model.llm_engine.model_config - model_tokenizer = vllm_model.model.llm_engine.tokenizer + model_config = vllm_model.llm.llm_engine.model_config + model_tokenizer = vllm_model.llm.llm_engine.tokenizer # asserts on the bert model config file assert model_config.encoder_config["max_seq_length"] == 512 @@ -49,7 +50,8 @@ def test_model_loading_with_params(vllm_runner): def check_model(model): assert isinstance(model, BertEmbeddingModel) - assert isinstance(model._pooler, CLSPool) + assert isinstance(pooler := model.pooler, DispatchPooler) + assert isinstance(pooler.poolers_by_task["embed"].pooling, CLSPool) vllm_model.apply_model(check_model) @@ -70,8 +72,8 @@ def test_roberta_model_loading_with_params(vllm_runner): output = vllm_model.embed("Write a short story about a robot that" " dreams for the first time.\n") - model_config = vllm_model.model.llm_engine.model_config - model_tokenizer = vllm_model.model.llm_engine.tokenizer + model_config = vllm_model.llm.llm_engine.model_config + model_tokenizer = vllm_model.llm.llm_engine.tokenizer # asserts on the bert model config file assert model_config.encoder_config["max_seq_length"] == 512 @@ -87,7 +89,9 @@ def test_roberta_model_loading_with_params(vllm_runner): def check_model(model): assert isinstance(model, RobertaEmbeddingModel) - assert isinstance(model._pooler, MeanPool) + assert isinstance(pooler := model.pooler, DispatchPooler) + assert isinstance(pooler.poolers_by_task["embed"].pooling, + MeanPool) vllm_model.apply_model(check_model) @@ -108,13 +112,14 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner): output = vllm_model.embed("Write a short story about a robot that" " dreams for the first time.\n") - model_tokenizer = vllm_model.model.llm_engine.tokenizer + model_tokenizer = vllm_model.llm.llm_engine.tokenizer assert model_tokenizer.tokenizer_id == model_name def check_model(model): assert isinstance(model, RobertaEmbeddingModel) assert not hasattr(model, "lm_head") - assert isinstance(model._pooler, CLSPool) + assert isinstance(pooler := model.pooler, DispatchPooler) + assert isinstance(pooler.poolers_by_task["embed"].pooling, CLSPool) vllm_model.apply_model(check_model) diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index 7d7a62eec118..ea240d227889 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -39,7 +39,7 @@ [ pytest.param( "bigscience/bloom-560m", # bloom - testing alibi slopes - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[pytest.mark.core_model], ), pytest.param( "openai-community/gpt2", # gpt2 @@ -87,7 +87,11 @@ pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( "TitanML/tiny-mixtral", # mixtral - marks=[pytest.mark.core_model, pytest.mark.cpu_model], + marks=[pytest.mark.core_model], + ), + pytest.param( + "allenai/OLMoE-1B-7B-0924-Instruct", + marks=[pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) diff --git a/tests/models/language/generation/test_gemma.py b/tests/models/language/generation/test_gemma.py index 5be4ae874e61..60a4bc14be88 100644 --- a/tests/models/language/generation/test_gemma.py +++ b/tests/models/language/generation/test_gemma.py @@ -15,13 +15,13 @@ def test_dummy_loader(vllm_runner, monkeypatch, model: str) -> None: load_format="dummy", ) as llm: if model == "google/gemma-3-4b-it": - normalizers = llm.model.collective_rpc( + normalizers = llm.llm.collective_rpc( lambda self: self.model_runner.model.language_model.model. normalizer.cpu().item()) - config = llm.model.llm_engine.model_config.hf_config.text_config + config = llm.llm.llm_engine.model_config.hf_config.text_config else: - normalizers = llm.model.collective_rpc( + normalizers = llm.llm.collective_rpc( lambda self: self.model_runner.model.model.normalizer.cpu( ).item()) - config = llm.model.llm_engine.model_config.hf_config + config = llm.llm.llm_engine.model_config.hf_config assert np.allclose(normalizers, config.hidden_size**0.5, rtol=2e-3) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index ecaae3ec1fc4..2238924c1b50 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -61,14 +61,6 @@ "tiiuae/Falcon-H1-0.5B-Base", ] -ATTN_BLOCK_SIZES = { - "ibm-ai-platform/Bamba-9B-v1": 528, - "Zyphra/Zamba2-1.2B-instruct": 80, - "nvidia/Nemotron-H-8B-Base-8K": 528, - "ibm-granite/granite-4.0-tiny-preview": 400, - "tiiuae/Falcon-H1-0.5B-Base": 800, -} - # Avoid OOM MAX_NUM_SEQS = 4 @@ -105,11 +97,6 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: - if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: - block_size = ATTN_BLOCK_SIZES[model] - else: - block_size = 16 - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: @@ -117,9 +104,7 @@ def test_models( m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enforce_eager=True, - enable_prefix_caching=False, - block_size=block_size) as vllm_model: + enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: @@ -289,7 +274,7 @@ def test_models_preemption_recompute( Tests that outputs are identical with and w/o preemptions (recompute). """ with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: - scheduler = vllm_model.model.llm_engine.scheduler[0] + scheduler = vllm_model.llm.llm_engine.scheduler[0] scheduler.ENABLE_ARTIFICIAL_PREEMPT = True preempt_vllm_outputs = vllm_model.generate_greedy( example_prompts, max_tokens) diff --git a/tests/models/language/generation/test_mistral.py b/tests/models/language/generation/test_mistral.py index c70698ede37a..81a88f2d485e 100644 --- a/tests/models/language/generation/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -238,8 +238,8 @@ def test_mistral_symbolic_languages(vllm_runner, model: str, load_format="mistral") as vllm_model: for prompt in SYMBOLIC_LANG_PROMPTS: msg = {"role": "user", "content": prompt} - outputs = vllm_model.model.chat([msg], - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.llm.chat([msg], + sampling_params=SAMPLING_PARAMS) assert "�" not in outputs[0].outputs[0].text.strip() @@ -253,11 +253,11 @@ def test_mistral_function_calling(vllm_runner, model: str, dtype: str) -> None: load_format="mistral") as vllm_model: msgs = copy.deepcopy(MSGS) - outputs = vllm_model.model.chat(msgs, - tools=TOOLS, - sampling_params=SAMPLING_PARAMS) + outputs = vllm_model.llm.chat(msgs, + tools=TOOLS, + sampling_params=SAMPLING_PARAMS) - tokenizer = vllm_model.model.get_tokenizer() + tokenizer = vllm_model.llm.get_tokenizer() tool_parser = MistralToolParser(tokenizer) model_output = outputs[0].outputs[0].text.strip() @@ -308,7 +308,7 @@ def test_mistral_guided_decoding( f"Give an example JSON for an employee profile that " f"fits this schema: {SAMPLE_JSON_SCHEMA}" }] - outputs = vllm_model.model.chat(messages, sampling_params=params) + outputs = vllm_model.llm.chat(messages, sampling_params=params) generated_text = outputs[0].outputs[0].text json_response = json.loads(generated_text) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index a83d25818584..97362f641665 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -23,14 +23,14 @@ # See #19344 MTEB_RERANK_TASKS = ["NFCorpus"] MTEB_RERANK_LANGS = ["en"] -MTEB_RERANK_TOL = 1e-3 +MTEB_RERANK_TOL = 2e-3 class VllmMtebEncoder(mteb.Encoder): def __init__(self, vllm_model): super().__init__() - self.model = vllm_model + self.llm = vllm_model self.rng = np.random.default_rng(seed=42) def encode( @@ -43,7 +43,7 @@ def encode( # issues by randomizing the order. r = self.rng.permutation(len(sentences)) sentences = [sentences[i] for i in r] - outputs = self.model.embed(sentences, use_tqdm=False) + outputs = self.llm.embed(sentences, use_tqdm=False) embeds = np.array(outputs) embeds = embeds[np.argsort(r)] return embeds @@ -61,10 +61,10 @@ def predict( queries = [s[0] for s in sentences] corpus = [s[1] for s in sentences] - outputs = self.model.score(queries, - corpus, - truncate_prompt_tokens=-1, - use_tqdm=False) + outputs = self.llm.score(queries, + corpus, + truncate_prompt_tokens=-1, + use_tqdm=False) scores = np.array(outputs) scores = scores[np.argsort(r)] return scores @@ -178,11 +178,11 @@ def mteb_test_embed_models(hf_runner, if model_info.architecture: assert (model_info.architecture - in vllm_model.model.llm_engine.model_config.architectures) + in vllm_model.llm.llm_engine.model_config.architectures) vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), MTEB_EMBED_TASKS) - vllm_dtype = vllm_model.model.llm_engine.model_config.dtype + vllm_dtype = vllm_model.llm.llm_engine.model_config.dtype with hf_runner(model_info.name, is_sentence_transformer=True, @@ -267,7 +267,9 @@ def mteb_test_rerank_models(hf_runner, vllm_runner, model_info: RerankModelInfo, vllm_extra_kwargs=None, - hf_model_callback=None): + hf_model_callback=None, + vllm_mteb_encoder=VllmMtebEncoder, + atol=MTEB_RERANK_TOL): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. @@ -282,13 +284,13 @@ def mteb_test_rerank_models(hf_runner, max_num_seqs=8, **vllm_extra_kwargs) as vllm_model: - model_config = vllm_model.model.llm_engine.model_config + model_config = vllm_model.llm.llm_engine.model_config if model_info.architecture: assert (model_info.architecture in model_config.architectures) assert model_config.hf_config.num_labels == 1 - vllm_main_score = run_mteb_rerank(VllmMtebEncoder(vllm_model), + vllm_main_score = run_mteb_rerank(vllm_mteb_encoder(vllm_model), tasks=MTEB_RERANK_TASKS, languages=MTEB_RERANK_LANGS) vllm_dtype = model_config.dtype @@ -300,4 +302,4 @@ def mteb_test_rerank_models(hf_runner, print("SentenceTransformers:", st_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_RERANK_TOL) + assert st_main_score == pytest.approx(vllm_main_score, abs=atol) diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py index 3990e8ea92c8..64a8f25220da 100644 --- a/tests/models/language/pooling/test_baai.py +++ b/tests/models/language/pooling/test_baai.py @@ -68,7 +68,6 @@ enable_test=False), RerankModelInfo("BAAI/bge-reranker-v2-m3", architecture="XLMRobertaForSequenceClassification", - dtype="float32", enable_test=False) ] diff --git a/tests/models/language/pooling/test_bge_reranker_v2_gemma.py b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py new file mode 100644 index 000000000000..7fa9485dbc7f --- /dev/null +++ b/tests/models/language/pooling/test_bge_reranker_v2_gemma.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import numpy as np +import pytest +import torch + +from tests.conftest import HfRunner + +from .mteb_utils import (RerankModelInfo, VllmMtebEncoder, + mteb_test_rerank_models) + +RERANK_MODELS = [ + RerankModelInfo("BAAI/bge-reranker-v2-gemma", + architecture="GemmaForSequenceClassification"), +] + +PROMPT = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." # noqa: E501 + + +class GemmaRerankerHfRunner(HfRunner): + + def __init__(self, + model_name: str, + dtype: str = "auto", + *args: Any, + **kwargs: Any) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + super().__init__(model_name, dtype, auto_cls=AutoModelForCausalLM) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, + padding_side='left') + self.yes_loc = self.tokenizer.convert_tokens_to_ids("Yes") + + @torch.no_grad() + def predict(self, prompts: list[list[str]], *args, + **kwargs) -> torch.Tensor: + + def get_inputs(pairs, tokenizer, prompt=None): + if prompt is None: + prompt = PROMPT + + sep = "\n" + prompt_inputs = tokenizer(prompt, + return_tensors=None, + add_special_tokens=False)["input_ids"] + sep_inputs = tokenizer(sep, + return_tensors=None, + add_special_tokens=False)["input_ids"] + inputs = [] + for query, passage in pairs: + query_inputs = tokenizer( + f"A: {query}", + return_tensors=None, + add_special_tokens=False, + truncation=True, + ) + passage_inputs = tokenizer( + f"B: {passage}", + return_tensors=None, + add_special_tokens=False, + truncation=True, + ) + item = tokenizer.prepare_for_model( + [tokenizer.bos_token_id] + query_inputs["input_ids"], + sep_inputs + passage_inputs["input_ids"], + truncation="only_second", + padding=False, + return_attention_mask=False, + return_token_type_ids=False, + add_special_tokens=False, + ) + item["input_ids"] = item[ + "input_ids"] + sep_inputs + prompt_inputs + item["attention_mask"] = [1] * len(item["input_ids"]) + inputs.append(item) + return tokenizer.pad( + inputs, + padding=True, + return_tensors="pt", + ) + + scores = [] + for query, doc, *_ in prompts: + pairs = [(query, doc)] + inputs = get_inputs(pairs, self.tokenizer) + inputs = inputs.to(self.model.device) + _n_tokens = inputs["input_ids"].shape[1] + logits = self.model(**inputs, return_dict=True).logits + _scores = (logits[:, -1, + self.yes_loc].view(-1, ).float().sigmoid()) + scores.append(_scores[0].item()) + return torch.Tensor(scores) + + +class GemmaMtebEncoder(VllmMtebEncoder): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt = PROMPT + self.query_template = "A: {query}\n" + self.document_template = "B: {doc}\n{prompt}" + + def predict( + self, + sentences: list[tuple[str, str, + Optional[str]]], # query, corpus, prompt + *args, + **kwargs, + ) -> np.ndarray: + + _sentences = [] + for query, corpus, prompt in sentences: + query = self.query_template.format(query=query) + corpus = self.document_template.format(doc=corpus, prompt=prompt) + _sentences.append((query, corpus, prompt)) + + return super().predict(_sentences, *args, **kwargs) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo, + monkeypatch) -> None: + monkeypatch.setenv("VLLM_USE_V1", "0") + + assert model_info.architecture == "GemmaForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "hf_overrides": { + "architectures": ["GemmaForSequenceClassification"], + "classifier_from_token": ["Yes"], + "method": "no_post_processing", + } + } + + mteb_test_rerank_models(GemmaRerankerHfRunner, + vllm_runner, + model_info, + vllm_extra_kwargs, + vllm_mteb_encoder=GemmaMtebEncoder) diff --git a/tests/models/language/pooling/test_embedding.py b/tests/models/language/pooling/test_embedding.py index 05fcf4101ff9..cc9e4102d5b7 100644 --- a/tests/models/language/pooling/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import os from typing import Optional import pytest @@ -29,8 +28,10 @@ def v1(run_with_both_engines): # [Decoder-only] pytest.param("BAAI/bge-multilingual-gemma2", marks=[pytest.mark.core_model]), - pytest.param("intfloat/e5-mistral-7b-instruct", - marks=[pytest.mark.core_model, pytest.mark.cpu_model]), + pytest.param( + "intfloat/e5-mistral-7b-instruct", + # CPU v1 doesn't support sliding window + marks=[pytest.mark.core_model]), # the qwen models interfere with each other (see PR # https://github.com/vllm-project/vllm/pull/18720). # To avoid this problem, for now we skip v0 since it will be @@ -38,11 +39,13 @@ def v1(run_with_both_engines): pytest.param("ssmits/Qwen2-7B-Instruct-embed-base", marks=[pytest.mark.skip_v0, pytest.mark.cpu_model]), # [Encoder-only] - pytest.param("BAAI/bge-base-en-v1.5", - marks=[ - pytest.mark.core_model, pytest.mark.cpu_model, - pytest.mark.skip_v1 - ]), + pytest.param( + "BAAI/bge-base-en-v1.5", + marks=[ + # CPU only supports V1 + pytest.mark.core_model, + pytest.mark.skip_v1 + ]), pytest.param("sentence-transformers/all-MiniLM-L12-v2", marks=[pytest.mark.skip_v1]), pytest.param("intfloat/multilingual-e5-small", @@ -61,10 +64,6 @@ def test_models( model, monkeypatch, ) -> None: - if model == "intfloat/e5-mistral-7b-instruct" and current_platform.is_cpu( - ) and os.environ.get("VLLM_USE_V1", "0") == "1": - pytest.skip("CPU V1 doesn't support sliding window") - if model == "BAAI/bge-multilingual-gemma2" and current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend diff --git a/tests/models/language/pooling/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py index c2f70bb647a4..efa119bb7659 100644 --- a/tests/models/language/pooling/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -2,9 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations -import importlib.util -from array import array - +import numpy as np import openai import pytest from scipy.spatial.distance import cosine @@ -14,10 +12,6 @@ from ....utils import RemoteOpenAIServer -# GritLM embedding implementation is only supported by XFormers backend. -pytestmark = pytest.mark.skipif(not importlib.util.find_spec("xformers"), - reason="GritLM requires XFormers") - MODEL_NAME = "parasail-ai/GritLM-7B-vllm" MAX_MODEL_LEN = 4000 @@ -26,11 +20,11 @@ def _arr(arr): """ Convert a list of integers to an array of integers. """ - return array("i", arr) + return np.array(arr) def test_find_array(): - from vllm.model_executor.models.gritlm import GritLMPooler + from vllm.model_executor.models.gritlm import GritLMMeanPool model_config = ModelConfig( MODEL_NAME, @@ -41,17 +35,19 @@ def test_find_array(): dtype="bfloat16", seed=0, ) - pooler = GritLMPooler(model_config=model_config) + pooling = GritLMMeanPool(model_config=model_config) arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 + assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert pooling._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=3) == -1 + assert pooling._find_array(arr, _arr([3, 4, 5]), end_idx=4) == 3 + assert pooling._find_array(arr, _arr([3, 5]), start_idx=0) == -1 with pytest.raises(ValueError): - pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) + pooling._find_array(arr, _arr([3, 4, 5]), start_idx=-1) def run_llm_encode( @@ -124,7 +120,7 @@ def test_gritlm_offline_embedding(vllm_runner): task="embed", max_model_len=MAX_MODEL_LEN, ) as vllm_model: - llm = vllm_model.model + llm = vllm_model.llm d_rep = run_llm_encode( llm, @@ -171,7 +167,7 @@ def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): task="generate", max_model_len=MAX_MODEL_LEN, ) as vllm_model: - llm = vllm_model.model + llm = vllm_model.llm sampling_params = SamplingParams(temperature=0.0, max_tokens=256) outputs = llm.generate(input, sampling_params=sampling_params) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 0bc189d82b8a..16c711407aea 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -18,11 +18,8 @@ ] RERANK_MODELS = [ - RerankModelInfo( - "jinaai/jina-reranker-v2-base-multilingual", - architecture="XLMRobertaForSequenceClassification", - dtype="float32", - ) + RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual", + architecture="XLMRobertaForSequenceClassification") ] @@ -90,10 +87,10 @@ def test_matryoshka( task="embed", dtype=dtype, max_model_len=None) as vllm_model: - assert vllm_model.model.llm_engine.model_config.is_matryoshka + assert vllm_model.llm.llm_engine.model_config.is_matryoshka matryoshka_dimensions = ( - vllm_model.model.llm_engine.model_config.matryoshka_dimensions) + vllm_model.llm.llm_engine.model_config.matryoshka_dimensions) assert matryoshka_dimensions is not None if dimensions not in matryoshka_dimensions: diff --git a/tests/models/language/pooling/test_mxbai_rerank.py b/tests/models/language/pooling/test_mxbai_rerank.py index a1293a95bfd5..e74c58744dd2 100644 --- a/tests/models/language/pooling/test_mxbai_rerank.py +++ b/tests/models/language/pooling/test_mxbai_rerank.py @@ -12,11 +12,9 @@ RERANK_MODELS = [ RerankModelInfo("mixedbread-ai/mxbai-rerank-base-v2", architecture="Qwen2ForSequenceClassification", - dtype="float32", enable_test=True), RerankModelInfo("mixedbread-ai/mxbai-rerank-large-v2", architecture="Qwen2ForSequenceClassification", - dtype="float32", enable_test=False) ] diff --git a/tests/models/language/pooling/test_nomic_max_model_len.py b/tests/models/language/pooling/test_nomic_max_model_len.py index 250b3a52835a..7413ef578e38 100644 --- a/tests/models/language/pooling/test_nomic_max_model_len.py +++ b/tests/models/language/pooling/test_nomic_max_model_len.py @@ -23,7 +23,7 @@ def test_default(model_info, vllm_runner): with vllm_runner(model_info.name, task="embed", max_model_len=None) as vllm_model: - model_config = vllm_model.model.llm_engine.model_config + model_config = vllm_model.llm.llm_engine.model_config if model_info.name == "nomic-ai/nomic-embed-text-v2-moe": # For nomic-embed-text-v2-moe the length is set to 512 # by sentence_bert_config.json. @@ -38,7 +38,7 @@ def test_set_max_model_len_legal(model_info, vllm_runner): # set max_model_len <= 512 with vllm_runner(model_info.name, task="embed", max_model_len=256) as vllm_model: - model_config = vllm_model.model.llm_engine.model_config + model_config = vllm_model.llm.llm_engine.model_config assert model_config.max_model_len == 256 # set 512 < max_model_len <= 2048 @@ -52,7 +52,7 @@ def test_set_max_model_len_legal(model_info, vllm_runner): else: with vllm_runner(model_info.name, task="embed", max_model_len=1024) as vllm_model: - model_config = vllm_model.model.llm_engine.model_config + model_config = vllm_model.llm.llm_engine.model_config assert model_config.max_model_len == 1024 diff --git a/tests/models/language/pooling/test_qwen3_reranker.py b/tests/models/language/pooling/test_qwen3_reranker.py index b1e8fd6294ca..9c6a833b4138 100644 --- a/tests/models/language/pooling/test_qwen3_reranker.py +++ b/tests/models/language/pooling/test_qwen3_reranker.py @@ -6,17 +6,16 @@ import torch from tests.conftest import HfRunner +from tests.utils import multi_gpu_test from .mteb_utils import RerankModelInfo, mteb_test_rerank_models RERANK_MODELS = [ RerankModelInfo("Qwen/Qwen3-Reranker-0.6B", architecture="Qwen3ForSequenceClassification", - dtype="float32", enable_test=True), RerankModelInfo("Qwen/Qwen3-Reranker-4B", architecture="Qwen3ForSequenceClassification", - dtype="float32", enable_test=False) ] @@ -89,3 +88,29 @@ def test_rerank_models_mteb(vllm_runner, model_info: RerankModelInfo) -> None: mteb_test_rerank_models(Qwen3RerankerHfRunner, vllm_runner, model_info, vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", RERANK_MODELS) +@multi_gpu_test(num_gpus=2) +def test_rerank_models_mteb_tp(vllm_runner, + model_info: RerankModelInfo) -> None: + + assert model_info.architecture == "Qwen3ForSequenceClassification" + + vllm_extra_kwargs: dict[str, Any] = { + "hf_overrides": { + "architectures": ["Qwen3ForSequenceClassification"], + "classifier_from_token": ["no", "yes"], + "is_original_qwen3_reranker": True, + }, + "tensor_parallel_size": 2, + } + + if model_info.name == "Qwen/Qwen3-Reranker-4B": + vllm_extra_kwargs["max_num_seqs"] = 1 + + mteb_test_rerank_models(Qwen3RerankerHfRunner, + vllm_runner, + model_info, + vllm_extra_kwargs, + atol=1.2e-2) diff --git a/tests/models/language/pooling/test_reward.py b/tests/models/language/pooling/test_reward.py index ec3d25ee22a9..3b7fab3ba5c9 100644 --- a/tests/models/language/pooling/test_reward.py +++ b/tests/models/language/pooling/test_reward.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + import pytest import torch import torch.nn.functional as F @@ -84,6 +86,9 @@ def test_prm_models( dtype: str, monkeypatch, ) -> None: + if current_platform.is_cpu() and os.environ.get("VLLM_USE_V1", "0") == "0": + pytest.skip("CPU only supports V1") + if current_platform.is_rocm(): # ROCm Triton FA does not currently support sliding window attention # switch to use ROCm CK FA backend diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py index 33aff1c873fc..c7399e01c735 100644 --- a/tests/models/language/pooling/test_truncation_control.py +++ b/tests/models/language/pooling/test_truncation_control.py @@ -28,7 +28,7 @@ def test_smaller_truncation_size(vllm_runner, with vllm_runner(model_name, task="embed", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.model.encode( + vllm_output = vllm_model.llm.encode( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -43,7 +43,7 @@ def test_max_truncation_size(vllm_runner, with vllm_runner(model_name, task="embed", max_model_len=max_model_len) as vllm_model: - vllm_output = vllm_model.model.encode( + vllm_output = vllm_model.llm.encode( input_str, truncate_prompt_tokens=truncate_prompt_tokens) prompt_tokens = vllm_output[0].prompt_token_ids @@ -61,7 +61,7 @@ def test_bigger_truncation_size(vllm_runner, model_name, task="embed", max_model_len=max_model_len) as vllm_model: - llm_output = vllm_model.model.encode( + llm_output = vllm_model.llm.encode( input_str, truncate_prompt_tokens=truncate_prompt_tokens) assert llm_output == f"""truncate_prompt_tokens value diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index cbc2e9c87a64..e2e35e9b2721 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -35,6 +35,8 @@ REQUIRES_V0_MODELS = [ # V1 Test: not enough KV cache space in C1. "fuyu", + # V1 Test: Deadlock issue when processing mm_inputs + "llava-onevision-transformers", ] # yapf: disable @@ -152,6 +154,7 @@ video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501 max_model_len=4096, max_num_seqs=2, + num_logprobs= 6 if current_platform.is_cpu() else 5, auto_cls=AutoModelForTextToWaveform, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner, @@ -169,6 +172,71 @@ hf_output_post_proc=model_utils.ultravox_trunc_hf_output, marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + #### Transformers fallback to test + ## To reduce test burden, we only test batching arbitrary image size + # Dynamic image length and number of patches + "llava-onevision-transformers": VLMTestInfo( + models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + max_model_len=16384, + hf_model_kwargs=model_utils.llava_onevision_hf_model_kwargs("llava-hf/llava-onevision-qwen2-0.5b-ov-hf"), # noqa: E501 + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[pytest.mark.core_model], + ), + # FIXME(Isotr0py): Enable this test after + # https://github.com/huggingface/transformers/pull/39470 released + # "idefics3-transformers": VLMTestInfo( + # models=["HuggingFaceTB/SmolVLM-256M-Instruct"], + # test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + # prompt_formatter=lambda img_prompt:f"<|begin_of_text|>User:{img_prompt}<end_of_utterance>\nAssistant:", # noqa: E501 + # img_idx_to_prompt=lambda idx: "<image>", + # max_model_len=8192, + # max_num_seqs=2, + # auto_cls=AutoModelForImageTextToText, + # hf_output_post_proc=model_utils.idefics3_trunc_hf_output, + # image_size_factors=[(0.25, 0.5, 1.0)], + # vllm_runner_kwargs={ + # "model_impl": "transformers", + # }, + # marks=[pytest.mark.core_model], + # ), + # Pixel values from processor are not 4D or 5D arrays + "qwen2_5_vl-transformers": VLMTestInfo( + models=["Qwen/Qwen2.5-VL-3B-Instruct"], + test_type=VLMTestType.IMAGE, + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + image_size_factors=[(0.25, 0.2, 0.15)], + vllm_runner_kwargs={ + "model_impl": "transformers", + }, + marks=[large_gpu_mark(min_gb=32)], + ), + # Check "auto" with fallback to transformers + "internvl-transformers": VLMTestInfo( + models=["OpenGVLab/InternVL3-1B-hf"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<IMG_CONTEXT>", + max_model_len=4096, + use_tokenizer_eos=True, + image_size_factors=[(0.25, 0.5, 1.0)], + vllm_runner_kwargs={ + "model_impl": "auto", + }, + auto_cls=AutoModelForImageTextToText, + marks=[pytest.mark.core_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -317,6 +385,7 @@ num_logprobs=10, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], auto_cls=AutoModelForImageTextToText, + marks=[large_gpu_mark(min_gb=32)], ), "glm4_1v-video": VLMTestInfo( models=["THUDM/GLM-4.1V-9B-Thinking"], @@ -330,8 +399,7 @@ inputs=custom_inputs.video_with_metadata_glm4_1v(), limit_mm_per_prompt={"video": 1}, )], - # This is needed to run on machine with 24GB VRAM - vllm_runner_kwargs={"gpu_memory_utilization": 0.95}, + marks=[large_gpu_mark(min_gb=32)], ), "h2ovl": VLMTestInfo( models = [ diff --git a/tests/models/multimodal/generation/test_maverick.py b/tests/models/multimodal/generation/test_maverick.py new file mode 100644 index 000000000000..306cf39002df --- /dev/null +++ b/tests/models/multimodal/generation/test_maverick.py @@ -0,0 +1,652 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Create a reduced-layer version of the Maverick model for testing purposes. + +This script creates a new model with fewer layers by: +1. Loading the original Maverick model configuration +2. Creating a reduced configuration +3. Generating compatible safetensors files with appropriate weights +4. Creating the necessary index files for vLLM compatibility +""" + +import json +import shutil +from pathlib import Path +from typing import Any + +import pytest +import torch +from safetensors.torch import save_file +from transformers import (AutoConfig, AutoProcessor, AutoTokenizer, + GenerationConfig) + +from vllm import LLM, SamplingParams + +from ....utils import multi_gpu_test + +# Sample prompts for testing +PROMPTS: list[str] = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +def run_maverick_serving(model: str): + """Test Llama-4-Maverick model with vLLM LLM class using CLI equivalent + options with reduced layers. + """ + + try: + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM( + model=model, + max_model_len=2048, + enforce_eager=True, + tensor_parallel_size=8, + enable_expert_parallel=True, + trust_remote_code=True, + gpu_memory_utilization=0.4, + kv_cache_dtype="fp8", + ) + + outputs = llm.generate(PROMPTS, sampling_params) + + # Print the outputs + print("\nGenerated Outputs:\n" + "-" * 60) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Output: {generated_text!r}") + print("-" * 60) + + except Exception as e: + print(f"Error initializing or running model: {e}") + raise + + +def create_reduced_maverick_model( + original_model_name: + str = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + output_dir: str = "/tmp/reduced_maverick", + text_layers: int = 4, + num_experts: int = 4, + vision_layers: int = 2, + force_recreate: bool = False, +) -> str: + """ + Create a reduced-layer version of the Maverick model. + + Args: + original_model_name: Name of the original Maverick model + output_dir: Directory to save the reduced model + text_layers: Number of text transformer layers + num_experts: Number of experts per layer + vision_layers: Number of vision transformer layers + force_recreate: Whether to recreate if output_dir already exists + + Returns: + Path to the created reduced model directory + """ + + print( + f"Creating reduced Maverick model with {text_layers} text layers and " + f"{vision_layers} vision layers...") + + # Create output directory + output_path = Path(output_dir) + if output_path.exists(): + if force_recreate: + shutil.rmtree(output_path) + else: + print(f"Output directory {output_dir} already exists. " + "Use --force-recreate to overwrite.") + return str(output_path) + + output_path.mkdir(parents=True, exist_ok=True) + + try: + print("Loading original model configuration...") + original_config = AutoConfig.from_pretrained(original_model_name, + trust_remote_code=True) + + print("Creating reduced configuration...") + reduced_config = create_reduced_config(original_config, text_layers, + num_experts, vision_layers) + + config_path = output_path / "config.json" + with open(config_path, "w") as f: + json.dump(reduced_config, f, indent=2) + print(f"Saved reduced config to {config_path}") + + print("Copying tokenizer files...") + copy_tokenizer_files(original_model_name, output_path) + + print("Creating reduced safetensors files...") + create_reduced_safetensors(original_config, reduced_config, + output_path) + + print("Creating preprocessor config...") + create_preprocessor_config(original_config, output_path) + + try: + gen_config = GenerationConfig.from_pretrained(original_model_name) + gen_config.save_pretrained(output_path) + print("Copied generation config") + except Exception as e: + print(f"Could not copy generation config: {e}") + + print(f"Successfully created reduced Maverick model at {output_path}") + return str(output_path) + + except Exception as e: + print(f"Error creating reduced model: {e}") + # Clean up on failure + if output_path.exists(): + shutil.rmtree(output_path) + raise + + +def create_reduced_config(original_config: Any, text_layers: int, + num_experts: int, + vision_layers: int) -> dict[str, Any]: + """Create a reduced configuration based on the original.""" + + # Convert config to dictionary + config_dict = original_config.to_dict() + + # Reduce text layers + if "text_config" in config_dict: + original_text_layers = config_dict["text_config"]["num_hidden_layers"] + config_dict["text_config"]["num_hidden_layers"] = text_layers + print( + f"Reduced text layers from {original_text_layers} to {text_layers}" + ) + + original_num_experts = config_dict["text_config"]["num_local_experts"] + config_dict["text_config"]["num_local_experts"] = num_experts + print( + f"Reduced num experts from {original_num_experts} to {num_experts}" + ) + + hidden_dim_divisor = 4 + + original_hidden_size = config_dict["text_config"]["hidden_size"] + new_hidden_size = original_hidden_size // hidden_dim_divisor + config_dict["text_config"]["hidden_size"] = new_hidden_size + print(f"Reduced hidden size from {original_hidden_size} to " + f"{new_hidden_size}") + + original_head_dim = config_dict["text_config"]["head_dim"] + new_head_dim = original_head_dim // hidden_dim_divisor + config_dict["text_config"]["head_dim"] = new_head_dim + print(f"Reduced head dim from {original_head_dim} to {new_head_dim}") + + # Reduce vision layers + if "vision_config" in config_dict: + original_vision_layers = config_dict["vision_config"][ + "num_hidden_layers"] + config_dict["vision_config"]["num_hidden_layers"] = vision_layers + print(f"Reduced vision layers from {original_vision_layers} " + f"to {vision_layers}") + + # Update model name to indicate it's a reduced version + config_dict["_name_or_path"] = ( + f"reduced_maverick_{text_layers}t_{vision_layers}v") + + return config_dict + + +def copy_tokenizer_files(original_model_name: str, output_path: Path) -> None: + """Copy tokenizer files from the original model.""" + + try: + tokenizer = AutoTokenizer.from_pretrained(original_model_name, + trust_remote_code=True) + tokenizer.save_pretrained(output_path) + print("Tokenizer files copied successfully") + except Exception as e: + print(f"Warning: Could not copy tokenizer files: {e}") + + +def create_preprocessor_config(original_config: Any, + output_path: Path) -> None: + """Create preprocessor_config.json for multimodal model.""" + + # Try to load the original preprocessor config + try: + processor = AutoProcessor.from_pretrained( + original_config._name_or_path + or "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + trust_remote_code=True, + ) + processor.save_pretrained(output_path) + print("Copied original preprocessor config") + return + except Exception as e: + print(f"Could not copy original preprocessor config: {e}") + raise + + +def create_reduced_safetensors(original_config: Any, reduced_config: dict[str, + Any], + output_path: Path) -> None: + """Create safetensors files with weights for the reduced model.""" + + print("Generating synthetic weights for reduced model...") + + text_config = reduced_config["text_config"] + vision_config = reduced_config["vision_config"] + + weights = {} + + print("Creating text model weights...") + weights.update(create_text_model_weights(text_config)) + + print("Creating vision model weights...") + weights.update(create_vision_model_weights(vision_config)) + + print("Creating shared model weights...") + weights.update(create_shared_weights(text_config, vision_config)) + + print("Saving weights to safetensors files...") + save_weights_to_safetensors(weights, output_path) + + +def create_text_model_weights( + text_config: dict[str, Any]) -> dict[str, torch.Tensor]: + """Create synthetic weights for the text model with MoE structure.""" + + weights = {} + + vocab_size = text_config["vocab_size"] + hidden_size = text_config["hidden_size"] + intermediate_size = text_config["intermediate_size"] + intermediate_size_mlp = text_config["intermediate_size_mlp"] + num_layers = text_config["num_hidden_layers"] + num_attention_heads = text_config["num_attention_heads"] + num_key_value_heads = text_config.get("num_key_value_heads", + num_attention_heads) + + # MoE specific parameters + num_experts = text_config.get("num_local_experts") + assert (num_experts + is not None), "num_local_experts must be specified for MoE" + + head_dim = hidden_size // num_attention_heads + + # Embedding layers + weights["language_model.model.embed_tokens.weight"] = torch.randn( + vocab_size, hidden_size, dtype=torch.float16) + + # Transformer layers + for layer_idx in range(num_layers): + layer_prefix = f"language_model.model.layers.{layer_idx}" + print(f"Creating weights for layer {layer_prefix}...") + + # Self-attention weights (separate q, k, v projections) + weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( + hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( + hidden_size, num_key_value_heads * head_dim, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( + num_key_value_heads * head_dim, hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.o_proj.weight"] = torch.randn( + hidden_size, num_attention_heads * head_dim, dtype=torch.bfloat16) + print("Self-attention weights created.") + + # Feed-forward weights - MoE pattern based on interleave_moe_layer_step + # For interleave_moe_layer_step=2: layers 1,3,5,... are MoE, layers + # 0,2,4,... are dense + interleave_step = text_config.get("interleave_moe_layer_step", 1) + is_moe_layer = (interleave_step > 0 + and (layer_idx + 1) % interleave_step == 0) + + if is_moe_layer: + # MoE layer structure + # 1. Router weights + weights[ + f"{layer_prefix}.feed_forward.router.weight"] = torch.randn( + num_experts, hidden_size, dtype=torch.float16) + + # 2. Individual expert weights (not fused) + for expert_idx in range(num_experts): + expert_prefix = ( + f"{layer_prefix}.feed_forward.experts.{expert_idx}") + + weights[f"{expert_prefix}.gate_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=torch.bfloat16) + weights[f"{expert_prefix}.up_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=torch.bfloat16) + weights[f"{expert_prefix}.down_proj.weight"] = torch.randn( + hidden_size, intermediate_size, dtype=torch.bfloat16) + + # Expert weight scales (FP8 quantization) + weights[ + f"{expert_prefix}.gate_proj.weight_scale"] = torch.ones( + intermediate_size, 1, dtype=torch.bfloat16) + weights[f"{expert_prefix}.up_proj.weight_scale"] = torch.ones( + intermediate_size, 1, dtype=torch.bfloat16) + weights[ + f"{expert_prefix}.down_proj.weight_scale"] = torch.ones( + hidden_size, 1, dtype=torch.bfloat16) + + # 3. Shared expert weights + shared_expert_prefix = f"{layer_prefix}.feed_forward.shared_expert" + weights[f"{shared_expert_prefix}.gate_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=torch.bfloat16) + weights[f"{shared_expert_prefix}.up_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=torch.bfloat16) + weights[f"{shared_expert_prefix}.down_proj.weight"] = torch.randn( + hidden_size, intermediate_size, dtype=torch.bfloat16) + print(f"MoE feed-forward weights created for layer {layer_idx}.") + else: + # Dense layer structure + weights[f"{layer_prefix}.feed_forward.gate_proj.weight"] = ( + torch.randn(intermediate_size_mlp, + hidden_size, + dtype=torch.bfloat16)) + weights[f"{layer_prefix}.feed_forward.up_proj.weight"] = ( + torch.randn(intermediate_size_mlp, + hidden_size, + dtype=torch.bfloat16)) + weights[f"{layer_prefix}.feed_forward.down_proj.weight"] = ( + torch.randn(hidden_size, + intermediate_size_mlp, + dtype=torch.bfloat16)) + print(f"Dense feed-forward weights created for layer {layer_idx}.") + + # Layer norms + weights[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16) + weights[ + f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16) + print("Layer norms created.") + + # Final layer norm and output projection + weights["language_model.model.norm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16) + weights["language_model.lm_head.weight"] = torch.randn( + vocab_size, hidden_size, dtype=torch.bfloat16) + + return weights + + +def create_vision_model_weights( + vision_config: dict[str, Any]) -> dict[str, torch.Tensor]: + """Create synthetic weights for the vision model.""" + + weights = {} + + hidden_size = vision_config["hidden_size"] + intermediate_size = vision_config["intermediate_size"] + num_layers = vision_config["num_hidden_layers"] + + # Vision transformer layers + for layer_idx in range(num_layers): + layer_prefix = f"vision_model.model.layers.{layer_idx}" + + weights[f"{layer_prefix}.self_attn.q_proj.weight"] = torch.randn( + hidden_size, hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.q_proj.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.k_proj.weight"] = torch.randn( + hidden_size, hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.k_proj.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.v_proj.weight"] = torch.randn( + hidden_size, hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.v_proj.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.o_proj.weight"] = torch.randn( + hidden_size, hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.self_attn.o_proj.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + + weights[f"{layer_prefix}.mlp.fc1.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.mlp.fc1.bias"] = torch.zeros( + intermediate_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.mlp.fc2.weight"] = torch.randn( + hidden_size, intermediate_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.mlp.fc2.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + + weights[f"{layer_prefix}.input_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.input_layernorm.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + weights[ + f"{layer_prefix}.post_attention_layernorm.weight"] = torch.ones( + hidden_size, dtype=torch.bfloat16) + weights[f"{layer_prefix}.post_attention_layernorm.bias"] = torch.zeros( + hidden_size, dtype=torch.bfloat16) + + return weights + + +def create_shared_weights( + text_config: dict[str, Any], + vision_config: dict[str, Any]) -> dict[str, torch.Tensor]: + """Create weights for shared components (vision-language connector)""" + + weights = {} + + text_hidden_size = text_config["hidden_size"] + projector_input_dim = vision_config["projector_input_dim"] + + # Vision-language connector (projects vision features to text space) + weights["multi_modal_projector.linear_1.weight"] = torch.randn( + text_hidden_size, projector_input_dim, dtype=torch.bfloat16) + + return weights + + +def save_weights_to_safetensors(weights: dict[str, torch.Tensor], + output_path: Path) -> None: + """Save weights to safetensors files and create index.""" + + # Determine how to shard the weights + max_shard_size = 5 * 1024 * 1024 * 1024 # 5GB per shard + + # Calculate sizes and create shards + shards = [] + current_shard: dict[str, torch.Tensor] = {} + current_size = 0 + + for name, tensor in weights.items(): + tensor_size = tensor.numel() * tensor.element_size() + + if current_size + tensor_size > max_shard_size and current_shard: + shards.append(current_shard) + current_shard = {} + current_size = 0 + + current_shard[name] = tensor + current_size += tensor_size + + if current_shard: + shards.append(current_shard) + + # Save shards and create index + weight_map = {} + + if len(shards) == 1: + # Single file + filename = "model.safetensors" + save_file(shards[0], output_path / filename) + weight_map = {name: filename for name in shards[0]} + print(f"Saved weights to single file: {filename}") + else: + # Multiple shards + for i, shard in enumerate(shards): + filename = f"model-{i+1:05d}-of-{len(shards):05d}.safetensors" + save_file(shard, output_path / filename) + for name in shard: + weight_map[name] = filename + print(f"Saved shard {i+1}/{len(shards)}: {filename}") + + # Create index file + index_data = { + "metadata": { + "total_size": + sum(tensor.numel() * tensor.element_size() + for tensor in weights.values()) + }, + "weight_map": weight_map, + } + + index_path = output_path / "model.safetensors.index.json" + with open(index_path, "w") as f: + json.dump(index_data, f, indent=2) + + print(f"Created index file: {index_path}") + print(f"Total model size: " + f"{index_data['metadata']['total_size'] / (1024**3):.2f} GB") + + +def run_reduced_model(model_path: str, + should_profile: bool = False, + **kwargs) -> None: + """Test the created reduced model with vLLM.""" + + print(f"\nTesting reduced model at {model_path}...") + + llm = LLM( + model=model_path, + trust_remote_code=True, + max_model_len=512, # Small context for testing + gpu_memory_utilization=0.3, # Conservative memory usage + **kwargs, + ) + + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=50) + + if should_profile: + llm.start_profile() + outputs = llm.generate(PROMPTS, sampling_params) + if should_profile: + llm.stop_profile() + + print("Test generation successful!") + for output in outputs: + print(f"Prompt: {output.prompt}") + print(f"Output: " + f"{output.outputs[0].text}") + print("-" * 40) + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize( + "original_model_name,text_layers,num_experts,vision_layers,", + [("meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", 4, 4, 2)]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.parametrize("tp,ep", [(2, True)]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_dummy_maverick( + original_model_name: str, + text_layers: int, + num_experts: int, + vision_layers: int, + enforce_eager: bool, + tp: int, + ep: bool, + output_dir: str = "/tmp/reduced_maverick", + force_recreate: bool = True, + profile: bool = False, +) -> None: + model_path = create_reduced_maverick_model( + original_model_name=original_model_name, + output_dir=output_dir, + text_layers=text_layers, + num_experts=num_experts, + vision_layers=vision_layers, + force_recreate=force_recreate, + ) + + print(f"\nReduced model created successfully at: {model_path}") + + run_reduced_model(model_path=model_path, + should_profile=profile, + enforce_eager=enforce_eager, + tensor_parallel_size=tp, + enable_expert_parallel=ep) + + +def main(): + """Main function to create and test the reduced model.""" + + import argparse + + parser = argparse.ArgumentParser( + description="Create a reduced-layer Maverick model") + parser.add_argument( + "--output-dir", + default="/tmp/reduced_maverick", + help="Output directory for the reduced model", + ) + parser.add_argument( + "--text-layers", + type=int, + default=4, + help="Number of text transformer layers", + ) + parser.add_argument("--num-experts", + type=int, + default=4, + help="Number of experts") + parser.add_argument( + "--vision-layers", + type=int, + default=2, + help="Number of vision transformer layers", + ) + parser.add_argument( + "--force-recreate", + action="store_true", + help="Force recreation if output directory exists", + ) + parser.add_argument("--test", + action="store_true", + help="Test the created model with vLLM") + parser.add_argument("--profile", + action="store_true", + help="Profile the created model with vLLM") + parser.add_argument( + "--test-original", + action="store_true", + help="Test the original model with vLLM", + ) + parser.add_argument( + "--original-model", + default="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", + help="Original model name to base the reduction on", + ) + + args = parser.parse_args() + + if args.test: + test_dummy_maverick(original_model_name=args.original_model, + output_dir=args.output_dir, + text_layers=args.text_layers, + num_experts=args.num_experts, + vision_layers=args.vision_layers, + force_recreate=args.force_recreate, + tp=2, + ep=True, + enforce_eager=True, + profile=args.profile) + + if args.test_original: + run_maverick_serving(args.original_model) + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/models/multimodal/generation/test_pixtral.py b/tests/models/multimodal/generation/test_pixtral.py index 1def825ab087..e157d6f4a79d 100644 --- a/tests/models/multimodal/generation/test_pixtral.py +++ b/tests/models/multimodal/generation/test_pixtral.py @@ -180,8 +180,7 @@ def test_chat( ) as vllm_model: outputs = [] for msg in MSGS: - output = vllm_model.model.chat(msg, - sampling_params=SAMPLING_PARAMS) + output = vllm_model.llm.chat(msg, sampling_params=SAMPLING_PARAMS) outputs.extend(output) @@ -217,7 +216,7 @@ def test_multi_modal_placeholders(vllm_runner, prompt, max_model_len=8192, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, ) as vllm_model: - outputs = vllm_model.model.generate(prompt) + outputs = vllm_model.llm.generate(prompt) assert len(outputs) == 1, f"{len(outputs)=}" output: RequestOutput = outputs[0] diff --git a/tests/models/multimodal/generation/test_voxtral.py b/tests/models/multimodal/generation/test_voxtral.py new file mode 100644 index 000000000000..b4439dfe020c --- /dev/null +++ b/tests/models/multimodal/generation/test_voxtral.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import pytest +import pytest_asyncio +from mistral_common.audio import Audio +from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, + TextChunk, UserMessage) + +from vllm.transformers_utils.tokenizer import MistralTokenizer + +from ....conftest import AudioTestAssets +from ....utils import RemoteOpenAIServer +from .test_ultravox import MULTI_AUDIO_PROMPT, run_multi_audio_test + +MODEL_NAME = "mistralai/Voxtral-Mini-3B-2507" +MISTRAL_FORMAT_ARGS = [ + "--tokenizer_mode", "mistral", "--config_format", "mistral", + "--load_format", "mistral" +] + + +@pytest.fixture() +def server(request, audio_assets: AudioTestAssets): + args = [ + "--enforce-eager", + "--limit-mm-per-prompt", + json.dumps({"audio": len(audio_assets)}), + ] + MISTRAL_FORMAT_ARGS + + with RemoteOpenAIServer(MODEL_NAME, + args, + env_dict={"VLLM_AUDIO_FETCH_TIMEOUT": + "30"}) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +def _get_prompt(audio_assets, question): + tokenizer = MistralTokenizer.from_pretrained(MODEL_NAME) + + audios = [ + Audio.from_file(str(audio_assets[i].get_local_path()), strict=False) + for i in range(len(audio_assets)) + ] + audio_chunks = [ + AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios + ] + + text_chunk = TextChunk(text=question) + messages = [UserMessage(content=[*audio_chunks, text_chunk]).to_openai()] + + return tokenizer.apply_chat_template(messages=messages) + + +@pytest.mark.core_model +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_with_multiple_audios(vllm_runner, + audio_assets: AudioTestAssets, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + vllm_prompt = _get_prompt(audio_assets, MULTI_AUDIO_PROMPT) + run_multi_audio_test( + vllm_runner, + [(vllm_prompt, [audio.audio_and_sample_rate + for audio in audio_assets])], + MODEL_NAME, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tokenizer_mode="mistral", + ) + + +@pytest.mark.asyncio +async def test_online_serving(client, audio_assets: AudioTestAssets): + """Exercises online serving with/without chunked prefill enabled.""" + + def asset_to_chunk(asset): + audio = Audio.from_file(str(asset.get_local_path()), strict=False) + audio.format = "wav" + audio_dict = AudioChunk.from_audio(audio).to_openai() + return audio_dict + + audio_chunks = [asset_to_chunk(asset) for asset in audio_assets] + messages = [{ + "role": + "user", + "content": [ + *audio_chunks, + { + "type": + "text", + "text": + f"What's happening in these {len(audio_assets)} audio clips?" + }, + ], + }] + + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10) + + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "length" diff --git a/tests/models/multimodal/generation/test_whisper.py b/tests/models/multimodal/generation/test_whisper.py index 363d55153aac..4a65e8c95204 100644 --- a/tests/models/multimodal/generation/test_whisper.py +++ b/tests/models/multimodal/generation/test_whisper.py @@ -106,7 +106,7 @@ def run_test( tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, ) as vllm_model: - llm = vllm_model.model + llm = vllm_model.llm sampling_params = SamplingParams( temperature=0, diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index 8c83d8f8a8a2..cf8962ce4975 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -85,7 +85,7 @@ def run_test( enforce_eager=enforce_eager, task=task, **vllm_runner_kwargs_) as vllm_model: - tokenizer = vllm_model.model.get_tokenizer() + tokenizer = vllm_model.llm.get_tokenizer() vllm_kwargs: dict[str, Any] = {} if get_stop_token_ids is not None: diff --git a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py index f889eea5e839..a6f5aeccf94e 100644 --- a/tests/models/multimodal/pooling/test_dse_qwen2_vl.py +++ b/tests/models/multimodal/pooling/test_dse_qwen2_vl.py @@ -96,7 +96,7 @@ def _run_test( dtype=dtype, enforce_eager=True, max_model_len=8192) as vllm_model: - tokenizer = vllm_model.model.get_tokenizer() + tokenizer = vllm_model.llm.get_tokenizer() texts = [ # this is necessary because vllm_model.embed will not apply any # templating to the prompt, and therefore lacks an image_pad diff --git a/tests/models/multimodal/pooling/test_jinavl_reranker.py b/tests/models/multimodal/pooling/test_jinavl_reranker.py new file mode 100644 index 000000000000..712b6801de45 --- /dev/null +++ b/tests/models/multimodal/pooling/test_jinavl_reranker.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Union + +import pytest +from transformers import AutoModel + +from vllm.entrypoints.chat_utils import ChatCompletionContentPartImageParam +from vllm.entrypoints.score_utils import ScoreMultiModalParam + +from ....conftest import HfRunner, VllmRunner + +model_name = "jinaai/jina-reranker-m0" + +mm_processor_kwargs = { + "min_pixels": 3136, + "max_pixels": 602112, +} + +limit_mm_per_prompt = {"image": 2} + + +def vllm_reranker( + vllm_runner: type[VllmRunner], + model_name: str, + dtype: str, + query_strs: list[str], + document_strs: list[str], + query_type: str = "text", + doc_type: str = "text", +): + + def create_image_param(url: str) -> ChatCompletionContentPartImageParam: + return {"type": "image_url", "image_url": {"url": f"{url}"}} + + query: Union[list[str], ScoreMultiModalParam] + if query_type == "text": + query = query_strs + elif query_type == "image": + query = ScoreMultiModalParam( + content=[create_image_param(url) for url in query_strs]) + + documents: Union[list[str], ScoreMultiModalParam] + if doc_type == "text": + documents = document_strs + elif doc_type == "image": + documents = ScoreMultiModalParam( + content=[create_image_param(url) for url in document_strs]) + + with vllm_runner( + model_name, + task="score", + dtype=dtype, + max_num_seqs=2, + max_model_len=2048, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, + ) as vllm_model: + outputs = vllm_model.llm.score(query, documents) + + return [output.outputs.score for output in outputs] + + +def hf_reranker( + hf_runner: type[HfRunner], + model_name: str, + dtype: str, + query_strs: list[str], + document_strs: list[str], + query_type: str = "text", + doc_type: str = "text", +): + checkpoint_to_hf_mapper = { + "visual.": "model.visual.", + "model.": "model.language_model.", + } + + data_pairs = [[query_strs[0], d] for d in document_strs] + + with hf_runner( + model_name, + dtype=dtype, + trust_remote_code=True, + auto_cls=AutoModel, + model_kwargs={"key_mapping": checkpoint_to_hf_mapper}, + ) as hf_model: + return hf_model.model.compute_score(data_pairs, + max_length=2048, + query_type=query_type, + doc_type=doc_type) + + +# Visual Documents Reranking +@pytest.mark.parametrize("model_name", [model_name]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_text_image(hf_runner, vllm_runner, model_name, dtype): + query = ["slm markdown"] + documents = [ + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png", + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", + ] + + hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, + "text", "image") + vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, + documents, "text", "image") + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) + + +# Textual Documents Reranking +@pytest.mark.parametrize("model_name", [model_name]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_text_text(hf_runner, vllm_runner, model_name, dtype): + query = ["slm markdown"] + documents = [ + """We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient + web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML + into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding + large language models. The models effectiveness results from two key innovations: (1) a three-stage + data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, + refining, and critiquing web content extraction; and (2) a unified training framework combining + continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that + ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated + benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly + lower computational requirements.""", # noqa: E501 + "数据提取么?为什么不用正则啊,你用正则不就全解决了么?", + ] + hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, + "text", "text") + vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, + documents, "text", "text") + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) + + +# Image Querying for Textual Documents +@pytest.mark.parametrize("model_name", [model_name]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_image_text(hf_runner, vllm_runner, model_name, dtype): + query = [ + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + ] + documents = [ + """We present ReaderLM-v2, a compact 1.5 billion parameter language model designed for efficient + web content extraction. Our model processes documents up to 512K tokens, transforming messy HTML + into clean Markdown or JSON formats with high accuracy -- making it an ideal tool for grounding + large language models. The models effectiveness results from two key innovations: (1) a three-stage + data synthesis pipeline that generates high quality, diverse training data by iteratively drafting, + refining, and critiquing web content extraction; and (2) a unified training framework combining + continuous pre-training with multi-objective optimization. Intensive evaluation demonstrates that + ReaderLM-v2 outperforms GPT-4o-2024-08-06 and other larger models by 15-20% on carefully curated + benchmarks, particularly excelling at documents exceeding 100K tokens, while maintaining significantly + lower computational requirements.""", # noqa: E501 + "数据提取么?为什么不用正则啊,你用正则不就全解决了么?", + ] + + hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, + "image", "text") + vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, + documents, "image", "text") + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) + + +# Image Querying for Image Documents +@pytest.mark.parametrize("model_name", [model_name]) +@pytest.mark.parametrize("dtype", ["half"]) +def test_model_image_image(hf_runner, vllm_runner, model_name, dtype): + query = [ + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png" + ] + documents = [ + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/handelsblatt-preview.png", + "https://raw.githubusercontent.com/jina-ai/multimodal-reranker-test/main/paper-11.png", + ] + + hf_outputs = hf_reranker(hf_runner, model_name, dtype, query, documents, + "image", "image") + vllm_outputs = vllm_reranker(vllm_runner, model_name, dtype, query, + documents, "image", "image") + + assert hf_outputs[0] == pytest.approx(vllm_outputs[0], rel=0.02) + assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.02) diff --git a/tests/models/multimodal/pooling/test_prithvi_mae.py b/tests/models/multimodal/pooling/test_prithvi_mae.py new file mode 100644 index 000000000000..f08d83c08212 --- /dev/null +++ b/tests/models/multimodal/pooling/test_prithvi_mae.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm.utils import set_default_torch_num_threads + +from ....conftest import VllmRunner + + +def generate_test_mm_data(): + mm_data = { + "pixel_values": torch.full((6, 512, 512), 1.0, dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), + } + return mm_data + + +def _run_test( + vllm_runner: type[VllmRunner], + model: str, +) -> None: + + prompt = [ + { + # This model deals with no text input + "prompt_token_ids": [1], + "multi_modal_data": generate_test_mm_data(), + } for _ in range(10) + ] + + with ( + set_default_torch_num_threads(1), + vllm_runner( + model, + task="embed", + dtype=torch.float16, + enforce_eager=True, + skip_tokenizer_init=True, + # Limit the maximum number of sequences to avoid the + # test going OOM during the warmup run + max_num_seqs=32, + ) as vllm_model, + ): + vllm_model.encode(prompt) + + +MODELS = ["christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"] + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", MODELS) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, +) -> None: + _run_test( + vllm_runner, + model, + ) diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 0f33225eda2d..fd5842523178 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -159,6 +159,7 @@ def _test_processing_correctness( _ADD_SPECIAL_TOKENS_OVERRIDES = { "mllama": False, "ovis": False, + "paligemma": False, "ultravox": False, "whisper": False, } @@ -290,6 +291,7 @@ def _test_processing_correctness_one( "allenai/Molmo-7B-D-0924", "allenai/Molmo-7B-O-0924", "nvidia/NVLM-D-72B", + "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", "AIDC-AI/Ovis1.6-Gemma2-9B", "AIDC-AI/Ovis1.6-Llama3.2-3B", "AIDC-AI/Ovis2-1B", diff --git a/tests/models/multimodal/processing/test_nemotron_vl.py b/tests/models/multimodal/processing/test_nemotron_vl.py new file mode 100644 index 000000000000..3ce88bc427f5 --- /dev/null +++ b/tests/models/multimodal/processing/test_nemotron_vl.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for Nemotron-Nano-VL's multimodal preprocessing kwargs.""" +from collections.abc import Mapping +from typing import Optional + +import pytest +from PIL import Image +from transformers import PretrainedConfig + +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import rescale_image_size +from vllm.multimodal.processing import BaseMultiModalProcessor + +from ....conftest import ImageTestAssets +from ...utils import build_model_context + + +def _get_expected_num_patches( + config: PretrainedConfig, + image: Image.Image, + num_imgs: int, + min_num: int, + max_num: int, +): + from vllm.model_executor.models.internvl import ( + calculate_internvl_targets, get_internvl_target_ratios) + + width, height = image.size + + blocks, _, _ = calculate_internvl_targets( + orig_width=width, + orig_height=height, + target_ratios=get_internvl_target_ratios( + min_num, + max_num, + ), + image_size=config.force_image_size, + use_thumbnail=False, + ) + expected_num_patches = blocks + + if config.use_thumbnail and expected_num_patches > 1: + expected_num_patches += 1 + + return expected_num_patches + + +def _run_check( + processor: BaseMultiModalProcessor, + images: list[Image.Image], + min_num: int, + max_num: int, + mm_processor_kwargs: Mapping[str, object], +): + tokenizer = processor.info.get_tokenizer() + config = processor.info.get_hf_config() + image_processor = processor.info.get_image_processor() + + config.use_thumbnail = image_processor.use_thumbnail + prompt = "<image>" * len(images) + mm_data = {"image": images} + + total_expected_num_patches = sum( + _get_expected_num_patches(config, image, len(images), min_num, max_num) + for image in images) + print(total_expected_num_patches) + processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.convert_tokens_to_ids("<image>") + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape + print("Image token count:", img_tok_count, "Pixel shape:", pixel_shape) + assert img_tok_count == 256 * total_expected_num_patches + assert pixel_shape[0] == total_expected_num_patches + + +@pytest.mark.parametrize("model_id", + ["nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"]) +@pytest.mark.parametrize( + "size_factors", + [ + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + [4.0, 2.0, 1.0], + ], +) +@pytest.mark.parametrize( + ("min_dynamic_patch", "max_dynamic_patch"), + [(1, 1), (1, 2), (1, 4), (1, 8), (2, 4), (4, 8)], +) +@pytest.mark.parametrize("dynamic_image_size", [True, False]) +@pytest.mark.parametrize("kwargs_on_init", [True, False]) +def test_processor_override( + model_id: str, + image_assets: ImageTestAssets, + size_factors: list[int], + min_dynamic_patch: int, + max_dynamic_patch: int, + dynamic_image_size: Optional[bool], + kwargs_on_init: bool, +): + mm_processor_kwargs = { + "min_dynamic_patch": min_dynamic_patch, + "max_dynamic_patch": max_dynamic_patch, + "dynamic_image_size": dynamic_image_size, + } + + ctx = build_model_context( + model_id, + mm_processor_kwargs=mm_processor_kwargs if kwargs_on_init else None, + limit_mm_per_prompt={"image": len(size_factors)}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + hf_processor_mm_kwargs = {} if kwargs_on_init else mm_processor_kwargs + + min_num = min_dynamic_patch if dynamic_image_size else 1 + max_num = max_dynamic_patch if dynamic_image_size else 1 + + _run_check( + processor, + [ + rescale_image_size(image_assets[0].pil_image, f) + for f in size_factors + ], + min_num, + max_num, + hf_processor_mm_kwargs, + ) diff --git a/tests/models/multimodal/processing/test_transformers.py b/tests/models/multimodal/processing/test_transformers.py new file mode 100644 index 000000000000..c7d1b5271ff7 --- /dev/null +++ b/tests/models/multimodal/processing/test_transformers.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.assets.image import ImageAsset +from vllm.config import ModelConfig +from vllm.multimodal import MULTIMODAL_REGISTRY + + +# yapf: disable +@pytest.mark.parametrize("model_id", + ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +def test_multimodal_processor(model_id): + model_config = ModelConfig( + model=model_id, + model_impl="transformers", + ) + + mm_processor = MULTIMODAL_REGISTRY.create_processor(model_config, ) + + image_pil = ImageAsset('cherry_blossom').pil_image + mm_data = {"image": image_pil} + str_prompt = "<|im_start|>user <image>\nWhat is the content of this image?<|im_end|><|im_start|>assistant\n" # noqa: E501 + str_processed_inputs = mm_processor.apply( + prompt=str_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + ids_prompt = [ + 151644, 872, 220, 151646, 198, 3838, 374, 279, 2213, 315, 419, 2168, + 30, 151645, 151644, 77091, 198 + ] + ids_processed_inputs = mm_processor.apply( + prompt=ids_prompt, + mm_data=mm_data, + hf_processor_mm_kwargs={}, + ) + + assert str_processed_inputs["prompt"] == ids_processed_inputs["prompt"] diff --git a/tests/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py similarity index 82% rename from tests/quantization/test_bitsandbytes.py rename to tests/models/quantization/test_bitsandbytes.py index 363daa6d27ef..e53902cdb8f4 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/models/quantization/test_bitsandbytes.py @@ -13,8 +13,8 @@ from tests.quantization.utils import is_quant_method_supported -from ..models.utils import check_embeddings_close -from ..utils import compare_two_settings, create_new_process_for_each_test +from ...utils import compare_two_settings, multi_gpu_test +from ..utils import check_embeddings_close, check_logprobs_close models_4bit_to_test = [ ("facebook/opt-125m", "quantize opt model inflight"), @@ -26,6 +26,10 @@ ("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"), ] +models_4bit_to_moe_test = [ + ("allenai/OLMoE-1B-7B-0125-Instruct", "quantize moe model inflight"), +] + models_pre_qaunt_4bit_to_test = [ ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', 'read pre-quantized 4-bit FP4 model'), @@ -42,7 +46,6 @@ @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -@create_new_process_for_each_test() def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -56,7 +59,6 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_qaunt_4bit_to_test) -@create_new_process_for_each_test() def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -68,7 +70,6 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_pre_quant_8bit_to_test) -@create_new_process_for_each_test() def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -76,12 +77,10 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, True) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason='Test requires at least 2 GPUs.') @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -@create_new_process_for_each_test() +@multi_gpu_test(num_gpus=2) def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, model_name, description) -> None: @@ -96,12 +95,10 @@ def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, vllm_tp_size=2) -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason='Test requires at least 2 GPUs.') @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_test) -@create_new_process_for_each_test() +@multi_gpu_test(num_gpus=2) def test_load_pp_4bit_bnb_model(model_name, description) -> None: common_args = [ "--disable-log-stats", @@ -122,12 +119,40 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None: compare_two_settings(model_name, common_args, pp_args) +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_moe_test) +def test_4bit_bnb_moe_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + )) + with vllm_runner(model_name, + quantization='bitsandbytes', + enforce_eager=False) as llm: + vllm_outputs = llm.generate_greedy_logprobs(example_prompts, + max_tokens=32, + num_logprobs=5) + + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + transformers_outputs = llm.generate_greedy_logprobs_limit( + example_prompts, max_tokens=32, num_logprobs=5) + check_logprobs_close( + outputs_0_lst=transformers_outputs, + outputs_1_lst=vllm_outputs, + name_0="transformers", + name_1="vllm", + ) + + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') @pytest.mark.parametrize("model_name, description", models_4bit_to_embedding_test) @pytest.mark.parametrize("dtype", ["half"]) -@create_new_process_for_each_test() def test_4bit_bnb_embedding_model( model_name, description, @@ -146,6 +171,13 @@ def test_4bit_bnb_embedding_model( example_prompts = [str(s).strip() for s in example_prompts] # Inflight 4bit quantization + with vllm_runner(model_name, + task="embed", + dtype=dtype, + gpu_memory_utilization=0.5, + quantization="bitsandbytes") as vllm_model: + vllm_outputs = vllm_model.embed(example_prompts) + hf_model_kwargs = dict(quantization_config=BitsAndBytesConfig( load_in_4bit=True)) with hf_runner( @@ -156,12 +188,6 @@ def test_4bit_bnb_embedding_model( ) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model_name, - task="embed", - dtype=dtype, - gpu_memory_utilization=0.5, - quantization="bitsandbytes") as vllm_model: - vllm_outputs = vllm_model.embed(example_prompts) check_embeddings_close( embeddings_0_lst=hf_outputs, embeddings_1_lst=vllm_outputs, @@ -189,7 +215,8 @@ def validate_generated_texts(hf_runner, model_name, pre_quant=False, hf_model_kwargs=None, - vllm_tp_size=1): + vllm_tp_size=1, + max_tokens=8): # NOTE: run vLLM first, as it requires a clean process # when using distributed inference @@ -197,7 +224,8 @@ def validate_generated_texts(hf_runner, quantization=None if pre_quant else 'bitsandbytes', tensor_parallel_size=vllm_tp_size, enforce_eager=False) as llm: - vllm_outputs = llm.generate_greedy(prompts, 8) + + vllm_outputs = llm.generate_greedy(prompts, max_tokens) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") # Clean up the GPU memory for the next test @@ -209,19 +237,17 @@ def validate_generated_texts(hf_runner, # Run with HF runner with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: - hf_outputs = llm.generate_greedy(prompts, 8) + hf_outputs = llm.generate_greedy(prompts, max_tokens) hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") # Clean up the GPU memory for the next test gc.collect() torch.cuda.empty_cache() - # Compare the generated strings for hf_log, vllm_log in zip(hf_logs, vllm_logs): hf_str = hf_log["generated_text"] vllm_str = vllm_log["generated_text"] prompt = hf_log["prompt"] - assert hf_str == vllm_str, (f"Model: {model_name}" f"Mismatch between HF and vLLM outputs:\n" f"Prompt: {prompt}\n" diff --git a/tests/models/quantization/test_modelopt.py b/tests/models/quantization/test_modelopt.py index 6ad526cc893f..e23d4d9d211d 100644 --- a/tests/models/quantization/test_modelopt.py +++ b/tests/models/quantization/test_modelopt.py @@ -45,7 +45,7 @@ reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: - model = LLM( + llm = LLM( model=model_name, max_model_len=MAX_MODEL_LEN, trust_remote_code=True, @@ -68,9 +68,9 @@ def test_models(example_prompts, model_name) -> None: # Note: these need to be run 1 at a time due to numerical precision, # since the expected strs were generated this way. for prompt in formatted_prompts: - outputs = model.generate(prompt, params) + outputs = llm.generate(prompt, params) generations.append(outputs[0].outputs[0].text) - del model + del llm print(model_name, generations) expected_strs = EXPECTED_STRS_MAP[model_name] diff --git a/tests/models/quantization/test_nvfp4.py b/tests/models/quantization/test_nvfp4.py index b95dad9a4eff..b3c217e729e4 100644 --- a/tests/models/quantization/test_nvfp4.py +++ b/tests/models/quantization/test_nvfp4.py @@ -46,7 +46,7 @@ reason="modelopt_fp4 is not supported on this GPU type.") @pytest.mark.parametrize("model_name", MODELS) def test_models(example_prompts, model_name) -> None: - model = LLM( + llm = LLM( model=model_name, max_model_len=MAX_MODEL_LEN, trust_remote_code=True, @@ -69,9 +69,9 @@ def test_models(example_prompts, model_name) -> None: # Note: these need to be run 1 at a time due to numerical precision, # since the expected strs were generated this way. for prompt in formatted_prompts: - outputs = model.generate(prompt, params) + outputs = llm.generate(prompt, params) generations.append(outputs[0].outputs[0].text) - del model + del llm print(model_name, generations) expected_strs = EXPECTED_STRS_MAP[model_name] diff --git a/tests/models/registry.py b/tests/models/registry.py index aba01cefe993..84ca0bc60003 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -135,12 +135,16 @@ def check_available_online( trust_remote_code=True), "AquilaForCausalLM": _HfExamplesInfo("BAAI/AquilaChat2-7B", trust_remote_code=True), + "ArceeForCausalLM": _HfExamplesInfo("arcee-ai/AFM-4.5B-Base", + is_available_online=False), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", trust_remote_code=True), + "BailingMoeForCausalLM": _HfExamplesInfo("inclusionAI/Ling-lite-1.5", + trust_remote_code=True), "BambaForCausalLM": _HfExamplesInfo("ibm-ai-platform/Bamba-9B", extras={"tiny": "hmellor/tiny-random-BambaForCausalLM"}), # noqa: E501 "BloomForCausalLM": _HfExamplesInfo("bigscience/bloom-560m", @@ -163,10 +167,11 @@ def check_available_online( "DeepseekV3ForCausalLM": _HfExamplesInfo("deepseek-ai/DeepSeek-V3", # noqa: E501 trust_remote_code=True), "Ernie4_5_ForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-0.3B-PT", - trust_remote_code=True), + min_transformers_version="4.54"), "Ernie4_5_MoeForCausalLM": _HfExamplesInfo("baidu/ERNIE-4.5-21B-A3B-PT", - trust_remote_code=True), + min_transformers_version="4.54"), "ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501 + "Exaone4ForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-4.0-32B"), # noqa: E501 "Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501 "FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"), "FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base", @@ -194,6 +199,8 @@ def check_available_online( trust_remote_code=True), "HunYuanMoEV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-A13B-Instruct", trust_remote_code=True), + "HunYuanDenseV1ForCausalLM":_HfExamplesInfo("tencent/Hunyuan-7B-Instruct-0124", + trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", trust_remote_code=True), "InternLM2ForCausalLM": _HfExamplesInfo("internlm/internlm2-chat-7b", @@ -218,6 +225,8 @@ def check_available_online( trust_remote_code=True), "MiniCPM3ForCausalLM": _HfExamplesInfo("openbmb/MiniCPM3-4B", trust_remote_code=True), + "MiniMaxForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01-hf", + min_transformers_version="4.53"), "MiniMaxText01ForCausalLM": _HfExamplesInfo("MiniMaxAI/MiniMax-Text-01", trust_remote_code=True, revision="a59aa9cbc53b9fb8742ca4e9e1531b9802b6fdc3"), # noqa: E501 @@ -242,10 +251,10 @@ def check_available_online( "PersimmonForCausalLM": _HfExamplesInfo("adept/persimmon-8b-chat"), "PhiForCausalLM": _HfExamplesInfo("microsoft/phi-2"), "Phi3ForCausalLM": _HfExamplesInfo("microsoft/Phi-3-mini-4k-instruct"), - # Blocksparse attention not supported in V1 yet - "Phi3SmallForCausalLM": _HfExamplesInfo("microsoft/Phi-3-small-8k-instruct", - trust_remote_code=True, - v0_only=True), + "Phi4FlashForCausalLM": _HfExamplesInfo("microsoft/Phi-4-mini-flash-reasoning", # noqa: E501 + trust_remote_code=True, + v0_only=True, + max_model_len=10240), "PhiMoEForCausalLM": _HfExamplesInfo("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True), "Plamo2ForCausalLM": _HfExamplesInfo("pfnet/plamo-2-1b", @@ -257,7 +266,6 @@ def check_available_online( "Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"), "Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"), "Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"), - "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 "RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"), "StableLMEpochForCausalLM": _HfExamplesInfo("stabilityai/stablelm-zephyr-3b"), # noqa: E501 "StableLmForCausalLM": _HfExamplesInfo("stabilityai/stablelm-3b-4e1t"), @@ -284,7 +292,6 @@ def check_available_online( # [Text-only] "BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5", v0_only=True), "Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2", v0_only=True), # noqa: E501 - "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), @@ -303,7 +310,6 @@ def check_available_online( "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"), - "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2", v0_only=True), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1", v0_only=True), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-small", v0_only=True), # noqa: E501 @@ -316,12 +322,27 @@ def check_available_online( is_available_online=False), # noqa: E501 } -_CROSS_ENCODER_EXAMPLE_MODELS = { - # [Text-only] +_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS = { + # [Decoder-only] + "GPT2ForSequenceClassification": _HfExamplesInfo("nie3e/sentiment-polish-gpt2-small"), # noqa: E501 + + # [Cross-encoder] "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2", v0_only=True), # noqa: E501 + "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501 "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501 - "ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501 +} + +_AUTOMATIC_CONVERTED_MODELS = { + # Use as_seq_cls_model for automatic conversion + "GemmaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-gemma", # noqa: E501 + v0_only=True, + hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501 + "classifier_from_token": ["Yes"], # noqa: E501 + "method": "no_post_processing"}), # noqa: E501 + "LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501 + "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 + "Qwen3ForSequenceClassification": _HfExamplesInfo("tomaarsen/Qwen3-Reranker-0.6B-seq-cls"), # noqa: E501 } _MULTIMODAL_EXAMPLE_MODELS = { @@ -343,6 +364,9 @@ def check_available_online( trust_remote_code=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501 "Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501 + "Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4.5", + min_transformers_version="4.54", + is_available_online=False), # noqa: E501 "H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m", extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501 max_transformers_version="4.48", # noqa: E501 @@ -387,6 +411,8 @@ def check_available_online( trust_remote_code=True), "NVLM_D": _HfExamplesInfo("nvidia/NVLM-D-72B", trust_remote_code=True), + "Llama_Nemotron_Nano_VL" : _HfExamplesInfo("nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1", # noqa: E501 + trust_remote_code=True), "PaliGemmaForConditionalGeneration": _HfExamplesInfo("google/paligemma-3b-mix-224", # noqa: E501 extras={"v2": "google/paligemma2-3b-ft-docci-448"}), # noqa: E501 "Phi3VForCausalLM": _HfExamplesInfo("microsoft/Phi-3-vision-128k-instruct", @@ -407,7 +433,8 @@ def check_available_online( hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}), # noqa: E501 "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 - "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 + max_model_len=4096), "Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B"), "Qwen2_5OmniForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B-AWQ"), # noqa: E501 "SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"), @@ -418,6 +445,12 @@ def check_available_online( hf_overrides={"architectures": ["TarsierForConditionalGeneration"]}), # noqa: E501 "Tarsier2ForConditionalGeneration": _HfExamplesInfo("omni-research/Tarsier2-Recap-7b", # noqa: E501 hf_overrides={"architectures": ["Tarsier2ForConditionalGeneration"]}), # noqa: E501 + "VoxtralForConditionalGeneration": _HfExamplesInfo( + "mistralai/Voxtral-Mini-3B-2507", + min_transformers_version="4.54", + # disable this temporarily until we support HF format + is_available_online=False, + ), # [Encoder-decoder] # Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer # Therefore, we borrow the BartTokenizer from the original Bart model @@ -426,15 +459,18 @@ def check_available_online( trust_remote_code=True), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 + # [Cross-encoder] + "JinaVLForRanking": _HfExamplesInfo("jinaai/jina-reranker-m0"), # noqa: E501 } + _SPECULATIVE_DECODING_EXAMPLE_MODELS = { - "EAGLEModel": _HfExamplesInfo("JackFram/llama-68m", - speculative_model="abhigoyal/vllm-eagle-llama-68m-random"), # noqa: E501 "MedusaModel": _HfExamplesInfo("JackFram/llama-68m", speculative_model="abhigoyal/vllm-medusa-llama-68m-random"), # noqa: E501 - "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", - speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 + # Temporarily disabled. + # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. + # "MLPSpeculatorPreTrainedModel": _HfExamplesInfo("JackFram/llama-160m", + # speculative_model="ibm-ai-platform/llama-160m-accelerator"), # noqa: E501 "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), @@ -446,24 +482,34 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "EagleLlama4ForCausalLM": _HfExamplesInfo( + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + trust_remote_code=True, + speculative_model="morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", + tokenizer="meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501 "EagleMiniCPMForCausalLM": _HfExamplesInfo("openbmb/MiniCPM-1B-sft-bf16", trust_remote_code=True, is_available_online=False, speculative_model="openbmb/MiniCPM-2B-sft-bf16", tokenizer="openbmb/MiniCPM-2B-sft-bf16"), + "Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4.5", + speculative_model="THUDM/GLM-4.5", + min_transformers_version="4.54", + is_available_online=False), "MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL", trust_remote_code=True, speculative_model="XiaomiMiMo/MiMo-7B-RL") } _TRANSFORMERS_MODELS = { - "TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 + "TransformersForCausalLM": _HfExamplesInfo("hmellor/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 + "TransformersForMultimodalLM": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), } _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, - **_CROSS_ENCODER_EXAMPLE_MODELS, + **_SEQUENCE_CLASSIFICATION_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, **_TRANSFORMERS_MODELS, @@ -495,4 +541,5 @@ def find_hf_info(self, model_id: str) -> _HfExamplesInfo: raise ValueError(f"No example model defined for {model_id}") -HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) \ No newline at end of file +HF_EXAMPLE_MODELS = HfExampleModels(_EXAMPLE_MODELS) +AUTO_EXAMPLE_MODELS = HfExampleModels(_AUTOMATIC_CONVERTED_MODELS) diff --git a/tests/models/test_initialization.py b/tests/models/test_initialization.py index 25bc96bf3266..14d243012b2f 100644 --- a/tests/models/test_initialization.py +++ b/tests/models/test_initialization.py @@ -12,20 +12,36 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config from vllm.v1.engine.core import EngineCore as V1EngineCore -from .registry import HF_EXAMPLE_MODELS - - -@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) -def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): - model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch) +from ..utils import create_new_process_for_each_test +from .registry import AUTO_EXAMPLE_MODELS, HF_EXAMPLE_MODELS, HfExampleModels + + +@create_new_process_for_each_test() +def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch, + EXAMPLE_MODELS: HfExampleModels): + """The reason for using create_new_process_for_each_test is to avoid + the WARNING: + "We must use the 'spawn' multiprocessing start method. Overriding + VLLM_WORKER_MULTIPROC_METHOD to 'spawn'." + The spawn process causes the _initialize_kv_caches_v1 function below to + become ineffective. + """ + + model_info = EXAMPLE_MODELS.get_hf_info(model_arch) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") # FIXME: Possible memory leak in the previous tests? - if model_arch in ("GraniteSpeechForConditionalGeneration", + if model_arch in ("Glm4vForConditionalGeneration", + "GraniteSpeechForConditionalGeneration", "KimiVLForConditionalGeneration"): pytest.skip("Avoid OOM") + if model_arch in ("Llama4ForCausalLM", "EagleLlama4ForCausalLM"): + from vllm.model_executor.models.llama4 import Llama4ForCausalLM + from vllm.model_executor.models.registry import ModelRegistry + ModelRegistry.register_model("Llama4ForCausalLM", Llama4ForCausalLM) + # Avoid OOM and reduce initialization time by only using 1 layer def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: hf_config.update(model_info.hf_overrides) @@ -33,13 +49,18 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: text_config = hf_config.get_text_config() # Ensure at least 2 expert per group - # Since `grouped_topk` assums top-2 + # Since `grouped_topk` assumes top-2 n_group = getattr(text_config, 'n_group', None) num_experts = n_group * 2 if n_group is not None else 2 + # we use three layers for Gemma-3n to check + # both normal layer and kv_shared_layer + num_hidden_layers = (3 if model_arch + == "Gemma3nForConditionalGeneration" else 1) + text_config.update({ "num_layers": 1, - "num_hidden_layers": 1, + "num_hidden_layers": num_hidden_layers, "num_experts": num_experts, "num_experts_per_tok": 2, "num_local_experts": num_experts, @@ -47,6 +68,8 @@ def hf_overrides(hf_config: PretrainedConfig) -> PretrainedConfig: "first_k_dense_replace": 0, # To avoid OOM on DeepSeek-V3 "n_routed_experts": num_experts, + # For Gemma-3n + "num_kv_shared_layers": 1, }) if hasattr(hf_config, "vision_config"): @@ -86,6 +109,9 @@ def _initialize_kv_caches_v1(self, vllm_config): _initialize_kv_caches_v1), monkeypatch.context() as m): if model_info.v0_only: m.setenv("VLLM_USE_V1", "0") + if model_arch == "Phi4FlashForCausalLM": + # Phi4FlashForCausalLM only supports DIFFERENTIAL_FLASH_ATTN backend + m.setenv("VLLM_ATTENTION_BACKEND", "DIFFERENTIAL_FLASH_ATTN") LLM( model_info.default, tokenizer=model_info.tokenizer, @@ -102,3 +128,15 @@ def _initialize_kv_caches_v1(self, vllm_config): load_format="dummy", hf_overrides=hf_overrides, ) + + +@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs()) +def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch): + can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS) + + +@pytest.mark.parametrize("model_arch", + AUTO_EXAMPLE_MODELS.get_supported_archs()) +def test_implicit_converted_models(model_arch: str, + monkeypatch: pytest.MonkeyPatch): + can_initialize(model_arch, monkeypatch, AUTO_EXAMPLE_MODELS) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index 01b2260abe8c..1ce90070c5c8 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -72,11 +72,15 @@ def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): @create_new_process_for_each_test() -@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ - ("MLPSpeculatorPreTrainedModel", False, False), - ("DeepseekV2ForCausalLM", True, False), - ("Qwen2VLForConditionalGeneration", True, True), -]) +@pytest.mark.parametrize( + "model_arch,is_pp,init_cuda", + [ + # TODO(woosuk): Re-enable this once the MLP Speculator is supported + # in V1. + # ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), + ]) def test_registry_is_pp(model_arch, is_pp, init_cuda): assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index b7b99ce41cbb..cd5b6193d001 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -56,7 +56,7 @@ def check_implementation( "model,model_impl", [ ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), - ("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE + ("hmellor/Ilama-3.2-1B", "auto"), # CUSTOM CODE ]) # trust_remote_code=True by default def test_models( hf_runner: type[HfRunner], @@ -138,3 +138,38 @@ def test_quantization( name_0="transformers", name_1="vllm", ) + + +@pytest.mark.parametrize( + "model", + ["jason9693/Qwen2.5-1.5B-apeach"], +) +@pytest.mark.parametrize("dtype", ["float"]) +def test_classify( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + monkeypatch, +) -> None: + import torch + from transformers import AutoModelForSequenceClassification + + with vllm_runner(model, + max_model_len=512, + dtype=dtype, + model_impl="transformers") as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, + 1e-3 if dtype == "float" else 1e-2) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index b642e5c0ad47..3fdf7e33ca5f 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -39,7 +39,7 @@ TEST_VIDEO_URLS = [ "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4", - "https://filesamples.com/samples/video/avi/sample_640x360.avi", + "https://github.com/opencv/opencv/raw/refs/tags/4.12.0/samples/data/vtest.avi", ] diff --git a/tests/multimodal/test_video.py b/tests/multimodal/test_video.py index 897c9c33461a..05b7b84be7f3 100644 --- a/tests/multimodal/test_video.py +++ b/tests/multimodal/test_video.py @@ -1,14 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import tempfile +from pathlib import Path + import numpy as np import numpy.typing as npt import pytest +from PIL import Image -from vllm import envs +from vllm.assets.base import get_vllm_public_assets +from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list from vllm.multimodal.image import ImageMediaIO from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader, VideoMediaIO) +from .utils import cosine_similarity, create_video_from_image, normalize_image + NUM_FRAMES = 10 FAKE_OUTPUT_1 = np.random.rand(NUM_FRAMES, 1280, 720, 3) FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3) @@ -59,30 +67,79 @@ def load_bytes(cls, return FAKE_OUTPUT_2 -def test_video_media_io_kwargs(): - envs.VLLM_VIDEO_LOADER_BACKEND = "assert_10_frames_1_fps" - imageio = ImageMediaIO() +def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch): + with monkeypatch.context() as m: + m.setenv("VLLM_VIDEO_LOADER_BACKEND", "assert_10_frames_1_fps") + imageio = ImageMediaIO() - # Verify that different args pass/fail assertions as expected. - videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0}) - _ = videoio.load_bytes(b"test") - - videoio = VideoMediaIO( - imageio, **{ - "num_frames": 10, - "fps": 1.0, - "not_used": "not_used" - }) - _ = videoio.load_bytes(b"test") - - with pytest.raises(AssertionError, match="bad num_frames"): - videoio = VideoMediaIO(imageio, **{}) + # Verify that different args pass/fail assertions as expected. + videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 1.0}) _ = videoio.load_bytes(b"test") - with pytest.raises(AssertionError, match="bad num_frames"): - videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0}) + videoio = VideoMediaIO( + imageio, **{ + "num_frames": 10, + "fps": 1.0, + "not_used": "not_used" + }) _ = videoio.load_bytes(b"test") - with pytest.raises(AssertionError, match="bad fps"): - videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0}) - _ = videoio.load_bytes(b"test") + with pytest.raises(AssertionError, match="bad num_frames"): + videoio = VideoMediaIO(imageio, **{}) + _ = videoio.load_bytes(b"test") + + with pytest.raises(AssertionError, match="bad num_frames"): + videoio = VideoMediaIO(imageio, **{"num_frames": 9, "fps": 1.0}) + _ = videoio.load_bytes(b"test") + + with pytest.raises(AssertionError, match="bad fps"): + videoio = VideoMediaIO(imageio, **{"num_frames": 10, "fps": 2.0}) + _ = videoio.load_bytes(b"test") + + +@pytest.mark.parametrize("is_color", [True, False]) +@pytest.mark.parametrize("fourcc, ext", [("mp4v", "mp4"), ("XVID", "avi")]) +def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str): + """ + Test all functions that use OpenCV for video I/O return RGB format. + Both RGB and grayscale videos are tested. + """ + image_path = get_vllm_public_assets(filename="stop_sign.jpg", + s3_prefix="vision_model_images") + image = Image.open(image_path) + with tempfile.TemporaryDirectory() as tmpdir: + if not is_color: + image_path = f"{tmpdir}/test_grayscale_image.png" + image = image.convert("L") + image.save(image_path) + # Convert to gray RGB for comparison + image = image.convert("RGB") + video_path = f"{tmpdir}/test_RGB_video.{ext}" + create_video_from_image( + image_path, + video_path, + num_frames=2, + is_color=is_color, + fourcc=fourcc, + ) + + frames = video_to_ndarrays(video_path) + for frame in frames: + sim = cosine_similarity(normalize_image(np.array(frame)), + normalize_image(np.array(image))) + assert np.sum(np.isnan(sim)) / sim.size < 0.001 + assert np.nanmean(sim) > 0.99 + + pil_frames = video_to_pil_images_list(video_path) + for frame in pil_frames: + sim = cosine_similarity(normalize_image(np.array(frame)), + normalize_image(np.array(image))) + assert np.sum(np.isnan(sim)) / sim.size < 0.001 + assert np.nanmean(sim) > 0.99 + + io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path)) + for frame in io_frames: + sim = cosine_similarity(normalize_image(np.array(frame)), + normalize_image(np.array(image))) + assert np.sum(np.isnan(sim)) / sim.size < 0.001 + assert np.nanmean(sim) > 0.99 diff --git a/tests/multimodal/utils.py b/tests/multimodal/utils.py index 23346509a06f..9a58292f9f4a 100644 --- a/tests/multimodal/utils.py +++ b/tests/multimodal/utils.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import cv2 import numpy as np +import numpy.typing as npt from PIL import Image @@ -31,3 +33,47 @@ def random_audio( ): audio_len = rng.randint(min_len, max_len) return rng.rand(audio_len), sr + + +def create_video_from_image( + image_path: str, + video_path: str, + num_frames: int = 10, + fps: float = 1.0, + is_color: bool = True, + fourcc: str = "mp4v", +): + image = cv2.imread(image_path) + if not is_color: + # Convert to grayscale if is_color is False + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + height, width = image.shape + else: + height, width, _ = image.shape + + video_writer = cv2.VideoWriter( + video_path, + cv2.VideoWriter_fourcc(*fourcc), + fps, + (width, height), + isColor=is_color, + ) + + for _ in range(num_frames): + video_writer.write(image) + + video_writer.release() + return video_path + + +def cosine_similarity(A: npt.NDArray, + B: npt.NDArray, + axis: int = -1) -> npt.NDArray: + """Compute cosine similarity between two vectors.""" + return (np.sum(A * B, axis=axis) / + (np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis))) + + +def normalize_image(image: npt.NDArray) -> npt.NDArray: + """Normalize image to [0, 1] range.""" + return image.astype(np.float32) / 255.0 \ No newline at end of file diff --git a/tests/neuron/2_core/test_mistral.py b/tests/neuron/2_core/test_mistral.py index d02fff943e90..ff59be1725b6 100644 --- a/tests/neuron/2_core/test_mistral.py +++ b/tests/neuron/2_core/test_mistral.py @@ -9,7 +9,6 @@ def test_mistral(): tensor_parallel_size=2, max_num_seqs=4, max_model_len=128, - use_v2_block_manager=True, override_neuron_config={ "sequence_parallel_enabled": False, "skip_warmup": True diff --git a/tests/neuron/2_core/test_multi_lora.py b/tests/neuron/2_core/test_multi_lora.py index 6b97f47d4db3..52ca9fe7b666 100644 --- a/tests/neuron/2_core/test_multi_lora.py +++ b/tests/neuron/2_core/test_multi_lora.py @@ -14,7 +14,6 @@ def test_llama_single_lora(): tensor_parallel_size=2, max_num_seqs=4, max_model_len=512, - use_v2_block_manager=True, override_neuron_config={ "sequence_parallel_enabled": False, "skip_warmup": True, @@ -57,7 +56,6 @@ def test_llama_multiple_lora(): tensor_parallel_size=2, max_num_seqs=4, max_model_len=512, - use_v2_block_manager=True, override_neuron_config={ "sequence_parallel_enabled": False, diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index aff3498567d2..fc654f20fff2 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -8,14 +8,16 @@ import torch.nn as nn from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.models.gemma2 import Gemma2Model from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors class MyGemma2Embedding(nn.Module): + + is_pooling_model = True + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -24,12 +26,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = Gemma2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self._pooler = Pooler.from_config_with_defaults( - vllm_config.model_config.pooler_config, - pooling_type=PoolingType.LAST, - normalize=True, - softmax=False, - ) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -54,13 +57,6 @@ def forward( # Return all-zero embeddings return torch.zeros_like(hidden_states) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weights = self.hf_to_vllm_mapper.apply(weights) diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py index f00a8f6998cb..b940ab416e67 100644 --- a/tests/prefix_caching/test_disable_sliding_window.py +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -25,25 +25,25 @@ @pytest.mark.parametrize("model_len_len", MODEL_LEN_LEN) def test_disable_sliding_window(model_len_len, ): model, sliding_len, full_len = model_len_len - vllm_disabled_model = LLM(model, disable_sliding_window=True) - vllm_disabled_model.generate("Hi my name is") - model_config = vllm_disabled_model.llm_engine.model_config + disabled_llm = LLM(model, disable_sliding_window=True) + disabled_llm.generate("Hi my name is") + model_config = disabled_llm.llm_engine.model_config assert model_config.max_model_len == sliding_len, ( "Max len expected to equal sliding_len of %s, but got %s", sliding_len, model_config.max_model_len) - del vllm_disabled_model + del disabled_llm cleanup_dist_env_and_memory() - vllm_enabled_model = LLM(model, - enforce_eager=True, - disable_sliding_window=False, - enable_prefix_caching=False) - vllm_enabled_model.generate("Hi my name is") - model_config = vllm_enabled_model.llm_engine.model_config + enabled_llm = LLM(model, + enforce_eager=True, + disable_sliding_window=False, + enable_prefix_caching=False) + enabled_llm.generate("Hi my name is") + model_config = enabled_llm.llm_engine.model_config assert model_config.max_model_len == full_len, ( "Max len expected to equal full_len of %s, but got %s", full_len, model_config.max_model_len) - del vllm_enabled_model + del enabled_llm cleanup_dist_env_and_memory() diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index a65fc934b16a..5bf6ed957c74 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -93,8 +93,8 @@ def test_mixed_requests( # Run all the promopts greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - req_outputs = vllm_model.model.generate(example_prompts, - greedy_params) + req_outputs = vllm_model.llm.generate(example_prompts, + greedy_params) # Verify number of cached tokens for i in range(len(req_outputs)): @@ -161,7 +161,7 @@ def test_fully_cached_prefill_needs_uncached_token(model): max_num_batched_tokens=max_num_batched_tokens, max_num_seqs=max_num_batched_tokens, ) - engine: LLMEngine = runner.model.llm_engine + engine: LLMEngine = runner.llm.llm_engine scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore engine.scheduler[0] = scheduler diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py deleted file mode 100644 index 2b603fe8f022..000000000000 --- a/tests/prompt_adapter/test_bloom.py +++ /dev/null @@ -1,48 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -import vllm -from vllm.prompt_adapter.request import PromptAdapterRequest - -MODEL_PATH = "bigscience/bloomz-560m" -PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' - - -def do_sample(llm, pa_name: str, pa_id: int): - - prompts = [ - "Tweet text : @nationalgridus I have no water and the bill is \ - current and paid. Can you do something about this? Label : ", - "Tweet text : @nationalgridus Looks good thanks! Label : " - ] - sampling_params = vllm.SamplingParams(temperature=0.0, - max_tokens=3, - stop_token_ids=[3]) - - outputs = llm.generate(prompts, - sampling_params, - prompt_adapter_request=PromptAdapterRequest( - pa_name, pa_id, PA_PATH, 8) if pa_id else None) - - # Print the outputs. - generated_texts = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@pytest.mark.parametrize("enforce_eager", [True, False]) -def test_twitter_prompt_adapter(enforce_eager: bool): - llm = vllm.LLM(MODEL_PATH, - enforce_eager=enforce_eager, - enable_prompt_adapter=True, - max_prompt_adapter_token=8) - - expected_output = ['complaint', 'no complaint'] - - assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py deleted file mode 100644 index 4f273afb4e36..000000000000 --- a/tests/prompt_adapter/test_multi_adapter_inference.py +++ /dev/null @@ -1,56 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm import EngineArgs, LLMEngine, SamplingParams -from vllm.prompt_adapter.request import PromptAdapterRequest - -MODEL_PATH = "bigscience/bloomz-560m" -pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' -pa_path2 = 'swapnilbp/angry_tweet_ptune' - - -def do_sample(engine): - - prompts = [ - ("Tweet text: I have complaints! Label: ", - SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), - PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), - ("Tweet text: I have no problems Label: ", - SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), - PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), - ("Tweet text: I have complaints! Label: ", - SamplingParams(temperature=0.0, max_tokens=3), None), - ("Tweet text: I have no problems Label: ", - SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), - PromptAdapterRequest("complain", 3, pa_path, 8)), - ] - - request_id = 0 - results = set() - while prompts or engine.has_unfinished_requests(): - if prompts: - prompt, sampling_params, pa_request = prompts.pop(0) - engine.add_request(str(request_id), - prompt, - sampling_params, - prompt_adapter_request=pa_request) - request_id += 1 - - request_outputs = engine.step() - - for request_output in request_outputs: - if request_output.finished: - results.add(request_output.outputs[0].text) - return results - - -def test_multi_prompt_adapters(): - engine_args = EngineArgs(model=MODEL_PATH, - max_prompt_adapters=3, - enable_prompt_adapter=True, - max_prompt_adapter_token=8) - engine = LLMEngine.from_engine_args(engine_args) - expected_output = { - ' quot;I', 'hate speech', 'no complaint', 'not hate speech' - } - assert do_sample(engine) == expected_output diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py deleted file mode 100644 index ba2e15b81bc1..000000000000 --- a/tests/prompt_adapter/test_pa_lora.py +++ /dev/null @@ -1,64 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from huggingface_hub import snapshot_download - -from vllm import EngineArgs, LLMEngine, SamplingParams -from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest - -MODEL_PATH = "meta-llama/Llama-2-7b-hf" -pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") -lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") - - -def do_sample(engine): - - prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 - - # first prompt with a prompt adapter and second without adapter - prompts = [ - (prompt_text, - SamplingParams(temperature=0.0, max_tokens=100, - stop=["[/assistant]"]), - PromptAdapterRequest("hate_speech", 1, pa_path, - 8), LoRARequest("sql_test", 1, lora_path)), - (prompt_text, - SamplingParams(temperature=0.0, max_tokens=100, - stop=["[/assistant]"]), None, - LoRARequest("sql_test", 1, lora_path)), - ] - - request_id = 0 - results = set() - while prompts or engine.has_unfinished_requests(): - if prompts: - prompt, sampling_params, pa_request, lora_request = prompts.pop(0) - engine.add_request(str(request_id), - prompt, - sampling_params, - prompt_adapter_request=pa_request, - lora_request=lora_request) - request_id += 1 - - request_outputs = engine.step() - - for request_output in request_outputs: - if request_output.finished: - results.add(request_output.outputs[0].text) - return results - - -def test_lora_prompt_adapter(): - engine_args = EngineArgs(model=MODEL_PATH, - enable_prompt_adapter=True, - enable_lora=True, - max_num_seqs=60, - max_prompt_adapter_token=8) - engine = LLMEngine.from_engine_args(engine_args) - result = do_sample(engine) - - expected_output = { - " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 - } - assert result == expected_output diff --git a/tests/quantization/reference_mxfp4.py b/tests/quantization/reference_mxfp4.py new file mode 100644 index 000000000000..2ef251933f68 --- /dev/null +++ b/tests/quantization/reference_mxfp4.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +BFLOAT16_EXP_BIAS = 127 +BFLOAT16_MANTISSA_BITS = 7 +BFLOAT16_EXP_BITS = 8 + +FLOAT16_EXP_BIAS = 15 +FLOAT16_MANTISSA_BITS = 10 +FLOAT16_EXP_BITS = 5 + +FLOAT8_E8M0_MAX_EXP = 127 +FLOAT4_EXP_BIAS = 1 +FLOAT4_MANTISSA_BITS = 1 + +FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) +FLOAT16_SIGN_EXPONENT_MASK = (( + (1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS) + +BFLOAT16_VAL_TO_ADD = (1 << + (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)) +BFLOAT16_SIGN_EXPONENT_MASK = (( + (1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS) + + +def e8m0_to_half(scale, half_dtype: torch.dtype): + assert scale.dtype == torch.uint8 + + scale_exp = scale.to(torch.int16) - 127 + + # This can be implemented with bitwise operations in a proper kernel. + scale_half = 2.0**(scale_exp.to(torch.float)) + + return scale_half.to(half_dtype) + + +def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype, + half_exp_bias: int, half_mantissa_bits: int): + assert val.dtype == torch.uint8 + + unpacked = torch.zeros(*val.shape[:-1], + val.shape[-1] * 2, + dtype=torch.uint8, + device=val.device) + unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits. + unpacked[..., ::2] = val & 0x0F # Extract low 4 bits. + + # Takes one float4 values represented as b0000xxxx, + # and converts it to the corresponding float16 value. + + sign = unpacked >> 3 + + exp = (unpacked >> 1) & 3 + new_mantissa = unpacked & 1 + + # if exp == 0 and new_mantissa == 0: + # new_exp = 0 + # else: + # new_exp = exp - FLOAT4_EXP_BIAS + FLOAT16_EXP_BIAS + + # int8_t works with float16, but may overflow with bfloat16. + new_exp = exp - FLOAT4_EXP_BIAS + half_exp_bias + + # Cast b0000 to 0. in fp16/bf16. + new_exp = new_exp * torch.logical_or(exp > 0, new_mantissa > 0) + + # Cast b0001 to 0.5 in fp16/bf16. + new_mantissa = torch.logical_and(new_mantissa, exp > 0) + + new_mantissa = new_mantissa.to(torch.int32) + new_exp = new_exp.to(torch.int32) + sign = sign.to(torch.int32) + + qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( + new_mantissa << (half_mantissa_bits - 1)) + + assert qdq_val.max() <= 65535 + assert qdq_val.min() >= 0 + qdq_val = qdq_val.to(torch.uint16) + + result = qdq_val.view(float_dtype) + + return result + + +def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: + assert x.dtype == torch.uint8 + assert scale.dtype == torch.uint8 + + if float_dtype == torch.float16: + half_exp_bias = FLOAT16_EXP_BIAS + half_mantissa_bits = FLOAT16_MANTISSA_BITS + elif float_dtype == torch.bfloat16: + half_exp_bias = BFLOAT16_EXP_BIAS + half_mantissa_bits = BFLOAT16_MANTISSA_BITS + + scale_half = e8m0_to_half(scale, half_dtype=float_dtype) + + x_half = upcast_fp4_to_fp16_or_bf16(x, + float_dtype=float_dtype, + half_exp_bias=half_exp_bias, + half_mantissa_bits=half_mantissa_bits) + + x_half = x_half.reshape(*x_half.shape[:-1], -1, 32) + x_half = x_half * scale_half[..., None] + x_half = x_half.reshape(*x_half.shape[:-2], -1) + + return x_half + + +def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int, + half_exp_bias: int): + # Casts an fp16/bf16 input to the restricted values of float4_e2m1, + # that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, + # -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0]. + + float_type = val.dtype + + # "rshift_cuda" not implemented for 'UInt16' + val_view = val.view(torch.int16) #.to(torch.int32) + + exp = val_view >> half_mantissa_bits + exp = exp & ((1 << half_exp_bits) - 1) + + exp = exp.view(torch.uint16).to(torch.int32) + + sign = (val_view >> (half_mantissa_bits + half_exp_bits)) & 1 + + mantissa_last = (val_view >> (half_mantissa_bits - 1)) & 1 + + exp_unbias = exp - half_exp_bias + new_exp = exp_unbias + FLOAT4_EXP_BIAS + + exp_shift = (new_exp <= 0) * (1 - new_exp) + + # Typically 9. + # Take the min to prevent overflow on `uint16_t half`. This is the case for + # very small values, correctly mapped to `round_close`. + tail_bits = half_mantissa_bits - FLOAT4_MANTISSA_BITS + exp_shift + tail_bits[tail_bits >= 16] = 16 + + mantissa_plus_one = val_view & ((1 << (half_mantissa_bits + 1)) - 1) + + half = 1 << (tail_bits - 1) + + tail = mantissa_plus_one & ((1 << tail_bits) - 1) + + round_close = (tail < half) # round towards 0 + round_away = (tail > half) # round away from 0 + tie = tail == half + + new_mantissa_close = torch.zeros(val.shape, + device=val.device, + dtype=torch.bool) + new_exp_close = torch.zeros(val.shape, + device=val.device, + dtype=torch.uint16) + + new_mantissa_away = torch.zeros(val.shape, + device=val.device, + dtype=torch.bool) + new_exp_away = torch.zeros(val.shape, + device=val.device, + dtype=torch.uint16) + + new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16) + + # 1. round down + # if new_exp == 0: # case [0.5, 0.749999] + # new_mantissa = 0 + # elif new_exp < 0: # case [0, 0.24999] + # new_mantissa = 0 + # else: + # new_mantissa = mantissa_last + + new_mantissa_close = (new_exp > 0) * mantissa_last + new_exp_close = exp + + # # 2. round up + # if new_exp <= 0: # case [0.250001, 0.499999] and [0.75001, 0.99999] + # new_mantissa = 0 + # new_exp += 1 + # elif mantissa_last == 0: + # new_mantissa = 1 + # else: + # new_mantissa = 0 + # new_exp += 1 + + new_mantissa_away = torch.logical_and(new_exp > 0, mantissa_last == 0) + new_exp_away = exp + torch.logical_or(new_exp <= 0, mantissa_last == 1) + + # # 3. tie + # 0.25 -> 0. (handled by `exp > (half_exp_bias - 2)`) + # 0.75 -> 1. + # 1.25 -> 1. + # 1.75 -> 2. + # 2.5 -> 2. + # 3.5 -> 4. + # 5. -> 4. + new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1)) + + # Gather round up, round down and tie. + new_exp = round_away * new_exp_away \ + + round_close * new_exp_close \ + + tie * new_exp_tie + + new_mantissa = round_away * new_mantissa_away \ + + round_close * new_mantissa_close + + # if new_exp > 3: + # new_mantissa = 1 + new_mantissa = new_mantissa + (new_exp > + (2 + half_exp_bias)) * (new_mantissa == 0) + + # Clamp the exponent to acceptable values. + new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp( + new_exp, half_exp_bias - 2, half_exp_bias + 2) + + sign = sign.to(torch.int32) + new_mantissa = new_mantissa.to(torch.int32) + + qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + ( + new_mantissa << (half_mantissa_bits - 1)) + + assert qdq_val.max() <= 65535 + assert qdq_val.min() >= 0 + assert qdq_val.dtype == torch.int32 + qdq_val = qdq_val.to(torch.uint16) + + result = qdq_val.view(float_type) + return result + + +def qdq_mxfp4_torch(x: torch.Tensor, + scale_calculation_mode: str = "even") -> torch.Tensor: + half_dtype = x.dtype + + if half_dtype == torch.float16: + half_mantissa_bits = FLOAT16_MANTISSA_BITS + half_exp_bits = FLOAT16_EXP_BITS + half_exp_bias = FLOAT16_EXP_BIAS + val_to_add = FLOAT16_VAL_TO_ADD + sign_exponent_mask = FLOAT16_SIGN_EXPONENT_MASK + elif half_dtype == torch.bfloat16: + half_mantissa_bits = BFLOAT16_MANTISSA_BITS + half_exp_bits = BFLOAT16_EXP_BITS + half_exp_bias = BFLOAT16_EXP_BIAS + val_to_add = BFLOAT16_VAL_TO_ADD + sign_exponent_mask = BFLOAT16_SIGN_EXPONENT_MASK + else: + raise ValueError("not implemented") + + x = x.reshape(*x.shape[:-1], -1, 32) + + block_max = torch.max(torch.abs(x), dim=-1).values + + block_max = block_max.view(torch.uint16).to(torch.int32) + + block_max_uint = torch.bitwise_and(block_max + val_to_add, + sign_exponent_mask) + + assert block_max_uint.max() <= 65535 + assert block_max_uint.min() >= 0 + assert block_max_uint.dtype == torch.int32 + block_max_uint = block_max_uint.to(torch.uint16) + + block_max = block_max_uint.view(half_dtype) + + scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to( + torch.int32) - 2 + + scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP) + + scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP) + scale = scale.to(half_dtype) + + x = x / scale[..., None] + + x_fp4 = fp16_to_fp4_simulate(x, + half_exp_bits=half_exp_bits, + half_mantissa_bits=half_mantissa_bits, + half_exp_bias=half_exp_bias) + + x_fp4 = x_fp4 * scale[..., None] + return x_fp4.reshape(*x_fp4.shape[:-2], -1) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 3646ad6c481b..db7e50eff72b 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -45,7 +45,8 @@ def use_v0_only(monkeypatch): """ This module relies on V0 internals, so set VLLM_USE_V1=0. """ - monkeypatch.setenv('VLLM_USE_V1', '0') + if not current_platform.is_cpu(): + monkeypatch.setenv('VLLM_USE_V1', '0') @pytest.mark.parametrize( diff --git a/tests/quantization/test_gptq_dynamic.py b/tests/quantization/test_gptq_dynamic.py index 23b999e7c679..aea50e99c1dd 100644 --- a/tests/quantization/test_gptq_dynamic.py +++ b/tests/quantization/test_gptq_dynamic.py @@ -39,7 +39,7 @@ def test_gptq_with_dynamic(vllm_runner, model_id: str, use_marlin_kernel: bool, linear_method_cls = GPTQMarlinLinearMethod if use_marlin_kernel else ( GPTQLinearMethod) - for name, submodule in (vllm_model.model.llm_engine.model_executor. + for name, submodule in (vllm_model.llm.llm_engine.model_executor. driver_worker.model_runner.model.named_modules()): if name == "lm_head": assert isinstance(submodule.quant_method, linear_method_cls) diff --git a/tests/quantization/test_modelopt.py b/tests/quantization/test_modelopt.py new file mode 100644 index 000000000000..fcbfa681d75c --- /dev/null +++ b/tests/quantization/test_modelopt.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test ModelOpt quantization method setup and weight loading. + +Run `pytest tests/quantization/test_modelopt.py`. +""" + +import os + +import pytest +import torch + +from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform + + +@pytest.fixture(scope="function", autouse=True) +def use_v0_only(monkeypatch): + """ + This module relies on V0 internals, so set VLLM_USE_V1=0. + """ + if not current_platform.is_cpu(): + monkeypatch.setenv('VLLM_USE_V1', '0') + + +@pytest.mark.skipif(not is_quant_method_supported("modelopt"), + reason="ModelOpt FP8 is not supported on this GPU type.") +def test_modelopt_fp8_checkpoint_setup(vllm_runner): + """Test ModelOpt FP8 checkpoint loading and structure validation.""" + # TODO: provide a small publically available test checkpoint + model_path = ("/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/" + "TinyLlama-1.1B-Chat-v1.0-fp8-0710") + + # Skip test if checkpoint doesn't exist + if not os.path.exists(model_path): + pytest.skip(f"Test checkpoint not found at {model_path}. " + "This test requires a local ModelOpt FP8 checkpoint.") + + with vllm_runner(model_path, quantization="modelopt", + enforce_eager=True) as llm: + + def check_model(model): + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + o_proj = layer.self_attn.o_proj + gate_up_proj = layer.mlp.gate_up_proj + down_proj = layer.mlp.down_proj + + # Check that ModelOpt quantization method is properly applied + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptFp8LinearMethod) + assert isinstance(qkv_proj.quant_method, ModelOptFp8LinearMethod) + assert isinstance(o_proj.quant_method, ModelOptFp8LinearMethod) + assert isinstance(gate_up_proj.quant_method, + ModelOptFp8LinearMethod) + assert isinstance(down_proj.quant_method, ModelOptFp8LinearMethod) + + # Check weight dtype is FP8 + assert qkv_proj.weight.dtype == torch.float8_e4m3fn + assert o_proj.weight.dtype == torch.float8_e4m3fn + assert gate_up_proj.weight.dtype == torch.float8_e4m3fn + assert down_proj.weight.dtype == torch.float8_e4m3fn + + # Check scales are present and have correct dtype + assert hasattr(qkv_proj, 'weight_scale') + assert hasattr(qkv_proj, 'input_scale') + assert qkv_proj.weight_scale.dtype == torch.float32 + assert qkv_proj.input_scale.dtype == torch.float32 + + assert hasattr(o_proj, 'weight_scale') + assert hasattr(o_proj, 'input_scale') + assert o_proj.weight_scale.dtype == torch.float32 + assert o_proj.input_scale.dtype == torch.float32 + + assert hasattr(gate_up_proj, 'weight_scale') + assert hasattr(gate_up_proj, 'input_scale') + assert gate_up_proj.weight_scale.dtype == torch.float32 + assert gate_up_proj.input_scale.dtype == torch.float32 + + assert hasattr(down_proj, 'weight_scale') + assert hasattr(down_proj, 'input_scale') + assert down_proj.weight_scale.dtype == torch.float32 + assert down_proj.input_scale.dtype == torch.float32 + + llm.apply_model(check_model) + + # Run a simple generation test to ensure the model works + output = llm.generate_greedy(["Hello my name is"], max_tokens=20) + assert output + print(f"ModelOpt FP8 output: {output}") diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index 3571f773fb02..4a0c8ba4d8a9 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -3,15 +3,44 @@ """Test model set-up and weight loading for quark-quantized models. Run `pytest tests/quantization/test_quark.py`. + +See also `tests/kernels/moe/test_mxfp4_moe.py`. """ +import importlib +import importlib.metadata +import os +from dataclasses import dataclass + +import huggingface_hub +import lm_eval import pytest import torch +from packaging import version from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.platforms import current_platform +from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch + +QUARK_MXFP4_AVAILABLE = importlib.util.find_spec( + "quark") is not None and version.parse( + importlib.metadata.version("amd-quark")) >= version.parse('0.8.99') + +if QUARK_MXFP4_AVAILABLE: + from quark.torch.export.nn.modules.realquantizer import ( + StaticScaledRealQuantizer) + from quark.torch.kernel import mx as mx_kernel + from quark.torch.quantization.config.config import FP4PerGroupSpec + +try: + huggingface_hub.list_repo_refs( + "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ") + HF_HUB_AMD_ORG_ACCESS = True +except huggingface_hub.errors.RepositoryNotFoundError: + HF_HUB_AMD_ORG_ACCESS = False + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch): @@ -78,11 +107,11 @@ def test_quark_fp8_parity(vllm_runner): } with (vllm_runner(quark_model_id, **llm_kwargs) as quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): - quark_model = (quark_handle.model.llm_engine.model_executor. + quark_model = (quark_handle.llm.llm_engine.model_executor. driver_worker.model_runner.model) quark_state_dict = quark_model.state_dict() - fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker. + fp8_model = (fp8_handle.llm.llm_engine.model_executor.driver_worker. model_runner.model) fp8_state_dict = fp8_model.state_dict() @@ -90,3 +119,145 @@ def test_quark_fp8_parity(vllm_runner): for key in fp8_state_dict: assert torch.equal(fp8_state_dict[key], quark_state_dict[key]) + + +@dataclass +class ModelCase: + model_id: str + tp: int + + +@dataclass +class GSM8KAccuracyTestConfig: + model_name: str + excepted_value: float + + def get_model_args(self) -> str: + return ( + f"pretrained={self.model_name}," + "dtype=auto,add_bos_token=True,tensor_parallel_size=8,gpu_memory_utilization=0.7,max_model_len=38768" + ) + + +ACCURACY_CONFIGS = [ + # Private model. + GSM8KAccuracyTestConfig( + model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant", + excepted_value=0.96), +] + + +@pytest.mark.parametrize("config", ACCURACY_CONFIGS) +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +@pytest.mark.skipif( + not HF_HUB_AMD_ORG_ACCESS, + reason="Read access to huggingface.co/amd is required for this test.") +def test_mxfp4_gsm8k_correctness(config: GSM8KAccuracyTestConfig): + if torch.cuda.device_count() < 8: + pytest.skip( + f"This test requires >=8 gpus, got only {torch.cuda.device_count()}" + ) + + task = "gsm8k" + rtol = 0.03 + + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" + + results = lm_eval.simple_evaluate( + model="vllm", + model_args=config.get_model_args(), + tasks=task, + batch_size=64, + num_fewshot=8, + ) + + EXPECTED_VALUE = config.excepted_value + measured_value = results["results"][task]["exact_match,strict-match"] + assert (measured_value - rtol < EXPECTED_VALUE + and measured_value + rtol > EXPECTED_VALUE + ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" + + del os.environ["VLLM_USE_TRITON_FLASH_ATTN"] + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("scalings", + [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, + scalings: list[int]): + torch.manual_seed(0) + + hidden_size = 64 * 32 + inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - + 0.5) * 2 + for i in range(hidden_size // 32): + inp[:, i * 32:(i + 1) * + 32] = inp[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + + inp_kernel = inp.clone() + inp_kernel_clone = inp_kernel.clone() + + res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even") + res_torch = qdq_mxfp4_torch(inp_kernel, "even") + + for i in range(hidden_size // 32): + assert torch.all(torch.isfinite(res_hip[:, i * 32:(i + 1) * 32])) + assert torch.all(torch.isfinite(res_torch[:, i * 32:(i + 1) * 32])) + + torch.testing.assert_close(res_hip[:, i * 32:(i + 1) * 32], + res_torch[:, i * 32:(i + 1) * 32]) + + +@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, + reason="amd-quark>=0.9 is not available") +@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("scalings", + [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]]) +def test_mxfp4_dequant_kernel_match_quark(float_dtype: torch.dtype, + scalings: list[int]): + qspec = FP4PerGroupSpec( + ch_axis=-1, + group_size=32, + scale_format="e8m0", + scale_calculation_mode="even", + is_dynamic=False, + ).to_quantization_spec() + + weight_quantizer = StaticScaledRealQuantizer( + qspec=qspec, + quantizer=None, + reorder=False, + real_quantized=True, + float_dtype=float_dtype, + device="cuda", + ) + + observer = qspec.observer_cls(qspec, device="cuda") + + hidden_size = 512 + shape = (11008, hidden_size) + + w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2 + + # Make it so that different groups have different scales. + for i in range(hidden_size // 32): + w[:, i * 32:(i + 1) * + 32] = w[:, i * 32:(i + 1) * 32] * scalings[i % len(scalings)] + + observer(w) + scale, _ = observer._calculate_qparams() + weight_quantizer.scale = scale + + w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda") + weight_quantizer.maybe_convert_and_transpose_scale() + + scale = weight_quantizer.scale + + out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype) + + out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype) + + assert torch.equal(out_hip, out_torch) diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py index 6c541fdbeeae..84705e92c85b 100644 --- a/tests/quantization/test_register_quantization_config.py +++ b/tests/quantization/test_register_quantization_config.py @@ -111,7 +111,7 @@ def test_custom_quant(vllm_runner, model, monkeypatch): quantization="custom_quant", enforce_eager=True) as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + model = llm.llm.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 layer = model.model.layers[0] qkv_proj = layer.self_attn.qkv_proj diff --git a/tests/reasoning/test_hunyuan_reasoning_parser.py b/tests/reasoning/test_hunyuan_reasoning_parser.py new file mode 100644 index 000000000000..f9238267f02e --- /dev/null +++ b/tests/reasoning/test_hunyuan_reasoning_parser.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from transformers import AutoTokenizer + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "hunyuan_a13b" +START_REASONING = "<think>\n" +START_RESPONSE = "\n</think>\n<answer>\n" +END_RESPONSE = "\n</answer>" + +NO_REASONING_QUICK_THROUGHT = { + "output": + f"{START_REASONING}{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "reasoning_content": None, + "content": "This is the rest", +} + +SIMPLE_REASONING = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest{END_RESPONSE}", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} + +COMPLETE_REASONING_WITH_SYMBOL = { + "output": f"{START_REASONING}This is a reasoning section!{START_RESPONSE}", + "reasoning_content": "This is a reasoning section!", + "content": None, +} +NO_REASONING = { + "output": "This is content", + "reasoning_content": None, + "content": "This is content", +} +MULTIPLE_LINES = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +REASONING_WITH_THINK = { + "output": + f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501 + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING_WITH_THINK = { + "output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}", + "reasoning_content": "This is a reasoning section", + "content": None, +} +MULTIPLE_LINES_WITH_THINK = { + "output": + f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_SYMBOL, + id="complete_reasoning_with_symbol", + ), + pytest.param( + False, + NO_REASONING, + id="no_reasoning", + ), + pytest.param(False, NO_REASONING_QUICK_THROUGHT, id="no_reasoning_quick"), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + True, + NO_REASONING, + id="no_reasoning_streaming", + ), + pytest.param(True, + NO_REASONING_QUICK_THROUGHT, + id="no_reasoning_quick_stream"), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), +] + +# Global tokenizer initialization to avoid repeated loading +tokenizer = AutoTokenizer.from_pretrained("tencent/Hunyuan-A13B-Instruct", + trust_remote_code=True) + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, +): + output = tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: list[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/tests/reasoning/test_mistral_reasoning_parser.py b/tests/reasoning/test_mistral_reasoning_parser.py new file mode 100644 index 000000000000..91a22f6f5d72 --- /dev/null +++ b/tests/reasoning/test_mistral_reasoning_parser.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +from mistral_common.tokens.tokenizers.base import SpecialTokens +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo, + Tekkenizer) + +from tests.reasoning.utils import run_reasoning_extraction_mistral +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +parser_name = "mistral" + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + # TODO(Julien): upon model release change to a tokenizer already configured. + # ================================================================= + mistral_tokenizer = MistralTokenizer.from_pretrained( + "mistralai/Devstral-Small-2507") + assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer) + # Add think special tokens to the tokenizer + mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo( + rank=35, is_control=True, token_str=SpecialTokens.begin_think.value) + mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo( + rank=36, is_control=True, token_str=SpecialTokens.end_think.value) + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = { + k: v + for k, v in + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items() + if v not in {35, 36} + } + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.begin_think.value] = 35 + mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[ + SpecialTokens.end_think.value] = 36 + mistral_tokenizer.instruct.BEGIN_THINK = 35 + mistral_tokenizer.instruct.END_THINK = 36 + # ================================================================= + return mistral_tokenizer + + +SIMPLE_REASONING = { + "output": "This is a reasoning section[/THINK]This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING = { + "output": "This is a reasoning section[/THINK]", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +NO_CONTENT = { + "output": "This is content", + "reasoning_content": "This is content", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING_STREAMING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +MULTIPLE_LINES = { + "output": "This\nThat[/THINK]This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING = { + "output": "[/THINK]This is the rest", + "reasoning_content": "", + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING = { + "output": "[/THINK]This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +REASONING_WITH_THINK = { + "output": "[THINK]This is a reasoning section[/THINK]This is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING_WITH_THINK = { + "output": "[THINK]This is a reasoning section[/THINK]", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +MULTIPLE_LINES_WITH_THINK = { + "output": "[THINK]This\nThat[/THINK]This is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +SHORTEST_REASONING_NO_STREAMING_WITH_THINK = { + "output": "[/THINK]This is the rest", + "reasoning_content": "", + "content": "This is the rest", + "is_reasoning_end": True, +} +SHORTEST_REASONING_WITH_THINK = { + "output": "[/THINK]This is the rest", + "reasoning_content": None, + "content": "This is the rest", + "is_reasoning_end": True, +} +THINK_NO_END = { + "output": "[THINK]This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +EMPTY = { + "output": "", + "reasoning_content": "", + "content": None, + "is_reasoning_end": False, +} +EMPTY_STREAMING = { + "output": "", + "reasoning_content": None, + "content": None, + "is_reasoning_end": False, +} +NEW_LINE = { + "output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} +# Streaming cannot handle new lines at the beginning of the output +# because we need to support [THINK]...[/THINK] and [/THINK]... +# We cannot know if the text before [THINK] is reasoning content +# or not. +NEW_LINE_STREAMING = { + "output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest", + "reasoning_content": "\nThis is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_reasoning", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_reasoning_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_reasoning", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_reasoning_streaming", + ), + pytest.param( + False, + NO_CONTENT, + id="no_content_token", + ), + pytest.param( + True, + NO_REASONING_STREAMING, + id="no_reasoning_token_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + SHORTEST_REASONING, + id="shortest", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING, + id="shortest_streaming", + ), + pytest.param( + False, + REASONING_WITH_THINK, + id="reasoning_with_think", + ), + pytest.param( + True, + REASONING_WITH_THINK, + id="reasoning_with_think_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think", + ), + pytest.param( + True, + COMPLETE_REASONING_WITH_THINK, + id="complete_reasoning_with_think_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think", + ), + pytest.param( + True, + MULTIPLE_LINES_WITH_THINK, + id="multiple_lines_with_think_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING_WITH_THINK, + id="shortest_with_think", + ), + pytest.param( + True, + SHORTEST_REASONING_WITH_THINK, + id="shortest_with_think_streaming", + ), + pytest.param( + False, + THINK_NO_END, + id="think_no_end", + ), + pytest.param( + True, + THINK_NO_END, + id="think_no_end_streaming", + ), + pytest.param( + False, + EMPTY, + id="empty", + ), + pytest.param( + True, + EMPTY_STREAMING, + id="empty_streaming", + ), + pytest.param( + False, + NEW_LINE, + id="new_line", + ), + pytest.param( + True, + NEW_LINE_STREAMING, + id="new_line_streaming", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_mistral_reasoning( + streaming: bool, + param_dict: dict, + mistral_tokenizer: MistralTokenizer, +): + output = param_dict["output"] + + index_think = output.find("[THINK]") + len_think = len("[THINK]") + index_end_think = output.find("[/THINK]") + len_end_think = len("[/THINK]") + + # encode everything to tokens ids + output_tokens = [] + if index_think != -1: + output_before_think = output[:index_think] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_before_think, False, False) + output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK] + + if index_end_think != -1: + output_middle = output[index_think + len_think:index_end_think] + output_after_think = output[index_end_think + len_end_think:] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_middle, False, False) + output_tokens += [mistral_tokenizer.instruct.END_THINK] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_after_think, False, False) + else: + output_middle = output[index_think + len_think:] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_middle, False, False) + elif index_end_think != -1: + output_before_think = output[:index_end_think] + output_after_think = output[index_end_think + len_end_think:] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_before_think, False, False) + output_tokens += [mistral_tokenizer.instruct.END_THINK] + output_tokens += mistral_tokenizer.tokenizer.encode( + output_after_think, False, False) + else: + output_tokens += mistral_tokenizer.tokenizer.encode( + output, False, False) + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(mistral_tokenizer) + + reasoning, content = run_reasoning_extraction_mistral(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] + + # Test is_reasoning_end + is_reasoning_end = parser.is_reasoning_end(output_tokens) + assert is_reasoning_end == param_dict["is_reasoning_end"] + + # Test extract_content + if param_dict["content"] is not None: + content = parser.extract_content_ids(output_tokens) + assert content == mistral_tokenizer.tokenizer.encode( + param_dict["content"], bos=False, eos=False) + else: + content = parser.extract_content_ids(output_tokens) + assert content == [] diff --git a/tests/reasoning/utils.py b/tests/reasoning/utils.py index ddcf89796fb5..9af5fa5addbc 100644 --- a/tests/reasoning/utils.py +++ b/tests/reasoning/utils.py @@ -6,6 +6,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) from vllm.reasoning import ReasoningParser +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer class StreamingReasoningReconstructor: @@ -54,6 +55,32 @@ def run_reasoning_extraction( return reasoning, content +def run_reasoning_extraction_mistral( + reasoning_parser: ReasoningParser, + model_output: list[int], + request: Union[ChatCompletionRequest, None] = None, + streaming: bool = False, +) -> tuple[Optional[str], Optional[str]]: + assert isinstance(reasoning_parser.model_tokenizer, + MistralTokenizer), type(reasoning_parser.model_tokenizer) + if streaming: + reconstructor = run_reasoning_extraction_streaming_mistral( + reasoning_parser, + model_output, + request, + ) + return ( + reconstructor.reasoning_content, + reconstructor.other_content or None, + ) + else: + str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens( + model_output) + reasoning, content = run_reasoning_extraction_nonstreaming( + reasoning_parser, str_output, request) + return reasoning, content + + def run_reasoning_extraction_nonstreaming( reasoning_parser: ReasoningParser, model_output: list[str], @@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming( previous_text = current_text previous_tokens = current_tokens return reconstructor + + +def run_reasoning_extraction_streaming_mistral( + reasoning_parser: ReasoningParser, + model_deltas: list[int], + request: Union[ChatCompletionRequest, None] = None, +) -> StreamingReasoningReconstructor: + assert isinstance(reasoning_parser.model_tokenizer, + MistralTokenizer), type(reasoning_parser.model_tokenizer) + request = request or ChatCompletionRequest(messages=[], model="test-model") + reconstructor = StreamingReasoningReconstructor() + previous_text = "" + previous_tokens: list[int] = [] + for model_delta in model_deltas: + token_delta = [model_delta] + delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens( + [model_delta])[0] + current_text = previous_text + delta + current_tokens = previous_tokens + token_delta + delta_message = reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta, + previous_tokens, + current_tokens, + token_delta, + ) + if delta_message is not None: + reconstructor.append_delta(delta_message) + previous_text = current_text + previous_tokens = current_tokens + return reconstructor diff --git a/tests/samplers/test_ignore_eos.py b/tests/samplers/test_ignore_eos.py index 7eb9c0b5fb8c..ea4a17dd2306 100644 --- a/tests/samplers/test_ignore_eos.py +++ b/tests/samplers/test_ignore_eos.py @@ -36,7 +36,7 @@ def test_ignore_eos( ignore_eos=True) for prompt in example_prompts: - ignore_eos_output = vllm_model.model.generate( + ignore_eos_output = vllm_model.llm.generate( prompt, sampling_params=sampling_params) output_length = len(ignore_eos_output[0].outputs[0].token_ids) assert output_length == max_tokens diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 901c87591264..123f9595e97b 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -26,7 +26,7 @@ def test_logits_processor_force_generate( dtype: str, ) -> None: with vllm_runner(model, dtype=dtype) as vllm_model: - tokenizer = vllm_model.model.get_tokenizer() + tokenizer = vllm_model.llm.get_tokenizer() repeat_times = 2 enforced_answers = " vLLM" vllm_token_ids = tokenizer.encode(enforced_answers, @@ -45,13 +45,13 @@ def pick_vllm(token_ids, logits): ) # test logits_processors when prompt_logprobs is not None - vllm_model.model._add_request( + vllm_model.llm._add_request( example_prompts[0], params=params_with_logprobs, ) # test prompt_logprobs is not None - vllm_model.model._add_request( + vllm_model.llm._add_request( example_prompts[1], params=SamplingParams( prompt_logprobs=3, @@ -60,11 +60,11 @@ def pick_vllm(token_ids, logits): ) # test grouped requests - vllm_model.model._add_request( + vllm_model.llm._add_request( example_prompts[2], params=SamplingParams(max_tokens=max_tokens), ) - outputs = vllm_model.model._run_engine(use_tqdm=False) + outputs = vllm_model.llm._run_engine(use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_logprobs.py b/tests/samplers/test_logprobs.py index 86c8a03eee10..87f40b100531 100644 --- a/tests/samplers/test_logprobs.py +++ b/tests/samplers/test_logprobs.py @@ -64,7 +64,7 @@ def test_get_prompt_logprobs( prompt_logprobs=num_top_logprobs, temperature=0.0, detokenize=detokenize) - vllm_results = vllm_model.model.generate( + vllm_results = vllm_model.llm.generate( example_prompts, sampling_params=vllm_sampling_params) # Test whether logprobs are included in the results. @@ -174,7 +174,7 @@ def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int, logprobs=None, temperature=0.0, detokenize=detokenize) - results_logprobs_none = vllm_model.model.generate( + results_logprobs_none = vllm_model.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_none) for i in range(len(results_logprobs_none)): diff --git a/tests/samplers/test_no_bad_words.py b/tests/samplers/test_no_bad_words.py index 42b529ae169d..11803b8d7a5e 100644 --- a/tests/samplers/test_no_bad_words.py +++ b/tests/samplers/test_no_bad_words.py @@ -20,7 +20,7 @@ def v1(run_with_both_engines): def _generate( - model: LLM, + llm: LLM, prompt: str, num_prompt_tokens: int, temperature: float = 0, @@ -32,7 +32,7 @@ def _generate( ) # [([output_token_ids, ], [output_text, ]), ] - output = model.generate([prompt], sampling_params=sampling_params) + output = llm.generate([prompt], sampling_params=sampling_params) output_token_ids = output[0][0][0][num_prompt_tokens:] # [0] first (and only) request output @@ -66,10 +66,10 @@ def test_one_token_bad_word(self, vllm_runner): assert self.target_token_id not in output_token_ids def _generate(self, - model: LLM, + llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: return _generate( - model=model, + llm=llm, prompt=self.PROMPT, num_prompt_tokens=self.num_prompt_tokens, bad_words=bad_words, @@ -156,10 +156,10 @@ def test_two_token_bad_word(self, vllm_runner): or (self.neighbour_token_id2 in output_token_ids)) def _generate(self, - model: LLM, + llm: LLM, bad_words: Optional[list[str]] = None) -> list[int]: return _generate( - model=model, + llm=llm, prompt=self.PROMPT, num_prompt_tokens=self.num_prompt_tokens, bad_words=bad_words, diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py deleted file mode 100644 index 3b93c64113da..000000000000 --- a/tests/samplers/test_rejection_sampler.py +++ /dev/null @@ -1,577 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for rejection sampling.""" - -import pytest -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.model_executor.utils import set_random_seed - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This file tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - - -def mock_causal_accepted_tensor( - k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor: - """Generate an "accepted" tensor which should yield causally-accepted tokens - up to last accepted indices. - - Tokens after last_accepted_indices+1 may also be accepted, although they - will not be causally accepted. - """ - batch_size = last_accepted_indices.shape[0] - - accepted = (torch.arange(k).expand(batch_size, k) - <= last_accepted_indices.unsqueeze(-1).broadcast_to( - batch_size, k)) - - # Sprinkle accepted values after the contiguous initial accepted values. - # This replicates the behavior of rejection sampling, which may "accept" - # a token that cannot be accepted because of causality. - sprinkle_candidates = (torch.arange(k).expand( - batch_size, - k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + - 1) - sprinkle = torch.rand(batch_size, k) > 0.5 - accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates] - return accepted - - -@pytest.mark.parametrize("seed", list(range(10))) -@pytest.mark.parametrize( - "which_tokens_accepted", - ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) -@torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int, - device: str, use_flashinfer: bool): - """Verify the output has correct format given predetermined accepted matrix. - """ - set_random_seed(seed) - torch.set_default_device(device) - - batch_size = 10 - k = 5 - vocab_size = 3000 - - if which_tokens_accepted == "all_tokens_accepted": - accepted = mock_causal_accepted_tensor( - k, -1 + k * torch.ones((batch_size, ), dtype=torch.long)) - elif which_tokens_accepted == "no_tokens_accepted": - accepted = mock_causal_accepted_tensor( - k, -torch.ones((batch_size, ), dtype=torch.long)) - elif which_tokens_accepted == "some_tokens_accepted": - last_accepted_indices = torch.randint(low=-1, - high=k, - size=(batch_size, )) - accepted = mock_causal_accepted_tensor(k, last_accepted_indices) - else: - raise AssertionError() - - recovered_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) - rejection_sampler.init_gpu_tensors(device=device) - output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access - accepted, - recovered_token_ids, - draft_token_ids, - bonus_token_ids, - ) - - expected_bonus_token_ids = bonus_token_ids.clone() - - if which_tokens_accepted == "all_tokens_accepted": - # Expect all tokens to be equal to draft tokens. - assert torch.equal(output_token_ids[:, :-1], draft_token_ids) - - # Expect all bonus tokens to be included. - assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids) - elif which_tokens_accepted == "no_tokens_accepted": - # Expect first token to be equal to recovered tokens. - assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0]) - - # Expect everything else to be -1. - assert torch.equal(output_token_ids[:, 1:], - torch.ones_like(output_token_ids[:, 1:]) * -1) - elif which_tokens_accepted == "some_tokens_accepted": - recovered_plus_bonus = torch.cat( - (recovered_token_ids, expected_bonus_token_ids), dim=-1) - # Assert first rejected token is a recovered token or bonus token. - assert torch.equal( - recovered_plus_bonus[torch.arange(0, batch_size), - last_accepted_indices + 1], - output_token_ids[torch.arange(0, batch_size), - last_accepted_indices + 1]) - - # Assert every subsequent token is -1. - subsequent_mask = torch.arange(0, k + 1).expand( - batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1) - assert torch.all(output_token_ids[subsequent_mask] == -1) - - -@pytest.mark.parametrize("k", list(range(1, 6))) -@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) -@pytest.mark.parametrize("batch_size", list(range(1, 32))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) -@torch.inference_mode() -def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, - device: str, use_flashinfer: bool): - torch.set_default_device(device) - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) - rejection_sampler.init_gpu_tensors(device=device) - - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - - rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids) - - -@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0]) -@pytest.mark.parametrize("k", [1, 3, 6]) -@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) -@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) -@pytest.mark.parametrize("n_rep", [100]) -@pytest.mark.parametrize("device", CUDA_DEVICES) -# @pytest.mark.parametrize("use_flashinfer", [True, False]) -# Not testing FlashInfer now, since 0.2.3 API removed the ability -# to pass in uniform samples. -@pytest.mark.parametrize("use_flashinfer", [False]) -@torch.inference_mode() -def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, - frac_seeded: float, n_rep: int, device: str, - use_flashinfer: bool): - torch.set_default_device(device) - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) - rejection_sampler.init_gpu_tensors(device=device) - - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - - seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded - - results = [] - for _ in range(n_rep): - seeded_seqs = { - i: torch.Generator(device=device).manual_seed(i) - for i in range(batch_size) if seeded_mask[i] - } - results.append( - rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids, seeded_seqs)) - - for i in range(batch_size): - if seeded_mask[i]: - for j in range(1, n_rep): - assert torch.equal(results[j][i], results[0][i]) - - -@pytest.mark.parametrize("k", [1, 3, 6]) -@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) -@pytest.mark.parametrize("batch_size", [3, 8, 32, 128]) -@pytest.mark.parametrize("device", CUDA_DEVICES) -# @pytest.mark.parametrize("use_flashinfer", [True, False]) -# Not testing FlashInfer now, since 0.2.3 API removed the ability -# to pass in uniform samples. -@pytest.mark.parametrize("use_flashinfer", [False]) -@torch.inference_mode() -def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int, - device: str, use_flashinfer: bool): - torch.set_default_device(device) - set_random_seed(0) - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - - single_batches = [] - for i in range(batch_size): - single_batches.append((draft_probs[i].clone().unsqueeze(0), - draft_token_ids[i].clone().unsqueeze(0), - target_probs[i].clone().unsqueeze(0), - bonus_token_ids[i].clone().unsqueeze(0), - draft_token_ids[i].clone().unsqueeze(0))) - - set_random_seed(0) - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) - rejection_sampler.init_gpu_tensors(device=device) - - results = [] - seeded_seqs = { - i: torch.Generator(device=device).manual_seed(i) - for i in range(1, batch_size) # 0 is seed None - } - batch_result = rejection_sampler(target_probs.clone(), - bonus_token_ids.clone(), - draft_probs.clone(), - draft_token_ids.clone(), seeded_seqs) - - set_random_seed(0) - - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) - rejection_sampler.init_gpu_tensors(device=device) - for i in range(batch_size): - request_seeded_seqs = { - 0: torch.Generator(device=device).manual_seed(i) - } if seeded_seqs.get(i) is not None else None - (draft_probs, draft_token_ids, target_probs, bonus_token_ids, - draft_token_ids) = single_batches[i] - results.append( - rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids, request_seeded_seqs)) - for i in range(batch_size): - assert torch.equal(batch_result[i], results[i].squeeze(0)) - - -@pytest.mark.parametrize("k", [1, 3, 6]) -@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) -@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_compare_nonflashinfer_backend(k: int, vocab_size: int, - batch_size: int, device: str): - """ - Test the flashinfer and nonflashinfer backend generate - the same output metrics. - """ - - pytest.skip("Not testing FlashInfer now, since 0.2.3 API removed " - "the ability to pass in uniform samples.") - - torch.set_default_device(device) - torch.manual_seed(0) - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - - num_accepted_tokens = [] - num_emitted_tokens = [] - num_draft_tokens = [] - - def get_seeded_seqs(): - return { - i: torch.Generator(device=device).manual_seed(i) - for i in range(batch_size) - } - - for use_flashinfer in [True, False]: - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer) - rejection_sampler.init_gpu_tensors(device=device) - # We use seeded sequences to ensure the same tokens are accepted - # for both flashinfer and nonflashinfer backends. - seeded_seqs = get_seeded_seqs() - rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids, seeded_seqs) - num_accepted_tokens.append(rejection_sampler.num_accepted_tokens) - num_emitted_tokens.append(rejection_sampler.num_emitted_tokens) - num_draft_tokens.append(rejection_sampler.num_draft_tokens) - - assert num_accepted_tokens[0] == num_accepted_tokens[1] - assert num_emitted_tokens[0] == num_emitted_tokens[1] - assert num_draft_tokens[0] == num_draft_tokens[1] - - -@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) -@pytest.mark.parametrize("which_token_ids", - ["bonus_token_ids", "draft_token_ids"]) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@pytest.mark.parametrize("use_flashinfer", [True, False]) -@torch.inference_mode() -def test_raises_when_vocab_oob(above_or_below_vocab_range: str, - which_token_ids: str, device: str, - use_flashinfer: bool): - k = 3 - batch_size = 5 - vocab_size = 30_000 - torch.set_default_device(device) - - rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer, - strict_mode=True) - rejection_sampler.init_gpu_tensors(device=device) - - draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - - oob_token_ids = None - if which_token_ids == "bonus_token_ids": - oob_token_ids = bonus_token_ids - elif which_token_ids == "draft_token_ids": - oob_token_ids = draft_token_ids - else: - raise AssertionError() - - if above_or_below_vocab_range == "above": - rogue_token_id = vocab_size + 1 - elif above_or_below_vocab_range == "below": - rogue_token_id = -1 - else: - raise AssertionError() - - oob_token_ids[0][0] = rogue_token_id - - with pytest.raises(AssertionError): - rejection_sampler(target_probs, bonus_token_ids, draft_probs, - draft_token_ids) - - -@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) -@pytest.mark.parametrize("seed", list(range(5))) -@pytest.mark.parametrize("use_flashinfer", [True, False]) -@torch.inference_mode() -def test_rejection_sampling_approximates_target_distribution( - seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool): - """Verify rejection sampling approximates target distribution, - despite sampling from a potentially distinct draft distribution. - - This is done by first creating a random target probability - distribution and a random draft probability distribution. We then - sample token ids from the rejection sampler using these draft - and target distributions. The samples are used to estimate - the output probability distribution, which we expect to approximate - the target distribution. - - A basic distance metric is used to determine similarity between - distributions. - - We expect that as we increase the number of samples, - the distance between the observed distribution and the target - distribution decreases. To measure this, we compare the distance - of the observed distribution against both the target distribution - and a uniform random distribution. We expect the distance between - the observed distribution and the target distribution to improve - much more than the distance improvement between the observed - distribution and the random distribution. - - When draft_and_target_probs_equal=True, the draft and target - probabilities are exactly equal. Rejection sampling should - still work without any NaNs or exceptions. - """ - torch.set_default_device("cpu") - set_random_seed(seed) - helper = _CorrectnessTestHelper( - vocab_size=10, - rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer), - ) - - draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( - draft_and_target_probs_equal) - - sample_sizes = [10, 100, 1_000, 10_000, 100_000] - distance_wrt_reference: list[float] = [] - distance_wrt_target: list[float] = [] - - for num_samples in sample_sizes: - (reference_vs_rejsample_dist, - target_vs_rejsample_dist) = helper.run_and_compare_distributions( - draft_probs, - target_probs, - reference_probs, - num_samples, - ) - - distance_wrt_reference.append(reference_vs_rejsample_dist) - distance_wrt_target.append(target_vs_rejsample_dist) - - relative_change_in_distance_wrt_target = get_ratio_first_to_last( - distance_wrt_target) - relative_change_in_distance_wrt_reference = get_ratio_first_to_last( - distance_wrt_reference) - - print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} " - f"{reference_vs_rejsample_dist=:.05f}") - print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} " - f"{relative_change_in_distance_wrt_reference=:.02f}") - - relative_change_in_distance_wrt_target = get_ratio_first_to_last( - distance_wrt_target) - relative_change_in_distance_wrt_reference = get_ratio_first_to_last( - distance_wrt_reference) - - expected_improvement_multiplier = 20 - assert (relative_change_in_distance_wrt_target - > relative_change_in_distance_wrt_reference * - expected_improvement_multiplier) - - -def get_ratio_first_to_last(elements: list[float]) -> float: - return elements[0] / elements[-1] - - -class _CorrectnessTestHelper: - """Class that packages together logic required for the unit-level - rejection sampling correctness test. - """ - - def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler): - self.rejection_sampler = rejection_sampler - self.vocab_size = vocab_size - self.vocab_range = (0, vocab_size) - - self.rejection_sampler.init_gpu_tensors(device=0) - - # Keep test simple, use k=1 - self.k = 1 - - # Bonus tokens not used, but rejection sampler requires - # correct shape. - self.num_bonus_tokens = 1 - - def generate_probs_for_test( - self, draft_and_target_probs_equal: bool - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - draft_probs, target_probs = (F.softmax( - torch.rand(self.vocab_size, dtype=torch.float32), - dim=-1, - ) for _ in range(2)) - - num_reference_probs = 100 - reference_probs = F.softmax( - torch.rand(num_reference_probs, - self.vocab_size, - dtype=torch.float32), - dim=-1, - ) - - if draft_and_target_probs_equal: - target_probs = draft_probs.clone() - - return draft_probs, target_probs, reference_probs - - def run_and_compare_distributions(self, draft_probs: torch.Tensor, - target_probs: torch.Tensor, - reference_probs: torch.Tensor, - num_samples: int) -> tuple[float, float]: - # Sample using rejection sampling. - rej_sample_probs = self._estimate_rejection_sampling_pdf( - draft_probs, target_probs, num_samples) - - # Average distance from reference probs. - reference_vs_rejsample_dist = torch.dist( - reference_probs, - rej_sample_probs).item() / reference_probs.shape[0] - target_vs_rejsample_dist = torch.dist(target_probs, - rej_sample_probs).item() - - return reference_vs_rejsample_dist, target_vs_rejsample_dist - - def _estimate_rejection_sampling_pdf( - self, - draft_probs: torch.Tensor, - target_probs: torch.Tensor, - num_samples: int, - ) -> torch.Tensor: - # Repeat draft probs num_samples times. - draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( - num_samples, 1, 1) - - # Repeat target probs num_samples * (k + 1) times. - # Rejection sampler requires bonus token probs, but they aren't used. - target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( - num_samples, self.k + 1, 1) - - # Randomly sample draft token ids from draft probs. - draft_token_ids = torch.multinomial(draft_probs[:, 0, :], - num_samples=1, - replacement=True).reshape( - num_samples, self.k) - - # Bonus tokens not used but required. - bonus_token_ids = torch.zeros((1, self.num_bonus_tokens), - dtype=torch.int64, - device="cuda").repeat(num_samples, 1) - - # Get output tokens via rejection sampling. - output_token_ids = self.rejection_sampler(target_probs.to("cuda"), - bonus_token_ids.to("cuda"), - draft_probs.to("cuda"), - draft_token_ids.to("cuda")) - - # Remove bonus tokens - output_token_ids = output_token_ids[:, :-1].flatten() - - # Estimate probability density function - hist = torch.histogram(output_token_ids.to(dtype=torch.float, - device="cpu"), - bins=self.vocab_size, - range=self.vocab_range, - density=True) - - return hist.hist diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index b339b4b2ddf3..5a0efd98acc1 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -49,7 +49,7 @@ def test_random_sample_with_seed( sampling_params_seed_2 = copy.deepcopy(sampling_params) sampling_params_seed_2.seed = 200 - llm = vllm_model.model + llm = vllm_model.llm for prompt in example_prompts: for params in ( diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py deleted file mode 100644 index 119841470bfb..000000000000 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ /dev/null @@ -1,480 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for rejection sampling.""" - -import pytest -import torch - -from vllm.model_executor.layers.typical_acceptance_sampler import ( - TypicalAcceptanceSampler) -from vllm.model_executor.utils import set_random_seed - -CUDA_DEVICES = [f"cuda:{i}" for i in range(1)] - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - This file tests V0 internals, so set VLLM_USE_V1=0. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') - - -def get_zero_temperature_prob_dist(batch_size, k, vocab_size): - """ - Generates a fake temperature zero probability distribution. - Returns: - 1. A fake temperature zero probability distribution of shape - [batch_size, k, vocab_size] - 2. Tensor of shape [batch_size, k] containing the token ids - of the probability 1.0 tokens at each position. - """ - # Simulate temperature 0 probability distribution for target probabilities - # and create target probabilities such that only 1 token id has - # probability 1.0 - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - probs = torch.rand(batch_size, k, vocab_size) - _, zero_temperature_token_ids = torch.max(probs, dim=-1) - # set the probability of the tokens with ids in zero_temperature_token_ids - # to 1 and the rest to 0. - target_probs = torch.zeros_like(probs).scatter_( - -1, zero_temperature_token_ids.unsqueeze(-1), 1.0) - return target_probs, zero_temperature_token_ids - - -def get_draft_token_ids(batch_size: int, k: int, vocab_size: int, - token_ids_to_exclude: torch.Tensor): - """ - Returns a tensor of shape [batch_size, k] of fake draft token ids - drawn randomly from a vocab of size vocab_size. We however ensure - that token_ids from token_ids_to_exclude are excluded at the - corresponding positions. - """ - draft_token_ids = torch.empty(batch_size, k, dtype=torch.long) - for i in range(batch_size): - for j in range(k): - # Generate a random token ID excluding token_ids_to_exclude[i, j] - while True: - token_id = torch.randint(0, vocab_size, (1, )).item() - if token_id != token_ids_to_exclude[i, j]: - draft_token_ids[i, j] = token_id - break - return draft_token_ids - - -def get_acceptance_sampler( - posterior_threshold: float = 0.03, - posterior_alpha: float = 0.9, - strict_mode: bool = False, -) -> TypicalAcceptanceSampler: - """ - Initializes and returns a TypicalAcceptanceSampler. - """ - return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha, - strict_mode) - - -@pytest.mark.parametrize("k", list(range(1, 6))) -@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) -@pytest.mark.parametrize("batch_size", list(range(1, 32))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, - device: str): - """ - Tests that the TypicalAcceptancSampler forward succeeds for - different combinations of k, vocab_size, batch_size and num devices. - """ - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler() - typical_acceptance_sampler.init_gpu_tensors(device=device) - target_with_bonus_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - - -@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) -@pytest.mark.parametrize("which_token_ids", - ["bonus_token_ids", "draft_token_ids"]) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_raises_when_vocab_oob(above_or_below_vocab_range: str, - which_token_ids: str, device: str): - """ - Tests that we throw an exception of the token ids fall outside - the bound of the provided vocabulary. - """ - k = 3 - batch_size = 5 - vocab_size = 30_000 - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - target_with_bonus_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - # Verify that appropriate exceptions are thrown for out - # of bound vocabs. - oob_token_ids = None - if which_token_ids == "bonus_token_ids": - oob_token_ids = bonus_token_ids - elif which_token_ids == "draft_token_ids": - oob_token_ids = draft_token_ids - else: - raise AssertionError() - - if above_or_below_vocab_range == "above": - rogue_token_id = vocab_size + 1 - elif above_or_below_vocab_range == "below": - rogue_token_id = -1 - else: - raise AssertionError() - - oob_token_ids[0][0] = rogue_token_id - - with pytest.raises(AssertionError): - typical_acceptance_sampler(target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - - -@pytest.mark.parametrize("seed", list(range(10))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_uniform_target_distribution_accepts_all_tokens( - seed: int, device: str): - """ - Test the TypicalAcceptanceSampler with a uniform target probability - distribution. - - This test verifies that when provided with a uniform target probability - distribution, the TypicalAcceptanceSampler accepts all draft tokens. The - entropy of the uniform target distribution being high should lead to all - draft tokens being accepted. - """ - set_random_seed(seed) - k = 3 - batch_size = 5 - vocab_size = 30_000 - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - target_with_bonus_probs = torch.rand(batch_size, - k + 1, - vocab_size, - dtype=torch.float32) - draft_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - # We are using a uniform target probability distribution. - # For a uniform distribution the entropy is very high and it - # should lead to all draft tokens being accepted. Verify that. - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze()) - - assert torch.all(output_token_ids[:, :k] == draft_token_ids) - - -@pytest.mark.parametrize("seed", list(range(10))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_temperature_zero_target_distribution(seed: int, device: str): - """ - Test the TypicalAcceptanceSampler with a zero-temperature target - probability distribution. - - This test verifies that when using a zero-temperature target probability - distribution, where only one token has a probability of 1.0, the - TypicalAcceptanceSampler correctly rejects all draft tokens that do not - match this probability. Additionally, it ensures that when all draft - tokens are rejected, the sampler falls back to greedy sampling to select a - single token from the target distribution. - """ - set_random_seed(seed) - k = 3 - batch_size = 5 - vocab_size = 30_000 - torch.set_default_device(device) - - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - # Simulate temperature 0 probability distribution for target probabilities - # and create target probabilities such that only 1 token id has - # probability 1.0 - target_with_bonus_probs, zero_temperature_token_ids = \ - get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) - zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] - # Populate draft_token_ids such that they exclude the token_ids - # with probability = 1.0 - draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, - zero_temperature_token_ids) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - # The target probaility distribution is a temperature zero distribution - # with zero entropy. Since our draft token ids don't match the probability - # 1.0 tokens in the target distribution we will reject all of them and - # fallback to the greedy sampling for selecting 1 token for each sequence. - # Verify the same. - output_token_ids = typical_acceptance_sampler( - target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - assert torch.all(output_token_ids[:, -1] == -1) - assert torch.all(output_token_ids[:, 0] == zero_temperature_token_ids[:, - 0]) - - -@pytest.mark.parametrize("seed", list(range(10))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_mixed_target_distribution(seed: int, device: str): - """ - Test the TypicalAcceptanceSampler with a mixed target probability - distribution. - - This test ensures that the TypicalAcceptanceSampler handles a mixed - target probability distribution correctly. Specifically, it uses a - zero-temperature distribution for some sequences and a uniform - distribution for others. The test verifies that: - - - For sequences with a zero-temperature distribution, only the token - with a probability of 1.0 is accepted, and all other tokens are rejected. - - For sequences with a uniform distribution, all draft tokens are - accepted. - """ - set_random_seed(seed) - k = 3 - batch_size = 4 - vocab_size = 30_000 - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - # For sequences 0 and 2 set the distribution to a temperature - # zero distribution. For sequences 1 and 3 set it to a uniform - # distribution. - target_with_bonus_probs, zero_temperature_token_ids = \ - get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) - zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] - target_probs = target_with_bonus_probs[:, :-1] - draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, - zero_temperature_token_ids) - uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) - target_probs[[1, 3]] = uniform_probs - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - # verify the shape of output_token_ids - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - # For sequences 0 and 2 verify that only 1 token is accepted - # which is the token with probability 1.0 in the target distribution - # at position 0. - assert torch.all(output_token_ids[[0, 2], 1:] == -1) - assert (torch.all(output_token_ids[[0, 2], - 0] == zero_temperature_token_ids[[0, 2], - 0])) - # For sequences 1 and 3 verify that all tokens are accepted since the - # target probability distribution is uniform. In addition verify that - # we also accept the bonus tokens. - assert torch.all( - output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :]) - assert torch.all(output_token_ids[[1, 3], -1] != -1) - - -@pytest.mark.parametrize("seed", list(range(10))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_accept_tokens_partially(seed: int, device: str): - """ - Test the TypicalAcceptanceSampler's behavior when only a subset of draft - tokens should be accepted. - - This test verifies that the TypicalAcceptanceSampler correctly accepts or - rejects draft tokens based on a zero-temperature target probability - distribution. Specifically, it ensures that: - - - When all draft tokens match tokens with a probability of 1.0 in the - target distribution, all draft tokens are accepted. - - When only some draft tokens match tokens with a probability of 1.0 in - the target distribution, only those matching tokens are accepted, and the - rest are rejected. - """ - set_random_seed(seed) - k = 5 - batch_size = 1 - vocab_size = 30_000 - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - # Create a temperature zero target probability distribution and ensure - # all draft token ids correspond to the tokens with 1.0 probability. - # Verify that all of them are accepted. - target_with_bonus_probs, zero_temperature_token_ids = \ - get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) - zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] - draft_token_ids = zero_temperature_token_ids - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) - assert torch.all(output_token_ids[:, -1] == bonus_token_ids) - # Next only keep the first 2 draft tokens same as the zero temperature - # tokens. For the remaining 3 choose some other tokens. In the - # response we will expect the first 2 tokens to be the same as the - # draft tokens and the recovered token and rest as -1 - draft_token_ids_to_replace = get_draft_token_ids( - batch_size, k, vocab_size, zero_temperature_token_ids) - draft_token_ids = torch.cat( - (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) - output_token_ids = typical_acceptance_sampler( - target_with_bonus_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) - assert torch.all( - output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2]) - assert torch.all(output_token_ids[:, -3:] == -1) - - -@pytest.mark.parametrize("seed", list(range(1))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_accept_tokens_set_non_default_posteriors(seed: int, device: str): - """ - Test the TypicalAcceptanceSampler with custom posterior thresholds and - alpha values. This test verifies that by modifying the posterior - thresholds and alpha values we can change the acceptance behavior of the - sampler. - """ - set_random_seed(seed) - k = 5 - batch_size = 1 - vocab_size = 30_000 - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - # Simulate temperature 0 probability distribution for target - # probabilities and create target probabilities such that only 1 token - # id has probability 1.0 and others have a very low probability of - # 0.00001. Populate draft_token_ids such that they exclude the token_ids - # with probability = 1.0. Without any changes to the posterior thresholds - # none of the draft tokens are accepted. - target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( - batch_size, k + 1, vocab_size) - zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] - target_probs[target_probs == 0] = 0.00001 - draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, - zero_temperature_token_ids) - bonus_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, 1), - dtype=torch.int64) - output_token_ids = typical_acceptance_sampler( - target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - assert torch.all(output_token_ids[:, 1:-1] == -1) - - # Change the posterior threshold values to 0.0 so that we will - # now accept even draft tokens with very low probability in the - # target distribution. Simulate and verify the same. - typical_acceptance_sampler = TypicalAcceptanceSampler( - strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0) - typical_acceptance_sampler.init_gpu_tensors(device=device) - output_token_ids = typical_acceptance_sampler( - target_probs, - bonus_token_ids, - draft_probs=None, - draft_token_ids=draft_token_ids) - assert output_token_ids.shape[0] == batch_size - assert output_token_ids.shape[1] == (k + 1) - assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids) - assert torch.all(output_token_ids[:, -1] == bonus_token_ids) - - -@pytest.mark.parametrize("seed", list(range(10))) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_get_recovered_token_ids(seed: int, device: str): - """ - Test the TypicalAcceptanceSampler's method for generating - replacement token IDs. - - This test verifies that the `_get_recovered_token_ids` method of the - TypicalAcceptanceSampler correctly identifies the token IDs to be used - as recovered token IDs based on the target probability distribution. - Specifically, it ensures that the method correctly identifies the - tokens with the highest probability for each sequence in the batch. - """ - set_random_seed(seed) - k = 10 - batch_size = 5 - vocab_size = 30_000 - torch.set_default_device(device) - typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) - typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - expected_replacement_tokens = torch.argmax(target_probs, dim=-1) - actual_replacement_tokens = ( - typical_acceptance_sampler._get_recovered_token_ids(target_probs)) - assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/tests/spec_decode/conftest.py b/tests/spec_decode/conftest.py deleted file mode 100644 index 375b248ebeda..000000000000 --- a/tests/spec_decode/conftest.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest - - -@pytest.fixture(scope="function", autouse=True) -def use_v0_only(monkeypatch): - """ - Since this module is V0 only, set VLLM_USE_V1=0 for - all tests in the module. - """ - monkeypatch.setenv('VLLM_USE_V1', '0') diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py deleted file mode 100644 index f3fe9db3f79e..000000000000 --- a/tests/spec_decode/e2e/conftest.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Sequence -from itertools import cycle -from typing import Optional, Union - -import pytest -import torch - -from vllm import LLM, SamplingParams -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import PromptLogprobs, SampleLogprobs - -from ...models.utils import (TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs, - check_logprobs_close, check_outputs_equal) -from ...utils import RemoteOpenAIServer - -PROMPTS = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - "San Francisco is know for its", - "Facebook was created in 2004 by", - "Curious George is a", - "Python 3.11 brings improvements to its", -] - - -@pytest.fixture -def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, - test_llm_kwargs, seed): - - def generate(): - kwargs = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - llm = LLM(**kwargs) - - if seed is not None: - set_random_seed(seed) - - yield llm - - del llm - cleanup_dist_env_and_memory() - - return generate - - -def maybe_assert_ngram_worker(llm): - # Verify the proposer worker is ngram if ngram is specified. - if (llm.llm_engine.speculative_config is not None - and llm.llm_engine.speculative_config.method == "ngram"): - from vllm.spec_decode.ngram_worker import NGramWorker - assert isinstance( - llm.llm_engine.model_executor.driver_worker.proposer_worker, - NGramWorker) - - -def get_output_from_llm_generator( - llm_generator, prompts, - sampling_params) -> tuple[list[str], list[list[int]], float]: - tokens: list[str] = [] - token_ids: list[list[int]] = [] - acceptance_rate: float = -1.0 - for llm in llm_generator(): - maybe_assert_ngram_worker(llm) - - outputs = llm.generate(prompts, sampling_params, use_tqdm=True) - - token_ids = [output.outputs[0].token_ids for output in outputs] - tokens = [output.outputs[0].text for output in outputs] - - # Fetch acceptance rate if logging is enabled. - if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None): - stat_logger = stat_loggers["prometheus"] - acceptance_rate = (stat_logger.metrics. - gauge_spec_decode_draft_acceptance_rate.labels( - **stat_logger.labels)._value.get()) - del llm - - return tokens, token_ids, acceptance_rate - - -def check_logprobs_correctness( - spec_outputs: Sequence[Union[TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs]], - baseline_outputs: Sequence[Union[TokensTextLogprobs, - TokensTextLogprobsPromptLogprobs]], - disable_logprobs: bool = False, -): - """Compare sampled and prompt logprobs between baseline and spec decoding - """ - if not disable_logprobs: - return check_logprobs_close( - outputs_0_lst=baseline_outputs, - outputs_1_lst=spec_outputs, - name_0="org", - name_1="sd", - ) - - # Check correctness when disable_logprobs == True - for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): - # Check generated token logprobs. - spec_logprobs = spec_output[2] - baseline_logprobs = baseline_output[2] - _check_logprobs_when_output_disabled(spec_logprobs, - baseline_logprobs, - is_prompt_logprobs=False) - - # Check prompt logprobs too, if they exist - if len(baseline_output) == 4: - assert len(spec_output) == 4 - spec_prompt_logprobs = spec_output[3] - baseline_prompt_logprobs = baseline_output[3] - _check_logprobs_when_output_disabled(spec_prompt_logprobs, - baseline_prompt_logprobs, - is_prompt_logprobs=True) - - -def _check_logprobs_when_output_disabled( - spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], - baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], - is_prompt_logprobs: bool = False, -): - # Prompt logprobs are optional - if is_prompt_logprobs and baseline_logprobs is None: - assert spec_logprobs is None - return - - assert spec_logprobs is not None - assert baseline_logprobs is not None - assert len(spec_logprobs) == len(baseline_logprobs) - - # For each generated position of the sequence. - for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( - zip(spec_logprobs, baseline_logprobs)): - - # First prompt logprob is expected to be None - if is_prompt_logprobs and baseline_pos_logprobs is None: - assert spec_pos_logprobs is None - assert pos == 0 - continue - - assert spec_pos_logprobs is not None - assert baseline_pos_logprobs is not None - - # When disabled, the 1 logprob is returned with dummy values for the - # score and rank, but the token id should match the baseline model - assert len(spec_pos_logprobs) == 1 - (spec_pos_logprob_token_id, - spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) - assert spec_pos_logprob.rank == -1 - assert spec_pos_logprob.logprob == 0.0 - if isinstance(spec_pos_logprob_token_id, torch.Tensor): - spec_pos_logprob_token_id = spec_pos_logprob_token_id.item() - assert spec_pos_logprob_token_id in baseline_pos_logprobs - - -def run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: Optional[int] = 0, - temperature: float = 0.0, - disable_seed: bool = False, - ignore_eos: bool = True, - ensure_all_accepted: bool = False, - expected_acceptance_rate: Optional[float] = None, - logprobs: Optional[int] = None, - prompt_logprobs: Optional[int] = None, - disable_logprobs: bool = False): - - org_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **baseline_llm_kwargs, - } - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - - if disable_seed: - seed = None - - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_output_len, - seed=seed, - ignore_eos=ignore_eos, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs) - - with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - with vllm_runner(**sd_args) as vllm_model: - if ensure_all_accepted or expected_acceptance_rate is not None: - # Force log interval to be 0 to catch all metrics. - stat_logger = vllm_model.model.llm_engine.stat_loggers[ - 'prometheus'] - stat_logger.local_interval = -100 - - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - if ensure_all_accepted or expected_acceptance_rate is not None: - acceptance_rate = (stat_logger.metrics. - gauge_spec_decode_draft_acceptance_rate.labels( - **stat_logger.labels)._value.get()) - - if ensure_all_accepted: - assert True - # FIXME: ci fails to log acceptance rate. - # It works locally. - # assert acceptance_rate == 1.0 - - if expected_acceptance_rate is not None: - assert acceptance_rate >= expected_acceptance_rate - 1e-2 - - # Only pass token entries, not the logprobs - check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], - outputs_1_lst=[out[0:2] for out in sd_outputs], - name_0="org", - name_1="sd") - - # Check logprobs if requested - if logprobs is not None or prompt_logprobs is not None: - check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) - - -def run_equality_correctness_test_tp(model, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: int = 0, - temperature: float = 0.0, - logprobs: Optional[int] = None): - """Helper method that compares the outputs of both the baseline LLM and - the test LLM. It asserts greedy equality, e.g. that the outputs are exactly - the same when temperature is zero. - """ - arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs - arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs - env1 = env2 = None - - max_wait_seconds = 240 - results = [] - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - for args, env in ((arg1, env1), (arg2, env2)): - with RemoteOpenAIServer(model, - args, - env_dict=env, - max_wait_seconds=max_wait_seconds) as server: - client = server.get_client() - - completion = client.completions.create(model=model, - prompt=prompts, - max_tokens=max_output_len, - seed=seed, - temperature=temperature, - logprobs=logprobs) - - results.append({ - "test": - "seeded_sampling", - "text": [choice.text for choice in completion.choices], - "logprobs": [choice.logprobs for choice in completion.choices], - "finish_reason": - [choice.finish_reason for choice in completion.choices], - "usage": - completion.usage, - }) - - n = len(results) // 2 - arg1_results = results[:n] - arg2_results = results[n:] - # Separate logprobs to avoid asserting exact equality. - arg1_logprobs = [r.pop("logprobs") for r in arg1_results] - arg2_logprobs = [r.pop("logprobs") for r in arg2_results] - - for arg1_result, arg2_result in zip(arg1_results, arg2_results): - assert arg1_result == arg2_result, ( - f"Results for {model=} are not the same with {arg1=} and {arg2=}. " - f"{arg1_result=} != {arg2_result=}") - if logprobs: - for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs): - for l1, l2 in zip(logs1, logs2): - assert l1.tokens == l2.tokens diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py deleted file mode 100644 index 6c453879a6a6..000000000000 --- a/tests/spec_decode/e2e/test_compatibility.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm import SamplingParams - -from .conftest import get_output_from_llm_generator - - -@pytest.mark.parametrize("common_llm_kwargs", - [{ - "model": "meta-llama/Llama-3.2-1B-Instruct", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Speculative max model len > overridden max model len should raise. - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "max_model_len": 129, - }, - "max_model_len": 128, - }, - { - # Speculative max model len > draft max model len should raise. - # https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12 - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "max_model_len": 2048 + 1, - }, - }, - { - # Speculative max model len > target max model len should raise. - # https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/blob/9213176726f574b556790deb65791e0c5aa438b6/config.json#L18 - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "max_model_len": 131072 + 1, - }, - }, - ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): - """Verify that speculative decoding validates speculative_max_model_len. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, match="cannot be larger than"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py deleted file mode 100644 index 7c369feec415..000000000000 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ /dev/null @@ -1,480 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""This docstring details important information on the testing methodology. - -Most of the tests rely on "greedy equality", where we expect the output of -speculative decoding on a sequence to exactly match the output of normal non- -speculative decoding. - -Since speculative decoding with rejection sampling guarantees that the output -distribution matches the target model's output distribution (up to hardware -numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. - -However, we still need to verify below scenario could be passed: - * Batch size 1 greedy equality - * Batch size >1 greedy equality - * Test greedy equality under preemption - * Test greedy equality under various number of speculative tokens. - -With those tests, we can say at least, EAGLE would not break the -correctness for the target model outputs. -""" - -import pytest - -from .conftest import run_equality_correctness_test - -# main model -MAIN_MODEL = "JackFram/llama-68m" - -# speculative model -SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random" - -# max. number of speculative tokens: this corresponds to -# num_heads in the config.json of the speculator model. -MAX_SPEC_TOKENS = 4 - -# precision -PRECISION = "float32" - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int): - - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs": False, - }, -}, { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs": True, - }, -}]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int, - logprobs: int): - - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "enforce_eager": False, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -def test_eagle_e2e_greedy_correctness_cuda_graph( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality with cuda graph enabled and different - batch sizes.""" - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 8, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 128, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_eagle_e2e_greedy_correctness_with_preemption( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality, even when some sequences are preempted mid- - generation. - """ - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": k, - }, - } - # Try a range of num. speculative tokens - for k in range(1, 1 + MAX_SPEC_TOKENS) - ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_eagle_different_k(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that eagle speculative decoding produces exact equality - to without spec decode with different values of num_speculative_tokens. - """ - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_by_batch_size": 4, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_eagle_disable_queue(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that eagle speculative decoding produces exact equality - to without spec decode when speculation is disabled for large - batch sizes. - """ - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": "float16", - - # Main model - "model_name": "meta-llama/Llama-2-7b-chat-hf", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "yuhuili/EAGLE-llama2-chat-7B", - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize("seed", [1]) -def test_llama2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - output_len: int, seed: int): - - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # 2 for small prompt, 256//16 for generated. - "num_gpu_blocks_override": 2 + 256 // 16, - "max_model_len": (2 + 256 // 16) * 16, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": "float16", - - # Main model - "model_name": "meta-llama/Meta-Llama-3-8B-Instruct", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize("seed", [1]) -def test_llama3_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - output_len: int, seed: int): - - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # 2 for small prompt, 256//16 for generated. - "num_gpu_blocks_override": 2 + 256 // 16, - "max_model_len": (2 + 256 // 16) * 16, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": "float16", - - # Main model - "model_name": "Qwen/Qwen2-7B-Instruct", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "yuhuili/EAGLE-Qwen2-7B-Instruct", - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize("seed", [1]) -def test_qwen2_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - output_len: int, seed: int): - - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0) - - -if __name__ == "__main__": - import pytest - pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py deleted file mode 100644 index f15a9224c003..000000000000 --- a/tests/spec_decode/e2e/test_integration.py +++ /dev/null @@ -1,161 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests which cover integration of the speculative decoding framework with -other features, e.g. cuda graphs. -""" - -import pytest - -from .conftest import run_equality_correctness_test - -MAIN_MODEL = "JackFram/llama-68m" - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Verify equality when cuda graphs allowed. - "enforce_eager": False, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - # Identical models. - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("output_len", [32]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int): - """Verify spec decode equality when cuda graphs are enabled. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", []) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - # Explicitly specify draft model quantization - { - "speculative_config": { - "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", - "num_speculative_tokens": 5, - "quantization": "gptq", - }, - }, - # Explicitly specify GPTQ-based draft model to use marlin quantization - { - "speculative_config": { - "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", - "num_speculative_tokens": 5, - "quantization": "marlin", - }, - }, - # Not explicitly specify draft model quantization - { - "speculative_config": { - "model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", - "num_speculative_tokens": 5, - "quantization": None, - }, - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize("seed", [1]) -def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, seed: int): - """Verify spec decode works well with draft model quantization configs. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=32, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": MAIN_MODEL, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_mqa_scorer": True, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, seed: int): - """Verify that speculative decoding generates the same output - with batch expansion scorer and mqa scorer. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py deleted file mode 100644 index a18be80c50dd..000000000000 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ /dev/null @@ -1,247 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests which cover integration of the speculative decoding framework with -tensor parallelism. -""" - -import json -from typing import Optional - -import pytest -import torch - -from vllm.platforms import current_platform - -from .conftest import run_equality_correctness_test_tp - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [[ - # Skip cuda graph recording for fast test. - "--enforce-eager", - "--tensor-parallel-size", - "2" - ]]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) -@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize("test_llm_kwargs", [ - [ - "--speculative_config", - json.dumps({ - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - }), - ], - [ - "--speculative_config", - json.dumps({ - "model": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - }), - ], -]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int): - """Verify greedy equality when tensor parallelism is used. - """ - if current_platform.is_rocm(): - pytest.skip("hip is not well-supported yet") - run_equality_correctness_test_tp("JackFram/llama-68m", - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0) - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [[ - # Skip cuda graph recording for fast test. - "--enforce-eager", - "--tensor_parallel_size", - "2", - - # precision - "--dtype", - "bfloat16", - ]]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) -@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize( - "model, test_llm_kwargs", - [("JackFram/llama-68m", [ - "--speculative_config", - json.dumps({ - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "draft_tensor_parallel_size": 1, - }), - ]), - ("ibm-granite/granite-3b-code-instruct", [ - "--speculative_config", - json.dumps({ - "model": "ibm-granite/granite-3b-code-instruct", - "num_speculative_tokens": 5, - "draft_tensor_parallel_size": 1, - }), - ])]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize("seed", [1]) -def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - seed: int): - """Verify spec decode works well with smaller tp for draft models. - """ - run_equality_correctness_test_tp(model, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=32, - seed=seed, - temperature=0.0) - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [[ - # Skip cuda graph recording for fast test. - "--enforce-eager", - "--tensor_parallel_size", - "2", - - # precision - "--dtype", - "bfloat16", - ]]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [["--enable-chunked-prefill", "False"], - [ - "--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4", - "--max-num-seqs", "4" - ]]) -@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize("model, test_llm_kwargs", - [("JackFram/llama-68m", [ - "--speculative_config", - json.dumps({ - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - }), - ]), - ("JackFram/llama-68m", [ - "--speculative_config", - json.dumps({ - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "draft_tensor_parallel_size": 1, - }), - ])]) -@pytest.mark.parametrize("logprobs", [None]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - logprobs: Optional[int], - batch_size: int, seed: int): - """Verify spec decode works well with same and different TP size for - the draft model with chunked prefill. - """ - run_equality_correctness_test_tp(model, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=32, - seed=seed, - temperature=0.0, - logprobs=logprobs) - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, - reason="Need at least 2 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [[ - # Skip cuda graph recording for fast test. - "--enforce-eager", - "--tensor_parallel_size", - "2", - - # precision - "--dtype", - "bfloat16", - ]]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [["--enable-chunked-prefill", "False"], - [ - "--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4", - "--max-num-seqs", "4" - ]]) -@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize("model, test_llm_kwargs", - [("JackFram/llama-68m", [ - "--speculative_config", - json.dumps({ - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs": False, - }), - ]), - ("JackFram/llama-68m", [ - "--speculative_config", - json.dumps({ - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "draft_tensor_parallel_size": 1, - "disable_logprobs": False, - }), - ])]) -@pytest.mark.parametrize("logprobs", [2]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_chunked_prefill_tp2_with_logprobs( - model, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, logprobs: Optional[int], - batch_size: int, seed: int): - """Verify spec decode works well with same and different TP size for - the draft model with chunked prefill. - """ - run_equality_correctness_test_tp(model, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=32, - seed=seed, - temperature=0.0, - logprobs=logprobs) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py deleted file mode 100644 index 039eec8fd2cc..000000000000 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests which cover integration of the speculative decoding framework with -tensor parallelism. -""" - -import json - -import openai -import pytest -import torch - -from .conftest import run_equality_correctness_test_tp - -MAIN_MODEL = "JackFram/llama-68m" -SPEC_MODEL = "JackFram/llama-68m" - - -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [[ - # Skip cuda graph recording for fast test. - "--enforce_eager", - "--tensor-parallel-size", - "4", - ]]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - [], -]) -@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - #TODO(wooyeon): add spec_draft_dp=2 case - [ - "--speculative_config", - json.dumps({ - "model": f"{SPEC_MODEL}", - "num_speculative_tokens": 5, - "draft_tensor_parallel_size": 1, - }), - ], - ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize("seed", [1]) -def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - seed: int): - """Verify spec decode works well with smaller tp for draft models. - """ - run_equality_correctness_test_tp(MAIN_MODEL, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=32, - seed=seed, - temperature=0.0) - - -@pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 4 GPUs to run the test.") -@pytest.mark.parametrize( - "common_llm_kwargs", - [[ - - # Skip cuda graph recording for fast test. - "--enforce-eager", - "--tensor-parallel-size", - "4", - ]]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [[]]) -@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - [ - # Artificially limit the draft model max model len; this forces vLLM - # to skip speculation once the sequences grow beyond 32-k tokens. - "--speculative_config", - json.dumps({ - "model": f"{SPEC_MODEL}", - "num_speculative_tokens": 5, - "max_model_len": 32, - }), - ], - ]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize( - "output_len", - [ - # This must be a good bit larger than speculative_max_model_len so that - # we can test the case where all seqs are skipped, but still small to - # ensure fast test. - 64, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_skip_speculation(common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int): - """Verify job failure with RuntimeError when all sequences skip speculation. - We do this by setting the max model len of the draft model to an - artificially low value, such that when the sequences grow beyond it, they - are skipped in speculative decoding. - - TODO: fix it to pass without raising Error. (#5814) - """ - with pytest.raises( - (openai.APIConnectionError, openai.InternalServerError)): - run_equality_correctness_test_tp(MAIN_MODEL, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py deleted file mode 100644 index 4de7ee05605a..000000000000 --- a/tests/spec_decode/e2e/test_logprobs.py +++ /dev/null @@ -1,315 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from itertools import cycle - -import pytest - -from vllm import SamplingParams - -from ..utils import maybe_enable_chunked_prefill -from .conftest import run_equality_correctness_test - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs": False, - }, -}, { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs": True, - }, -}]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 7, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12]) -def test_logprobs_equality(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int, logprobs: int, prefill_chunk_size: int): - """Verify output logprobs are equal with and without speculative decoding, - as well as with and without chunked prefill. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs": False, - }, -}, { - "speculative_config": { - "model": "JackFram/llama-160m", - "num_speculative_tokens": 6, - "disable_logprobs": False, - }, -}]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -def test_logprobs_different_k(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - output_len: int, seed: int, logprobs: int): - """Veriy logprob greedy equality with different speculation lens. - """ - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [{ - "speculative_config": { - "model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs": False, - # Artificially limit the draft model max model len; this forces - # vLLM to skip speculation once the sequences grow beyond 32-k - # tokens. - "max_model_len": 32, - }, - }]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1]) -def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int, logprobs: int): - """Verify logprobs greedy equality when some sequences skip speculation. - """ - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - "disable_logprobs": False, - }, -}]) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [6]) -def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int, logprobs: int): - """Verify at least one logprob result has num_logprobs+1, which tests the - case where the sampled token is not in top-k logprobs. - - Ideally, this test should validate equality with non-spec by getting - logprobs. This is left as future improvement. - """ - temperature = 1.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - "San Francisco is know for its", - "Facebook was created in 2004 by", - "Curious George is a", - "Python 3.11 brings improvements to its", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - logprobs=logprobs, - ) - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - with vllm_runner(**sd_args) as vllm_model: - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - num_returned_logprobs = [ - len(seq_logprobs) for seq_logprobs in sd_outputs[-1] - ] - - # Assert one of the returned logprobs has > num_logprobs (indicating the - # sampled token is not in top-k). - assert any( - [num_returned > logprobs for num_returned in num_returned_logprobs]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs": True, - }, -}]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("logprobs", [0]) -def test_logprobs_disabled(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int, logprobs: int): - """Check the behavior when logprobs are disabled. - Token choices should match with the base model. - """ - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py deleted file mode 100644 index bc9501bd5737..000000000000 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ /dev/null @@ -1,417 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""This docstring details important information on the testing methodology. - -Most of the tests rely on "greedy equality", where we expect the output of -speculative decoding on a sequence to exactly match the output of normal non- -speculative decoding. - -Since speculative decoding with rejection sampling guarantees that the output -distribution matches the target model's output distribution (up to hardware -numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. - -However, we still need to verify below scenario could be passed: - * Batch size 1 greedy equality - * Batch size >1 greedy equality - * Test greedy equality under preemption - * Test greedy equality under various number of speculative tokens. - -With those tests, we can say at least, Medusa would not break the -correctness for the target model outputs. -""" - -import pytest - -from ..utils import maybe_enable_chunked_prefill -from .conftest import run_equality_correctness_test - -# main model -# lmsys/vicuna-7b-v1.3 was to be used but it's causing -# OOM in CI pipeline, so using a smaller model. -MAIN_MODEL = "JackFram/llama-68m" - -# speculative model -SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random" - -# max number of speculative tokens: this corresponds to -# num_heads in the config.json of the speculator model. -MAX_SPEC_TOKENS = 5 - -# precision -PRECISION = "float32" - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int, prefill_chunk_size: int): - """Verify greedy equality with different batch size.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs": False, - }, - }, - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs": True, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 8, -]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int, logprobs: int, - prefill_chunk_size: int): - """Verify greedy equality with different batch size.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "enforce_eager": False, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_medusa_e2e_greedy_correctness_cuda_graph( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int, prefill_chunk_size: int): - """Verify greedy equality with cuda graph enabled and different - batch sizes.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 16, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 128, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_medusa_e2e_greedy_correctness_with_preemption( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int, prefill_chunk_size: int): - """Verify greedy equality, even when some sequences are preempted mid- - generation. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": k, - }, - } - # Try a range of num. speculative tokens - for k in range(1, 1 + MAX_SPEC_TOKENS) - ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_medusa_different_k(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int, prefill_chunk_size: int): - """Verify that medusa speculative decoding produces exact equality - to without spec decode with different values of num_speculative_tokens. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_by_batch_size": 4, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - output_len: int, seed: int, - prefill_chunk_size: int): - """Verify that medusa speculative decoding produces exact equality - to without spec decode when speculation is disabled for large - batch sizes. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_by_batch_size": 4, - "disable_mqa_scorer": True, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, seed: int, prefill_chunk_size: int): - """Verify that speculative decoding generates the same output - with batch expansion scorer and mqa scorer. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -if __name__ == "__main__": - import pytest - pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py deleted file mode 100644 index 0e41d93eaa19..000000000000 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ /dev/null @@ -1,533 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""This docstring details important information on the testing methodology. - -Most of the tests rely on "greedy equality", where we expect the output of -speculative decoding on a sequence to exactly match the output of normal non- -speculative decoding. - -Since speculative decoding with rejection sampling guarantees that the output -distribution matches the target model's output distribution (up to hardware -numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. - -However, we still need to verify below scenario could be passed: - * Batch size 1 greedy equality - * Batch size >1 greedy equality - * Test greedy equality under preemption - * Test greedy equality under various number of speculative tokens. - -With those tests, we can say at least, MLPSpeculator would not break the -correctness for the target model outputs. -""" - -from unittest.mock import patch - -import pytest - -from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size - -from ..utils import maybe_enable_chunked_prefill -from .conftest import run_equality_correctness_test - -# main model -MAIN_MODEL = "JackFram/llama-160m" - -# speculative model -SPEC_MODEL = "ibm-ai-platform/llama-160m-accelerator" - -# max. number of speculative tokens: this corresponds to -# n_predict in the config.json of the speculator model. -MAX_SPEC_TOKENS = 3 - -# precision -PRECISION = "float32" - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [4, 32]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) -def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int, prefill_chunk_size: int): - """Verify greedy equality with different batch size.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - "disable_logprobs": False, - }, - }, - { - "speculative_config": { - "model": SPEC_MODEL, - "disable_logprobs": True, - }, - }, -]) -@pytest.mark.parametrize("output_len", [8]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int, - logprobs: int, prefill_chunk_size: int): - """Verify greedy equality with different batch size.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - # NOTE Test is sensitive enough st if we don't enable chunked prefill - # scheduling on baseline too, we get slightly different logprobs, ending - # up sampling different tokens at the tail (ie top tokens don't change). - # TL;DR: sd+cp == org+cp but sd+cp != org..is this expected? - maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - }, - }, -]) -@pytest.mark.parametrize("output_len", [2048]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - prefill_chunk_size: int, seed: int): - """Verify acceptance rate with different batch size and large output - length.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - temperature=0.0, - seed=seed, - expected_acceptance_rate=0.48) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - - # Speculative config - "speculative_config": { - "model": SPEC_MODEL, - }, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) -@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) -@pytest.mark.parametrize("output_len", [64]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("temperature", [1.0]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -@pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - temperature: float, - prefill_chunk_size: int, seed: int): - """Verify seeded runs produce the same output.""" - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - temperature=temperature, - seed=seed) - - # Ensure this same test does fail if we _don't_ include per-request seeds - with pytest.raises(AssertionError): - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - temperature=temperature, - seed=seed, - disable_seed=True) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 16, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 128, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -@pytest.mark.parametrize("seed", [1]) -def test_mlp_e2e_greedy_correctness_with_preemption( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - prefill_chunk_size: int, seed: int): - """Verify greedy equality, even when some sequences are preempted mid- - generation. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 16, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": SPEC_MODEL, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 128, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -def test_mlp_e2e_greedy_correctness_with_padding( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - prefill_chunk_size: int, seed: int): - """Verify greedy equality when the vocab dimension is padded - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - - # Default pad_to is 64, test model has vocab_size of 32000 - def patched_pad_vocab_size(vocab_size, pad_to=None): - return pad_vocab_size(vocab_size, pad_to=32064) - - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", - patched_pad_vocab_size): - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "model": SPEC_MODEL, - "num_speculative_tokens": k, - }, - } - # Try a range of num. speculative tokens - for k in range(1, 1 + MAX_SPEC_TOKENS) - ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -@pytest.mark.parametrize("seed", [1]) -def test_mlp_different_k(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - prefill_chunk_size: int, seed: int, output_len: int): - """Verify that mlp speculative decoding produces exact equality - to without spec decode with different values of num_speculative_tokens. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": SPEC_MODEL, - "disable_by_batch_size": 4, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -# Speculative decoding is disabled when sequences reach decoding and the batch -# consists of single-token requests. Hence we set `max_num_seqs` -# >= `speculative_disable_by_batch_size` to test feature interaction. -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -@pytest.mark.parametrize("seed", [1]) -def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - prefill_chunk_size: int, seed: int, - output_len: int): - """Verify that mlp speculative decoding produces exact equality - to without spec decode when speculation is disabled for large - batch sizes. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": MAIN_MODEL, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": SPEC_MODEL, - "disable_mqa_scorer": True, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -@pytest.mark.parametrize("seed", [1]) -def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, prefill_chunk_size: int, seed: int): - """Verify that speculative decoding generates the same output - with batch expansion scorer and mqa scorer. - """ - maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) diff --git a/tests/spec_decode/e2e/test_mtp_correctness.py b/tests/spec_decode/e2e/test_mtp_correctness.py deleted file mode 100644 index d9c7be8ffe71..000000000000 --- a/tests/spec_decode/e2e/test_mtp_correctness.py +++ /dev/null @@ -1,333 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""This docstring details important information on the testing methodology. - -Most of the tests rely on "greedy equality", where we expect the output of -speculative decoding on a sequence to exactly match the output of normal non- -speculative decoding. - -Since speculative decoding with rejection sampling guarantees that the output -distribution matches the target model's output distribution (up to hardware -numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. - -However, we still need to verify below scenario could be passed: - * Batch size 1 greedy equality - * Batch size >1 greedy equality - * Test greedy equality under preemption - * Test greedy equality under various number of speculative tokens. - -With those tests, we can say at least, mtp would not break the -correctness for the target model outputs. -""" - -import pytest - -from .conftest import run_equality_correctness_test - -# main model -MAIN_MODEL = "luccafong/deepseek_mtp_main_random" - -# max. number of speculative tokens: this corresponds to -# num_nextn_predict_layers in the config.json of the speculator model. -MAX_SPEC_TOKENS = 1 - -# precision -PRECISION = "bfloat16" - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - - # GPU memory utilization - "gpu_memory_utilization": 0.85 - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -def test_mtp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int): - - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - - # GPU memory utilization - "gpu_memory_utilization": 0.85 - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs": False, - }, - }, - { - "speculative_config": { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_logprobs": True, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -def test_mtp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int, - logprobs: int): - - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "enforce_eager": False, - - # Print spec metrics. - "disable_log_stats": False, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - "gpu_memory_utilization": 0.85 - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 128, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -def test_mtp_e2e_greedy_correctness_cuda_graph(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - output_len: int, seed: int): - """Verify greedy equality with cuda graph enabled and different - batch sizes.""" - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 8, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - - # GPU memory utilization - "gpu_memory_utilization": 0.9 - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "num_speculative_tokens": MAX_SPEC_TOKENS, - }, - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 128, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_mtp_e2e_greedy_correctness_with_preemption( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality, even when some sequences are preempted mid- - generation. - """ - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - - # GPU memory utilization - "gpu_memory_utilization": 0.9 - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "num_speculative_tokens": k, - }, - } - # Try a range of num. speculative tokens - for k in range(1, 1 + MAX_SPEC_TOKENS) - ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_mtp_different_k(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that mtp speculative decoding produces exact equality - to without spec decode with different values of num_speculative_tokens. - """ - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Precision - "dtype": PRECISION, - - # Main model - "model_name": MAIN_MODEL, - - # GPU memory utilization - "gpu_memory_utilization": 0.9 - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "num_speculative_tokens": MAX_SPEC_TOKENS, - "disable_by_batch_size": 4 - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_mtp_disable_queue(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that mtp speculative decoding produces exact equality - to without spec decode when speculation is disabled for large - batch sizes. - """ - run_equality_correctness_test(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size, output_len, seed) - - -if __name__ == "__main__": - import pytest - pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py deleted file mode 100644 index ccc8e745ab37..000000000000 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ /dev/null @@ -1,842 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""The tests in this file verify end-to-end speculative decoding correctness. - -This docstring details important information on the testing methodology. - -Most of the tests rely on "greedy equality", where we expect the output of -speculative decoding on a sequence to exactly match the output of normal non- -speculative decoding. - -Since speculative decoding with rejection sampling guarantees that the output -distribution matches the target model's output distribution (up to hardware -numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. This gives us good coverage of temp=0. - -At temp=0, the TypicalAcceptanceSampler ensures that only the tokens with the -highest probability in the target distribution are accepted. Therefore, we can -expect greedy equality for the TypicalAcceptanceSampler at temp=0. - -For temp>0, we rely on unit tests on the rejection sampler to verify that the -output distribution is the same with spec decode vs. no spec decode (this would -be prohibitively expensive to run with a real model). Similarly, for the -TypicalAcceptance sampler also, we rely on unit tests to validate temp>0 -test cases. - -NOTE: Speculative decoding's distribution equality requires that the measured -distributions of the target model and proposal model be deterministic given the -same input. vLLM largely guarantees this. - -@cadedaniel has seen cases where the output probabilities of a draft/target -model change slightly with certain batch sizes or prompts, even with Torch -determinism flags set. It is unclear if this is a bug in vLLM, due to non- -determinism in on-device batched operations, a bug in vLLM's spec decode -implementation, or the "hardware numerics" limitations. Either way, rejection -sampling ensures the output distribution matches the target model, but it breaks -greedy-equality tests for those batch sizes/prompts. -""" - -from itertools import cycle - -import pytest -from transformers import AutoTokenizer - -from vllm import SamplingParams - -from ...utils import create_new_process_for_each_test -from .conftest import (get_output_from_llm_generator, - run_equality_correctness_test) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - # Note this is repeated in the test body; to initialize a tokenizer. - "model": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - # Chunked prefill enabled with small value - # to make sure we get mixed batches. - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, - { - # Verify the detokenizer assertions in the test work when spec - # decode is disabled. - }, - ]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_with_detokenization(test_llm_generator, - batch_size: int): - """Run generation with speculative decoding on a batch. Verify the engine - generates the correct number of tokens (via ignore_eos=True), and that the - detokenization matches HF transformers. - """ - output_len = 32 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - batch_tokens, batch_token_ids, _ = get_output_from_llm_generator( - test_llm_generator, prompts, sampling_params) - - # Expect a generation for each prompt in the batch. - assert len(batch_token_ids) == len(prompts) - - # Expect each generation to have expected number of tokens (note ignore_eos - # is True). - assert [len(token_ids) - for token_ids in batch_token_ids] == ([output_len] * batch_size) - - # Expect detokenized string to match. - tok = AutoTokenizer.from_pretrained("JackFram/llama-68m") - for actual_tokens, actual_token_ids in zip(batch_tokens, batch_token_ids): - expected_tokens = tok.decode(actual_token_ids) - print(f"{actual_token_ids=}") - assert actual_tokens.strip() == expected_tokens.strip() - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - # Try two different tiny base models. - # Note that one is equal to the draft model, another isn't. - { - "model_name": "JackFram/llama-68m", - }, - { - "model_name": "JackFram/llama-160m", - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "disable_logprobs": False, - }, - "enable_chunked_prefill": False, -}, { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 3, - "disable_logprobs": False, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, -}]) -@pytest.mark.parametrize( - "output_len", - [ - # Use long output len for the small model test. - 10, - ]) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality on a tiny model with batch size of one. - - Since this test is cheaper than other e2e correctness tests, we generate - with a higher output_len. - - When the draft model is the same as the target model, we further check - whether all speculative tokens are accepted. - """ - ensure_all_accepted = per_test_common_llm_kwargs.get( - "model_name") == test_llm_kwargs.get("speculative_config")["model"] - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - prompt_logprobs=2, - logprobs=2, - disable_logprobs=False, - temperature=0.0, - ensure_all_accepted=ensure_all_accepted) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - # Try two different tiny base models. - # Note that one is equal to the draft model, another isn't. - { - "model_name": "JackFram/llama-68m", - }, - { - "model_name": "JackFram/llama-160m", - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 256, - ]) -@pytest.mark.parametrize("batch_size", [64]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality on a tiny model and large batch size. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - # Try two different tiny base models. - # Note that one is equal to the draft model, another isn't. - { - "model_name": "JackFram/llama-68m", - }, - { - "model_name": "JackFram/llama-160m", - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize("max_output_len", [ - 256, -]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - max_output_len: int, seed: int): - """Verify greedy equality on a tiny model, with a large batch size, and when - sampling respects the EOS token. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len, - seed=seed, - temperature=0.0, - ignore_eos=False) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # A "real" model (not tiny). - "model_name": "meta-llama/Llama-2-7b-chat-hf", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize("batch_size", [1]) -@pytest.mark.parametrize( - "output_len", - [ - # Use decently long output len for a high quality test. - 256, - ]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_greedy_correctness_real_model_bs1( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality on a "real" model and batch size of 1. This is - separate from large BS tests to make identifying the source of bugs easier. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # A "real" model (not tiny). - "model_name": "meta-llama/Llama-2-7b-chat-hf", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize("batch_size", [32]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 64, - ]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality with a "real" model on a nontrivial batch size. - This is the closest test to a real production workload. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 16, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "model_name": "JackFram/llama-160m", - }, -]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 256, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_e2e_greedy_correctness_with_preemption( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality, even when some sequences are preempted mid- - generation. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize( - "per_test_common_llm_kwargs", - [ - # https://github.com/triton-lang/triton/issues/2266 tl.dot - # doesn't support embedding < 16 - { - "block_size": 16, - }, - { - "block_size": 32, - }, - ]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int): - """Verify greedy equality over different block sizes. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - - # Artificially limit the draft model max model len; this forces vLLM - # to skip speculation once the sequences grow beyond 32-k tokens. - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "max_model_len": 32, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "max_model_len": 32, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, - }, - ]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize( - "output_len", - [ - # This must be a good bit larger than speculative_max_model_len so that - # we can test the case where all seqs are skipped, but still small to - # ensure fast test. - 64, - ]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_skip_speculation(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality when some (or all) sequences skip speculation. - We do this by setting the max model len of the draft model to an - artificially low value, such that when the sequences grow beyond it, they - are skipped in speculative decoding. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "disable_by_batch_size": 2, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "disable_by_batch_size": 2, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, - }, -]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("output_len", [10]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_disable_speculation(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality when all sequences disable speculation. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": k, - }, - "enable_chunked_prefill": False, - } - # Try a range of common k, as well as large speculation. - for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63] - ] + [{ - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": k, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4, - } for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, - output_len: int, seed: int): - """Verify that speculative decoding produces exact equality to without spec - decode with many different values of k. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-160m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": k, - "acceptance_method": "typical_acceptance_sampler", - }, - "enable_chunked_prefill": False - } - # Try a range of common k. - for k in [1, 2, 3] - ] + [{ - "speculative_config": { - "model": "JackFram/llama-68m", - "num_speculative_tokens": k, - "acceptance_method": "typical_acceptance_sampler", - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - } for k in [1, 2, 3]]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -@create_new_process_for_each_test() -def test_typical_acceptance_sampling(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - seed: int): - """Verify that speculative decoding produces exact equality to without spec - decode with TypicalAcceptanceSampler as the draft token acceptance - sampling method. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py deleted file mode 100644 index 58d1a6ca7add..000000000000 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ /dev/null @@ -1,392 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""This docstring details important information on the testing methodology. - -Most of the tests rely on "greedy equality", where we expect the output of -speculative decoding on a sequence to exactly match the output of normal non- -speculative decoding. - -Since speculative decoding with rejection sampling guarantees that the output -distribution matches the target model's output distribution (up to hardware -numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy -equality. - -For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding, -and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775. -Since there is no model is needed for generate the proposal, we could make -the testcase much simpler than drafter multi-step one. - -However, we still need to verify below scenario could be passed: - * Batch size 1 greedy equality - * Batch size >1 greedy equality - * Test greedy equality under preemption - * Test greedy equality under various ngram sizes / speculative sizes - -With those tests, we can say at least, ngram spec would not break the -correctness for the target model outputs. -""" - -import pytest - -from ..utils import maybe_enable_chunked_prefill -from .conftest import run_equality_correctness_test - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "model_name": "JackFram/llama-68m", - }, -]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_mqa_scorer": False, - }, - }, - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_mqa_scorer": True, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 256, -]) -@pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) -@pytest.mark.parametrize("seed", [1]) -def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, - prefill_chunk_size: int, seed: int): - """Verify greedy equality on a tiny model with different batch size.""" - maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs) - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # Print spec metrics. - "disable_log_stats": False, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "model_name": "JackFram/llama-68m", - }, -]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_logprobs": False, - }, - }, - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_logprobs": True, - }, - }, -]) -@pytest.mark.parametrize("output_len", [ - 8, -]) -@pytest.mark.parametrize("batch_size", [8]) -@pytest.mark.parametrize("seed", [1]) -@pytest.mark.parametrize("logprobs", [1, 6]) -def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, output_len: int, seed: int, - logprobs: int): - """Verify greedy equality on a tiny model with different batch size.""" - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0, - logprobs=logprobs, - prompt_logprobs=logprobs, - disable_logprobs=test_llm_kwargs["speculative_config"] - ["disable_logprobs"]) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "block_size": 16, - # 2 for small prompt, 256//8 for generated. - "num_gpu_blocks_override": 2 + 256 // 8, - "max_model_len": (2 + 256 // 8) * 8, - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "model_name": "JackFram/llama-160m", - }, -]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [ - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - }, - "enable_chunked_prefill": False, - }, - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_mqa_scorer": True, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 - }, -]) -@pytest.mark.parametrize( - "output_len", - [ - # Use small output len for fast test. - 256, - ]) -@pytest.mark.parametrize("batch_size", [4]) -@pytest.mark.parametrize("seed", [1]) -def test_ngram_e2e_greedy_correctness_with_preemption( - vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify greedy equality, even when some sequences are preempted mid- - generation. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - temperature=0, - seed=seed) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize( - "test_llm_kwargs", - [ - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": k, - "prompt_lookup_max": 3, - }, - } - # Try a range of common k, as well as large speculation. - for k in [1, 3, 5] - ] + [ - { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": k, - "prompt_lookup_max": 1, - }, - } - # Try a range of common k, as well as large speculation. - for k in [1, 3, 5] - ]) -@pytest.mark.parametrize("batch_size", [2]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_ngram_different_k(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that ngram speculative decoding produces exact equality - to without spec decode with many different values of k and - different ngram prompt_lookup_max. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_by_batch_size": 4 - }, -}, { - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_by_batch_size": 4, - "disable_mqa_scorer": True, - }, - "enable_chunked_prefill": True, - "max_num_batched_tokens": 4, - "max_num_seqs": 4 -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that ngram speculative decoding produces exact equality - to without spec decode with many different values of k and - different ngram prompt_lookup_max. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # The original model is float32, keep it for numerical stability. - "dtype": "float32", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "speculative_config": { - "method": "ngram", - "num_speculative_tokens": 5, - "prompt_lookup_max": 3, - "disable_mqa_scorer": True, - }, -}]) -@pytest.mark.parametrize("batch_size", [1, 5]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 32, - ]) -@pytest.mark.parametrize("seed", [1]) -def test_ngram_scorer(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, output_len: int, - seed: int): - """Verify that ngram speculative decoding generates the same output - with batch expansion scorer and mqa scorer. - """ - run_equality_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - seed=seed, - temperature=0.0) diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py deleted file mode 100644 index 4cf373809dba..000000000000 --- a/tests/spec_decode/e2e/test_seed.py +++ /dev/null @@ -1,70 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from .conftest import run_equality_correctness_test - -# main model -MAIN_MODEL = "JackFram/llama-68m" - -# speculative model -SPEC_MODEL = "JackFram/llama-160m" - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model_name": "JackFram/llama-68m", - - # Skip cuda graph recording for fast test. - "enforce_eager": True, - - # speculative config - "speculative_config": { - "model": "JackFram/llama-160m", - "num_speculative_tokens": 3, - }, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) -@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) -@pytest.mark.parametrize("batch_size", [1, 8, 32]) -@pytest.mark.parametrize("temperature", [0.1, 1.0]) -@pytest.mark.parametrize( - "output_len", - [ - # Use smaller output len for fast test. - 20, - ]) -def test_seeded_consistency(vllm_runner, common_llm_kwargs, - per_test_common_llm_kwargs, baseline_llm_kwargs, - test_llm_kwargs, batch_size: int, - temperature: float, output_len: int): - """Verify outputs are consistent across multiple runs with same seed - """ - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - temperature=temperature, - disable_seed=False, - ) - - # Ensure this same test does fail if we _don't_ include per-request seeds - with pytest.raises(AssertionError): - run_equality_correctness_test( - vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - max_output_len=output_len, - temperature=temperature, - disable_seed=True, - ) diff --git a/tests/spec_decode/test_batch_expansion.py b/tests/spec_decode/test_batch_expansion.py deleted file mode 100644 index d20c549b0905..000000000000 --- a/tests/spec_decode/test_batch_expansion.py +++ /dev/null @@ -1,110 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest -import torch - -from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer - -from .utils import create_seq_group_metadata_from_prompts, mock_worker - - -@pytest.mark.parametrize('num_target_seq_ids', [100]) -@pytest.mark.skip_global_cleanup -def test_create_target_seq_id_iterator(num_target_seq_ids: int): - """Verify all new sequence ids are greater than all input - seq ids. - """ - scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) - - all_seq_ids = [ - [1, 3, 5, 7], - list(range(100)) + [0], - [100], - ] - - for seq_ids in all_seq_ids: - max_seq_id = max(seq_ids) - iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access - for _ in range(num_target_seq_ids): - assert next(iterator) > max_seq_id - - -@pytest.mark.parametrize('k', [1, 2, 6]) -@pytest.mark.skip_global_cleanup -def test_get_token_ids_to_score(k: int): - """Verify correct tokens are selected for scoring. - """ - proposal_token_ids = torch.tensor( - list(range(k)), - dtype=torch.int64, - device='cuda', - ) - - expected_output: list[list[int]] = [ - [], - ] - for i in range(proposal_token_ids.shape[0]): - expected_output.append(proposal_token_ids[:i + 1].tolist()) - - scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) - actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access - - actual_output = [ - x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output - ] - - assert actual_output == expected_output - - -@pytest.mark.parametrize('k', [1, 2, 6]) -@pytest.mark.skip_global_cleanup -def test_create_single_target_seq_group_metadata(k: int): - """Verify correct creation of a batch-expanded seq group metadata. - """ - - prompt_tokens = [1, 2, 3] - prev_output_tokens = [4, 5, 6] - - token_ids = list(range(k)) - - num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1 - - final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len( - token_ids) - - block_size = 32 - input_seq_group_metadata = create_seq_group_metadata_from_prompts( - [prompt_tokens], 2048 // block_size, block_size, [final_seq_len], - [prev_output_tokens], [num_tokens_processed])[0] - - input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0] - target_seq_id = 100 - - scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000) - output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access - input_seq_group_metadata, - input_seq_id, - target_seq_id, - token_ids, - input_seq_group_metadata.sampling_params, - ) - - assert output.request_id == input_seq_group_metadata.request_id - assert output.sampling_params.repetition_penalty == \ - input_seq_group_metadata.sampling_params.repetition_penalty - assert output.sampling_params.temperature == \ - input_seq_group_metadata.sampling_params.temperature - assert output.sampling_params.top_p == \ - input_seq_group_metadata.sampling_params.top_p - assert output.sampling_params.top_k == \ - input_seq_group_metadata.sampling_params.top_k - assert len(output.seq_data) == 1 - assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple( - prompt_tokens) - assert output.seq_data[target_seq_id].get_output_token_ids() == tuple( - prev_output_tokens + token_ids) - - assert len(output.block_tables) == 1 - assert output.block_tables[ - target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id] diff --git a/tests/spec_decode/test_dynamic_spec_decode.py b/tests/spec_decode/test_dynamic_spec_decode.py deleted file mode 100644 index 407786ad3c64..000000000000 --- a/tests/spec_decode/test_dynamic_spec_decode.py +++ /dev/null @@ -1,90 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock, patch - -import pytest -import torch - -from vllm.sequence import ExecuteModelRequest -from vllm.spec_decode.metrics import AsyncMetricsCollector -from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker -from vllm.spec_decode.top1_proposer import Top1Proposer - -from .test_utils import mock_spec_decode_sampler -from .utils import create_batch, mock_worker - - -@pytest.mark.parametrize('queue_size', [4]) -@pytest.mark.parametrize('batch_size', [1]) -@pytest.mark.parametrize('k', [1]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_disable_spec_tokens(queue_size: int, batch_size: int, k: int, - acceptance_sampler_method: str): - """Verify that speculative tokens are disabled when the batch size - exceeds the threshold. - """ - disable_by_batch_size = 3 - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(proposer_worker=draft_worker, - scorer_worker=target_worker, - spec_decode_sampler=mock_spec_decode_sampler( - acceptance_sampler_method), - disable_logprobs=False, - metrics_collector=metrics_collector, - disable_by_batch_size=disable_by_batch_size) - - exception_secret = 'artificial stop' - draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k, - running_queue_size=queue_size) - - if queue_size > disable_by_batch_size: - with patch.object(worker, - '_run_no_spec', - side_effect=ValueError(exception_secret)), \ - pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) - - # When the batch size is larger than the threshold, - # we expect no speculative tokens (0). - expected_num_spec_tokens = None if queue_size < disable_by_batch_size else 0 - assert seq_group_metadata_list[ - 0].num_speculative_tokens == expected_num_spec_tokens - - draft_worker.sampler_output.side_effect = ValueError(exception_secret) - - proposer = Top1Proposer( - worker=draft_worker, - device='cpu', # not used - vocab_size=100, # not used - # Must be long enough to avoid being skipped due to length. - max_proposal_len=1024, - ) - - if queue_size < disable_by_batch_size: - # Should raise exception when executing the mocked draft model. - with pytest.raises(ValueError, match=exception_secret): - proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) - else: - # Should not execute the draft model because spec decode is disabled - # for all requests. Accordingly, the proposal length should be 0. - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) - assert proposals.proposal_lens.tolist() == [0] * batch_size diff --git a/tests/spec_decode/test_memory_usage.py b/tests/spec_decode/test_memory_usage.py deleted file mode 100644 index 5d9dd3f72a78..000000000000 --- a/tests/spec_decode/test_memory_usage.py +++ /dev/null @@ -1,91 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""This docstring details important information on the testing methodology. - -This test verifies that memory usage remains constant (or never grows) when -we enable / disable speculation via --speculative-disable-by-batch-size. - -There are a lot of things we try to keep track of between batches of requests -and if certain tensors are not freed from memory, can result in CUDA ooms. - -This is particularly relevant for production situations where speculation might -be enabled during off hours, but disabled once traffic peaks during the workday. -Since traffic will stay high for a long period of time, verifying we do not -increase our memory usage over time is essential to prevent possible CUDA ooms. -""" - -import torch - -import vllm -from tests.core.utils import create_dummy_prompt -from vllm.sequence import SequenceGroup - -ITERATIONS = 100 -MAIN_MODEL = "JackFram/llama-68m" - -# speculative model -SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random" - -BATCH_SIZE = 5 -SPEC_DISABLE_BATCH_SIZE = 2 - - -def add_seq_group_to_engine(engine: vllm.LLMEngine, seq_group: SequenceGroup): - scheduler = engine.scheduler[0] - scheduler.add_seq_group(seq_group) - - -""" -Since we are using a batch size greater than the disabled batch size, -we can ensure we go through the _no_spec codepath for most of our engine steps. -""" - - -def test_memory_usage_no_spec(): - previous_memory_allocated = None - llm = vllm.LLM(model=MAIN_MODEL, - speculative_config={ - "model": SPEC_MODEL, - "num_speculative_tokens": 3, - "disable_by_batch_size": SPEC_DISABLE_BATCH_SIZE, - }) - - batch_sequences = set() - engine = llm.llm_engine - - for i in range(ITERATIONS): - seq, seq_group = create_dummy_prompt(request_id=str(i), - prompt_length=10, - min_tokens=10, - max_tokens=10) - - add_seq_group_to_engine(engine, seq_group) - - batch_sequences.add(seq) - engine.step() - for seq in list(batch_sequences): - if seq.is_finished(): - batch_sequences.remove(seq) - - # If we aren't at our batch size yet, continue - if len(batch_sequences) <= BATCH_SIZE: - continue - - # Otherwise, loop until at least one request is done - while not any(seq.is_finished() for seq in batch_sequences): - engine.step() - - # Remove it from the set - for seq in list(batch_sequences): - if seq.is_finished(): - batch_sequences.remove(seq) - - # At this point, we are always at the case where we have finished - # processing some number of requests from the batch after running - # several _no_spec executions. The memory should not have - # increased between the previous time this was recorded and the - # current time. - if previous_memory_allocated is None: - previous_memory_allocated = torch.cuda.memory_allocated() - else: - assert previous_memory_allocated == torch.cuda.memory_allocated() diff --git a/tests/spec_decode/test_metrics.py b/tests/spec_decode/test_metrics.py deleted file mode 100644 index e8de410f8a94..000000000000 --- a/tests/spec_decode/test_metrics.py +++ /dev/null @@ -1,205 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from unittest.mock import MagicMock - -import pytest -import torch - -from vllm.spec_decode.metrics import AsyncMetricsCollector - - -def test_initial_call_returns_none(): - """Expect first call to get metrics to return None. - """ - spec_decode_sampler = MagicMock() - spec_decode_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_draft_tokens = 0 - - collector = AsyncMetricsCollector(spec_decode_sampler) - collector.init_gpu_tensors(rank=0) - maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert maybe_metrics is None - - -def test_second_call_returns_metrics(): - """Expect second call to not return None. - """ - spec_decode_sampler = MagicMock() - spec_decode_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_draft_tokens = 0 - - collect_interval_s = 5.0 - timer = MagicMock() - timer.side_effect = [ - 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 - ] - - collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, - timer=timer, - collect_interval_s=collect_interval_s) - collector.init_gpu_tensors(rank=0) - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is not None - - -@pytest.mark.parametrize("rank", [1, 2, 3, 4]) -def test_nonzero_rank_noop(rank): - """Verify nonzero ranks don't collect metrics. - """ - spec_decode_sampler = MagicMock() - spec_decode_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_draft_tokens = 0 - - collector = AsyncMetricsCollector(spec_decode_sampler) - collector.init_gpu_tensors(rank=rank) - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is None - - -def test_noop_until_time(): - """Verify metrics aren't collected until enough time passes. - """ - spec_decode_sampler = MagicMock() - spec_decode_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_draft_tokens = 0 - - collect_interval_s = 5.0 - timer = MagicMock() - timer.side_effect = [ - 0.0, collect_interval_s - 0.1, collect_interval_s - 0.1, - collect_interval_s + 0.1, collect_interval_s + 0.1 - ] - - collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, - timer=timer, - collect_interval_s=collect_interval_s) - collector.init_gpu_tensors(rank=0) - - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is None - - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is not None - - -def test_timer_is_reset(): - """Verify that the internal timer inside AsyncMetricsCollector - is reset after collection. - """ - spec_decode_sampler = MagicMock() - spec_decode_sampler.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_draft_tokens = 0 - - collect_interval_s = 5.0 - timer = MagicMock() - timer.side_effect = [ - 0.0, - collect_interval_s + 0.1, - collect_interval_s + 0.1, - collect_interval_s + 0.2, - collect_interval_s + 0.2, - 2 * collect_interval_s + 0.1, - 2 * collect_interval_s + 0.1, - ] - - collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, - timer=timer, - collect_interval_s=collect_interval_s) - collector.init_gpu_tensors(rank=0) - - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is not None - - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is None - - _ = collector.maybe_collect_rejsample_metrics(k=5) - metrics = collector.maybe_collect_rejsample_metrics(k=5) - assert metrics is not None - - -@pytest.mark.parametrize("has_data", [True, False]) -def test_initial_metrics_has_correct_values(has_data: bool): - """Test correctness of metrics data. - """ - if has_data: - num_accepted_tokens = 103 - num_emitted_tokens = 104 - num_draft_tokens = 105 - else: - num_accepted_tokens = 0 - num_emitted_tokens = 0 - num_draft_tokens = 0 - k = 5 - - max_num_emitted_tokens = AsyncMetricsCollector.get_max_num_emitted_tokens( - num_draft_tokens, k) - - spec_decode_sampler = MagicMock() - spec_decode_sampler.num_accepted_tokens = torch.tensor(num_accepted_tokens, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_emitted_tokens = torch.tensor(num_emitted_tokens, - dtype=torch.long, - device='cuda') - spec_decode_sampler.num_draft_tokens = num_draft_tokens - - collect_interval_s = 5.0 - timer = MagicMock() - timer.side_effect = [ - 0.0, collect_interval_s + 0.1, collect_interval_s + 0.2 - ] - - collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler, - timer=timer, - collect_interval_s=collect_interval_s) - collector.init_gpu_tensors(rank=0) - _ = collector.maybe_collect_rejsample_metrics(k) - metrics = collector.maybe_collect_rejsample_metrics(k) - - assert metrics.num_spec_tokens == k - assert metrics.accepted_tokens == num_accepted_tokens - assert metrics.draft_tokens == num_draft_tokens - assert metrics.emitted_tokens == num_emitted_tokens - - if has_data: - assert (metrics.draft_acceptance_rate == num_accepted_tokens / - num_draft_tokens) - assert (metrics.system_efficiency == num_emitted_tokens / - max_num_emitted_tokens) - else: - assert math.isnan(metrics.draft_acceptance_rate) - assert math.isnan(metrics.system_efficiency) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py deleted file mode 100644 index f2d93203b8e1..000000000000 --- a/tests/spec_decode/test_multi_step_worker.py +++ /dev/null @@ -1,838 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from unittest.mock import MagicMock - -import pytest -import torch - -from vllm.attention.selector import (_Backend, - global_force_attn_backend_context_manager) -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, - get_all_seq_ids) -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner -from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker - -from .utils import (assert_logprobs_dict_allclose, create_batch, - create_seq_group_metadata_from_prompts, create_worker, - patch_execute_model_with_seeds, zero_kv_cache) - - -@pytest.mark.parametrize('num_steps', list(range(1, 17))) -def test_assert_enough_kv_space(num_steps: int): - """Test that the multi step worker checks for sufficient space in the KV - cache. It should throw if it cannot run all the steps. - """ - block_size = 16 - num_gpu_blocks = 2048 // block_size - - prompts = [ - list(range(block_size * 3)), - list(range(block_size * 2)), - ] - - prev_output_tokens = [ - list(range(block_size * 1)), - list(range(block_size * 2)), - ] - - final_prompt_lens = [ - len(prompt + output) + num_steps - for prompt, output in zip(prompts, prev_output_tokens) - ] - - inputs = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens, - continuations=prev_output_tokens) - - assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access - worker = MagicMock() - worker.model_runner.block_size = block_size - - for seq_group_metadata in inputs: - original_block_tables = seq_group_metadata.block_tables - - # No exception. - assert_enough_kv_space(worker, inputs, num_steps) - - seq_group_metadata.block_tables = { - seq_id: [] - for seq_id, physical_blocks in original_block_tables.items() - } - - # Expect exception. - with pytest.raises(ValueError, - match='times but found insufficient KV space for'): - assert_enough_kv_space(worker, inputs, num_steps) - - seq_group_metadata.block_tables = original_block_tables - - -@torch.inference_mode() -def test_same_output_for_single_step(): - """Verify the multi step worker produces the same output as the normal - worker for num_steps=1. - """ - seed = 100 - model_name = 'JackFram/llama-68m' - - block_size = 32 - num_gpu_blocks = 2048 // block_size - multi_step_worker = create_worker( - MultiStepWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - model_runner_cls=TP1DraftModelRunner, - ) - worker = create_worker( - Worker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - # multi_step_worker.model_runner = worker.model_runner - # multi_step_worker.cache_engine = worker.cache_engine - - num_steps = 1 - - prompts = [ - [1, 2, 3, 4, 5], - [6, 7, 8, 9, 10], - ] - - final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] - - multi_step_seq_group = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) - - zero_kv_cache(multi_step_worker.cache_engine) - set_random_seed(seed) - actual_output, _ = multi_step_worker.sampler_output( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=multi_step_seq_group), - sample_len=num_steps, - seq_ids_with_bonus_token_in_last_step=set()) - assert len(actual_output) == num_steps - actual_output = actual_output[0] - - single_step_seq_group = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) - - zero_kv_cache(worker.cache_engine) - set_random_seed(seed) - expected_output = worker.execute_model( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=single_step_seq_group))[0] - - actual_token_ids = [ - output.samples[0].output_token for output in actual_output - ] - actual_logprobs = [output.samples[0].logprobs for output in actual_output] - - expected_token_ids = [ - output.samples[0].output_token for output in expected_output - ] - expected_logprobs = [ - output.samples[0].logprobs for output in expected_output - ] - - assert actual_token_ids == expected_token_ids - - print(f'{actual_logprobs=}') - print(f'{expected_logprobs=}') - assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs) - - -@torch.inference_mode() -def test_same_output_for_multi_step(): - """Verify the multi-step worker produces the same output as the normal - worker when num_steps > 1. This test runs the multi-step worker once, and - then runs the worker num_steps times, and compares the output. - """ - seed = 100 - model_name = 'JackFram/llama-68m' - - block_size = 16 - num_gpu_blocks = 2048 // block_size - multi_step_worker = create_worker( - MultiStepWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - - worker = create_worker( - Worker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - - # Make sure we go over the block boundary. - num_steps = block_size + 1 - - random.seed(seed) - prompts = [[ - random.randint(0, 1000) for _ in range(random.randint(10, 20)) - ] for _ in range(10)] - - final_prompt_lens = [len(prompt) + num_steps for prompt in prompts] - - rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) - multi_step_worker.execute_model = patch_execute_model_with_seeds( - multi_step_worker, rand_seeds) - worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) - - continuations = [[1] for _ in prompts] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens) - - # Run multi-step. - zero_kv_cache(multi_step_worker.cache_engine) - set_random_seed(seed) - multi_step_output, _ = multi_step_worker.sampler_output( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), - sample_len=num_steps, - seq_ids_with_bonus_token_in_last_step=set()) - - # Run single-step repeatedly. - zero_kv_cache(worker.cache_engine) - single_step_output: list[SamplerOutput] = [] - continuations = [[1] for _ in prompts] - set_random_seed(seed) - - for _ in multi_step_output: - - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens) - - single_step_output.extend( - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list))) - - # Append output tokens to new sequence data. - for i, seq_group_output in enumerate(single_step_output[-1]): - continuations[i].append(seq_group_output.samples[0].output_token) - - # Get token ids and logprobs for comparison. - multi_step_output_logprobs: list[list[dict[int, - Logprob]]] = [[] - for _ in prompts] - single_step_output_logprobs: list[list[dict[int, - Logprob]]] = [[] - for _ in prompts] - - multi_step_output_token_ids: list[list[int]] = [[] for _ in prompts] - single_step_output_token_ids: list[list[int]] = [[] for _ in prompts] - for i, _ in enumerate(prompts): - for multi_step, single_step in zip(multi_step_output, - single_step_output): - multi_step_output_token_ids[i].append( - multi_step[i].samples[0].output_token) - single_step_output_token_ids[i].append( - single_step[i].samples[0].output_token) - - multi_step_output_logprobs[i].append( - multi_step[i].samples[0].logprobs) - single_step_output_logprobs[i].append( - single_step[i].samples[0].logprobs) - - # Print per-sequence token ids - for i, (multi_step_tokens, single_step_tokens) in enumerate( - zip(multi_step_output_token_ids, single_step_output_token_ids)): - print(f'{i=} {multi_step_tokens=}') - print(f'{i=} {single_step_tokens=}') - print(f'{i=} equal {multi_step_tokens == single_step_tokens}') - - # Assert token ids are equal. - for multi_step_tokens, single_step_tokens in zip( - multi_step_output_token_ids, single_step_output_token_ids): - assert multi_step_tokens == single_step_tokens - - # Assert logprobs are equal. - for multi_step_logprobs, single_step_logprobs in zip( - multi_step_output_logprobs, single_step_output_logprobs): - assert_logprobs_dict_allclose(multi_step_logprobs, - single_step_logprobs) - - -@torch.inference_mode() -def test_multi_step_with_batch_expansion_correct_output(): - """ - In this test we verify that the MultiStepWorker is able to handle bonus - tokens correctly. The test verifies that if a sequence has a - bonus token then the MultiStepWorker is able to expand the batch by adding - new sequences corresponding to the sequences with bonus tokens. The - expanded batch is then used for predicting the next tokens. - """ - seed = 100 - model_name = 'JackFram/llama-68m' - - block_size = 16 - num_gpu_blocks = 2048 // block_size - batch_size = 128 - multi_step_worker = create_worker( - MultiStepWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - model_runner_cls=TP1DraftModelRunner, - ) - multi_step_worker.set_include_gpu_probs_tensor() - worker = create_worker( - Worker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - random.seed(seed) - prompts = [[0] for _ in range(batch_size)] - num_steps = 2 - final_prompt_lens = [(num_steps + 1) for prompt in prompts] - rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) - multi_step_worker.execute_model = patch_execute_model_with_seeds( - multi_step_worker, rand_seeds) - worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) - # Create the test continuations - continuations = [[random.randint(0, 1000)] for _ in prompts] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens) - - # Run single-step twice to generate 2 tokens. This - # will simulate the bonus token case with the second token - # being the bonus token. - zero_kv_cache(worker.cache_engine) - single_step_output: list[SamplerOutput] = [] - set_random_seed(seed) - for _ in range(num_steps): - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens) - single_step_output.extend( - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list))) - # Append output tokens to new sequence data. - for i, seq_group_output in enumerate(single_step_output[-1]): - continuations[i].append(seq_group_output.samples[0].output_token) - - # Create continuations for the MultiStepWorker. The continuations have - # 2 tokens in order to simulate the bonus token case. - multi_step_continuations = [] - for continuation in continuations: - multi_step_continuations.append(continuation[:2]) - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=multi_step_continuations, - final_prompt_lens=final_prompt_lens) - - # Run multi-step and verify that the third token prediction is accurate - # for all sequences. - zero_kv_cache(multi_step_worker.cache_engine) - all_seq_ids = {i for i in range(batch_size)} - multi_step_output, _ = multi_step_worker.sampler_output( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), - sample_len=1, - seq_ids_with_bonus_token_in_last_step=all_seq_ids) - for index, output in enumerate(multi_step_output[-1].outputs): - assert (continuations[index][-1] == output.samples[0].output_token) - - -@torch.inference_mode() -def test_multi_step_with_batch_expansion_incorrect_output(): - """ - Tests the MultiStepWorker's ability to handle batch expansion with bonus - tokens in a negative case scenario. This test provides the MultiStepWorker - with a batch containing sequences with bonus tokens but specifies the - sequence IDs with bonus tokens incorrectly. The test verifies that the - MultiStepWorker generates correct tokens for the sequences where the - sequence ID is specified correctly and incorrect tokens for those where - the sequence ID is specified incorrectly. - """ - seed = 100 - model_name = 'JackFram/llama-68m' - - block_size = 16 - num_gpu_blocks = 2048 // block_size - batch_size = 128 - multi_step_worker = create_worker( - MultiStepWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - model_runner_cls=TP1DraftModelRunner, - ) - multi_step_worker.set_include_gpu_probs_tensor() - worker = create_worker( - Worker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - random.seed(seed) - prompts = [[0] for _ in range(batch_size)] - num_steps = 2 - final_prompt_lens = [(num_steps + 1) for prompt in prompts] - rand_seeds = list(random.randint(0, 100) for _ in range(num_steps)) - multi_step_worker.execute_model = patch_execute_model_with_seeds( - multi_step_worker, rand_seeds) - worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds) - # Create the test continuations - continuations = [[random.randint(0, 1000)] for _ in prompts] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens) - # Run single-step twice to generate 2 tokens. This - # will simulate the bonus token case with the second token - # being the bonus token. - zero_kv_cache(worker.cache_engine) - single_step_output: list[SamplerOutput] = [] - set_random_seed(seed) - for _ in range(num_steps): - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=continuations, - final_prompt_lens=final_prompt_lens) - single_step_output.extend( - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list))) - # Append output tokens to new sequence data. - for i, seq_group_output in enumerate(single_step_output[-1]): - continuations[i].append(seq_group_output.samples[0].output_token) - - # Create continuations for the MultiStepWorker. The continuations have - # 2 tokens in order to simulate the bonus token case. - multi_step_continuations = [] - for continuation in continuations: - multi_step_continuations.append(continuation[:2]) - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=multi_step_continuations, - final_prompt_lens=final_prompt_lens) - - # Run multi-step. In this run INCORRECTLY specify that only the odd number - # sequences have bonus tokens. Verify that with this setting the third token - # prediction is accurate only for the odd numbered sequences. Also verify - # that the prediction might be wrong for some of the even numbered - # sequences. - zero_kv_cache(multi_step_worker.cache_engine) - set_random_seed(seed) - odd_seq_ids = {i for i in range(batch_size) if i % 2 != 0} - multi_step_output, _ = multi_step_worker.sampler_output( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), - sample_len=1, - seq_ids_with_bonus_token_in_last_step=odd_seq_ids) - num_mismatch = 0 - for index, output in enumerate(multi_step_output[-1].outputs): - if (index % 2) != 0: - assert (continuations[index][-1] == output.samples[0].output_token) - elif (continuations[index][-1] != output.samples[0].output_token): - num_mismatch += 1 - # The prediction is accurate for some of the sequences even without proper - # handling of the bonus tokens. Hence verify that the number of sequences - # for which there is a mismatch is > 0. - assert (num_mismatch > 0) - - -@torch.inference_mode() -@pytest.mark.parametrize('num_steps', [1, 2, 3, 4]) -# The choice of backends forces the multi_step_worker to choose between -# the vanilla model_runner and TP1DraftModelRunner and that we can test -# both code paths. -@pytest.mark.parametrize('attn_backend', - [_Backend.XFORMERS, _Backend.FLASH_ATTN]) -def test_multi_step_correct_kvcache(num_steps, attn_backend): - """Verify that the KV cache of the draft model - is correctly updated for sequences with bonus token. - """ - seed = 100 - model_name = "JackFram/llama-68m" - - block_size = 16 - num_gpu_blocks = 2048 // block_size - batch_size = 1 - - with global_force_attn_backend_context_manager(attn_backend): - dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32' - multi_step_worker = create_worker(MultiStepWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - model_runner_cls=TP1DraftModelRunner, - dtype=dtype) - multi_step_worker.set_include_gpu_probs_tensor() - worker = create_worker(Worker, - model_name, - block_size, - num_gpu_blocks, - seed, - dtype=dtype) - - prompts = [[0] for _ in range(batch_size)] - # Already generate two tokens for the sequence - # so that we can simulate the bonus token case - multi_step_continuations = [[ - random.randint(0, 1000), - random.randint(0, 1000) - ] for _ in prompts] - final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts] - - seq_ids_with_bonus_token_in_last_step = set(range(batch_size)) - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=multi_step_continuations, - final_prompt_lens=final_prompt_lens) - - # Run multi-step. - zero_kv_cache(multi_step_worker.cache_engine) - multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list), - sample_len=num_steps, - seq_ids_with_bonus_token_in_last_step= - seq_ids_with_bonus_token_in_last_step) - - # Run single-step repeatedly. - zero_kv_cache(worker.cache_engine) - # Generate the kv cache for the bonus token first - single_step_continuations = [c[:1] for c in multi_step_continuations] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=single_step_continuations, - final_prompt_lens=final_prompt_lens) - single_step_output = worker.execute_model( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list)) - for _ in range(num_steps): - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - continuations=multi_step_continuations, - final_prompt_lens=final_prompt_lens) - - single_step_output = worker.execute_model( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list)) - - for i, seq_group_output in enumerate(single_step_output[-1]): - multi_step_continuations[i].append( - seq_group_output.samples[0].output_token) - - # Verify that the KV cache of the single-step and - # multi-step workers are the same. - single_step_gpu_cache = worker.cache_engine[0].gpu_cache - multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache - num_layers = len(single_step_gpu_cache) - allclose = lambda a, b: torch.allclose( - a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2) - for i in range(num_layers): - assert allclose(single_step_gpu_cache[i][0], - multi_step_gpu_cache[i][0]) - assert allclose(single_step_gpu_cache[i][1], - multi_step_gpu_cache[i][1]) - - -@torch.inference_mode() -def test_draft_proposals_full_speculation_len(): - """Verify Top1Proposer correctly handles case where all sequences - can speculate. - """ - k = 10 - batch_size = 32 - vocab_size = 32_000 - device = 'cuda:0' - - draft_worker = MagicMock() - proposer = Top1Proposer( - worker=draft_worker, - device=device, - vocab_size=vocab_size, - max_proposal_len=2048, - ) - draft_worker.sampler_output.return_value = [ - SamplerOutput( - outputs=[], - sampled_token_probs=torch.rand(batch_size, - vocab_size, - device=device, - dtype=torch.float32), - logprobs=torch.rand(batch_size, - vocab_size, - device=device, - dtype=torch.float32), - sampled_token_ids=torch.randint(low=0, - high=vocab_size, - size=(batch_size, ), - device=device, - dtype=torch.long), - ) for _ in range(k) - ], True - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) - - assert torch.is_tensor(proposals.proposal_token_ids) - assert torch.is_tensor(proposals.proposal_probs) - - assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) - - assert proposals.proposal_lens.shape == torch.Size([batch_size]) - assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)] - - -@torch.inference_mode() -def test_draft_proposals_no_speculations(): - """Verify Top1Proposer correctly handles case where no sequences - can speculate. - """ - k = 10 - batch_size = 32 - vocab_size = 32_000 - device = 'cuda:0' - prompt_len = 10 - - draft_worker = MagicMock() - proposer = Top1Proposer( - worker=draft_worker, - device=device, - vocab_size=vocab_size, - max_proposal_len=prompt_len + k - 1, - ) - - seq_group_metadata_list, _, _ = create_batch(batch_size, - k, - prompt_len=prompt_len) - - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) - - assert torch.is_tensor(proposals.proposal_token_ids) - assert torch.is_tensor(proposals.proposal_probs) - - assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) - - assert proposals.proposal_lens.shape == torch.Size([batch_size]) - assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)] - - -@torch.inference_mode() -def test_draft_proposals_mixed_k(): - """Verify Top1Proposer correctly handles case some sequences can - speculate and some can't. - """ - k = 10 - batch_size = 32 - vocab_size = 32_000 - device = 'cuda:0' - - small_prompt_len = 5 - long_prompt_len = 10 - prev_output_token_len = 20 - - expected_num_proposal_seqs = 6 - expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs - - prompt_len = [ - small_prompt_len for _ in range(expected_num_proposal_seqs - 1) - ] + [long_prompt_len - for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len] - - draft_worker = MagicMock() - proposer = Top1Proposer( - worker=draft_worker, - device=device, - vocab_size=vocab_size, - max_proposal_len=long_prompt_len + prev_output_token_len + k - 1, - ) - - draft_worker.sampler_output.return_value = [ - SamplerOutput( - outputs=[], - sampled_token_probs=torch.rand(expected_num_proposal_seqs, - vocab_size, - device=device, - dtype=torch.float32), - logprobs=torch.rand(expected_num_proposal_seqs, - vocab_size, - device=device, - dtype=torch.float32), - sampled_token_ids=torch.randint( - low=0, - high=vocab_size, - size=(expected_num_proposal_seqs, ), - device=device, - dtype=torch.long), - ) for _ in range(k) - ], True - - seq_group_metadata_list, _, _ = create_batch( - batch_size, - k, - prompt_len=prompt_len, - prev_output_token_len=prev_output_token_len, - ) - - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k), - seq_ids_with_bonus_token_in_last_step=set()) - - assert torch.is_tensor(proposals.proposal_token_ids) - assert torch.is_tensor(proposals.proposal_probs) - - assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k]) - - assert proposals.proposal_lens.shape == torch.Size([batch_size]) - assert proposals.proposal_lens.tolist() == [ - k for _ in range(expected_num_proposal_seqs - 1) - ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k] - - -@torch.inference_mode() -def test_use_draft_model_runner_advance_step(): - """Verify that draft model runner triggers advance step - when applicable. - """ - seed = 100 - model_name = 'JackFram/llama-68m' - - k = 5 - batch_size = 32 - block_size = 32 - num_gpu_blocks = 2048 // block_size - worker = create_worker( - MultiStepWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - model_runner_cls=TP1DraftModelRunner, - ) - - # Mock "_gpu_advance_step" to raise an exception when called. - exception_secret = "artificial stop" - worker.model_runner._gpu_advance_step = MagicMock() - worker.model_runner._gpu_advance_step.side_effect = ValueError( - exception_secret) - - seq_group_metadata_list, _, _ = create_batch(batch_size, - k, - block_size=block_size, - num_gpu_blocks=num_gpu_blocks) - - # Fallback (should not call) when num_steps=1. - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k, - num_steps=1) - worker.execute_model(execute_model_req=execute_model_req) - - # Expect exception if _gpu_advance_step is called. - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k, - num_steps=k) - - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) - call_args_list = worker.model_runner._gpu_advance_step.call_args_list - assert len(call_args_list) == 1 - - -@torch.inference_mode() -def test_expand_execute_model_request_sync_with_expand_hidden_states(): - """ - In this test we verify that the logic for expanding the - seq_group_metadata_list remains in sync with the expansion logic of - the HiddenStates in _expand_execute_model_request. - """ - k = 5 - batch_size = 16 - seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15] - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - - execute_model_request = ExecuteModelRequest( - seq_group_metadata_list, - previous_hidden_states=HiddenStates( - torch.arange(batch_size), seq_group_metadata_list, - torch.arange(batch_size, 2 * batch_size))) - - expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\ - _expand_execute_model_request(execute_model_request, - seq_with_bonus_token_in_last_step) - - all_seq_ids = torch.tensor( - get_all_seq_ids( - expanded_execute_model_request.seq_group_metadata_list)) - ref_expanded_hidden_states = all_seq_ids + batch_size - ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size - - assert (ref_expanded_hidden_states == expanded_execute_model_request. - previous_hidden_states.hidden_states).all().item() diff --git a/tests/spec_decode/test_ngram_worker.py b/tests/spec_decode/test_ngram_worker.py deleted file mode 100644 index 8a7c11485681..000000000000 --- a/tests/spec_decode/test_ngram_worker.py +++ /dev/null @@ -1,221 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.sequence import ExecuteModelRequest -from vllm.spec_decode.ngram_worker import NGramWorker -from vllm.spec_decode.top1_proposer import Top1Proposer - -from .utils import create_seq_group_metadata_from_prompts, create_worker - - -def test_ngram_algo_correctness_for_single_no_match(): - """Verify our ngram algo find the right candidate in the prompt - - For the scenario cannot find any candidate in one single batch - """ - block_size = 32 - num_gpu_blocks = 2048 // block_size - seed = 100 - model_name = 'JackFram/llama-68m' - vocab_size = 32_000 - device = 'cuda:0' - - ngram_worker = create_worker( - NGramWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - - proposer = Top1Proposer( - worker=ngram_worker, - device=device, - vocab_size=vocab_size, - max_proposal_len=20, - ) - - # set ngram window [1, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(1, 3) - - prompts = [ - # shall find no candidate - [1, 2, 3, 4, 5, 6, 7], - ] - - proposal_len = 5 - final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) - - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), - seq_ids_with_bonus_token_in_last_step=None) - - assert torch.is_tensor(proposals.proposal_token_ids) - assert torch.is_tensor(proposals.proposal_probs) - - assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len]) - assert proposals.proposal_lens.shape == torch.Size([1]) - assert proposals.proposal_lens.tolist() == [0] - - -def test_ngram_algo_correctness_for_batches_not_match_all(): - """Verify our ngram algo find the right candidate in the prompt - - For the scenario find some candidate not full in batchs - """ - block_size = 32 - num_gpu_blocks = 2048 // block_size - seed = 100 - model_name = 'JackFram/llama-68m' - vocab_size = 32_000 - device = 'cuda:0' - - ngram_worker = create_worker( - NGramWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - - proposer = Top1Proposer( - worker=ngram_worker, - device=device, - vocab_size=vocab_size, - max_proposal_len=20, - ) - - # set ngram window [1, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(1, 3) - - prompts = [ - # shall find no candidate - [1, 2, 3, 4, 5, 6, 7], - # shall find candidate 12,13,14,15,16 - [11, 12, 13, 14, 15, 16, 11], - # shall find candidate 23,24,25,26,21 - [21, 21, 22, 23, 24, 25, 26, 21, 22], - # shall find candidate 34,35,36,37,38 - [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], - # shall find no candidate as exceed max_proposal_len - [ - 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37, - 38, 31, 32, 33 - ], - ] - - proposal_len = 5 - final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) - for sg in seq_group_metadata_list: - sg.is_prompt = False - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), - seq_ids_with_bonus_token_in_last_step=None) - - assert torch.is_tensor(proposals.proposal_token_ids) - assert torch.is_tensor(proposals.proposal_probs) - - assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len]) - assert proposals.proposal_lens.shape == torch.Size([5]) - - # the first sequence has no match so proposal_len should be overwritten to 0 - assert proposals.proposal_lens.tolist( - ) == [0] + [proposal_len for _ in range(3)] + [0] - - for i in range(proposal_len): - assert proposals.proposal_token_ids[0][i] == -1 - assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1] - assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3] - assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5] - assert proposals.proposal_token_ids[4][i] == -1 - - -def test_ngram_algo_correctness_for_batches_match_all(): - """Verify our ngram algo find the right candidate in the prompt - - For the scenario find candidate in all batches - """ - - block_size = 32 - num_gpu_blocks = 2048 // block_size - seed = 100 - model_name = 'JackFram/llama-68m' - vocab_size = 32_000 - device = 'cuda:0' - - ngram_worker = create_worker( - NGramWorker, - model_name, - block_size, - num_gpu_blocks, - seed, - ) - - proposer = Top1Proposer( - worker=ngram_worker, - device=device, - vocab_size=vocab_size, - max_proposal_len=20, - ) - - # set ngram window [0, 3], which is window=1/2/3 - ngram_worker.set_ngram_window_size(1, 3) - - prompts = [ - # shall find candidate 12,13,14,15,16 - [11, 12, 13, 14, 15, 16, 11], - # shall find candidate 23,24,25,26,21 - [21, 21, 22, 23, 24, 25, 26, 21, 22], - # shall find candidate 34,35,36,37,38 - [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33], - ] - - proposal_len = 5 - final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts] - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, - num_gpu_blocks, - block_size, - final_prompt_lens=final_prompt_lens) - - # Normally drafter is run on decode requests only; here we check the output - # of the ngram worker as it is the sole proposer that has no forward. - for sg in seq_group_metadata_list: - sg.is_prompt = False - proposals = proposer.get_spec_proposals( - execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=proposal_len), - seq_ids_with_bonus_token_in_last_step=None) - - assert torch.is_tensor(proposals.proposal_token_ids) - assert torch.is_tensor(proposals.proposal_probs) - - assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len]) - assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len]) - assert proposals.proposal_lens.shape == torch.Size([3]) - - assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)] - - for i in range(proposal_len): - assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1] - assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3] - assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5] diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py deleted file mode 100644 index 55fcf0055747..000000000000 --- a/tests/spec_decode/test_scorer.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random - -import pytest -import torch - -from vllm.sequence import ExecuteModelRequest -from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores -from vllm.spec_decode.mqa_scorer import MQAScorer -from vllm.worker.worker import Worker - -from .utils import create_batch, create_worker - - -def create_proposal(propose_lens: list[int], vocab_size: int, - device: str) -> SpeculativeProposals: - batch_size = len(propose_lens) - max_propose_len = max(propose_lens) - proposal_probs = torch.rand((batch_size, max_propose_len, vocab_size), - device=device) - - proposal_token_ids = torch.full((batch_size, max_propose_len), - fill_value=-1, - device=device) - for i in range(batch_size): - proposal_token_ids[i][:propose_lens[i]] = torch.argmax( - proposal_probs[i][:propose_lens[i]], dim=-1) - - propose_lens = torch.tensor(propose_lens, device=device) - return SpeculativeProposals(proposal_token_ids, proposal_probs, - propose_lens) - - -def assert_score_equal(score1: SpeculativeScores, - score2: SpeculativeScores) -> None: - assert torch.allclose(score1.probs, score2.probs) - assert torch.allclose(score1.logprobs, score2.logprobs) - assert torch.equal( - score1.token_ids, - score2.token_ids), f"{score1.token_ids}, {score2.token_ids}" - - -@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) -@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) -@pytest.mark.parametrize('max_propose_len', [1, 3, 5]) -@pytest.mark.parametrize('mixed_propose_len', [True]) -@pytest.mark.parametrize('device', ['cuda']) -@pytest.mark.parametrize('prefill_chunking', [False, True]) -def test_scorer(model_name: str, batch_size: int, max_propose_len: int, - mixed_propose_len: bool, device: str, - prefill_chunking: bool) -> None: - """ - Compare the batch expansion scorer and mqa scorer return the same score. - We test for both queries with the same propose length and different - propose length, as well as mixed prefill-decode batches. - """ - seed = 0 - block_size = 32 - num_gpu_blocks = 2048 // block_size - scorer_worker = create_worker(Worker, model_name, block_size, - num_gpu_blocks, seed) - scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer - scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True - scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True - - vocab_size = scorer_worker.vocab_size - - if not mixed_propose_len: - propose_lens = [max_propose_len] * batch_size - else: - # There must be at least 1 decode request, otherwise - # we have nothing to score (`_run_no_spec`). - non_zero_cnt = random.randint(1, batch_size) - propose_lens = [max_propose_len - ] * non_zero_cnt + [0] * (batch_size - non_zero_cnt) - random.shuffle(propose_lens) - - seq_group_metadatalist, _, _ = create_batch(batch_size, - max_propose_len, - block_size=block_size, - num_gpu_blocks=num_gpu_blocks) - - if mixed_propose_len and prefill_chunking and (n_prefills := - batch_size - non_zero_cnt): - prefill, _, _ = create_batch(n_prefills, - None, - prefill_chunk_size=4, - block_size=block_size, - num_gpu_blocks=num_gpu_blocks, - seq_ids=list( - range(batch_size, - batch_size + n_prefills))) - # re-order to guarantee prefill|decode order - target_group_metadatalist = [ - seq_group_metadatalist[i] for i, p in enumerate(propose_lens) - if p > 0 - ] - seq_group_metadatalist = prefill + target_group_metadatalist - propose_lens = [0] * n_prefills + [p for p in propose_lens if p > 0] - - proposals = create_proposal(propose_lens, vocab_size, device) - requests = ExecuteModelRequest(seq_group_metadatalist, - num_lookahead_slots=max_propose_len) - - batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, - vocab_size) - batch_expansion_score = batch_expansion_scorer.score_proposals( - requests, proposals) - - mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) - mqa_score = mqa_scorer.score_proposals(requests, proposals) - - assert_score_equal(batch_expansion_score, mqa_score) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py deleted file mode 100644 index 8aceaadff8d3..000000000000 --- a/tests/spec_decode/test_spec_decode_worker.py +++ /dev/null @@ -1,945 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import random -from collections import defaultdict -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest -import torch - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.utils import set_random_seed -from vllm.sequence import ExecuteModelRequest, SequenceOutput -from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer -from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner -from vllm.spec_decode.interfaces import SpeculativeProposals -from vllm.spec_decode.metrics import (AsyncMetricsCollector, - SpecDecodeWorkerMetrics) -from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker, - split_num_cache_blocks_evenly) -from vllm.worker.worker import Worker - -from .test_utils import mock_spec_decode_sampler -from .utils import (create_batch, create_sampler_output_list, create_worker, - mock_worker) - - -@pytest.mark.parametrize('k', [1, 2, 6]) -@pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_correctly_calls_draft_model(k: int, batch_size: int, - acceptance_sampler_method: str): - """Verify SpecDecodeWorker calls the draft worker with correct - inputs. Everything else is mocked out. - """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker( - draft_worker, - target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), - disable_logprobs=False, - metrics_collector=metrics_collector) - exception_secret = 'artificial stop' - draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) - - call_args_list = draft_worker.get_spec_proposals.call_args_list - assert len(call_args_list) == 1 - - for args, _ in call_args_list: - actual_execute_model_data = args[0] - assert actual_execute_model_data == execute_model_req - - -@pytest.mark.parametrize('k', [1, 2, 6]) -@pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_batch_expansion_correctly_calls_target_model( - k: int, batch_size: int, acceptance_sampler_method: str): - """Verify SpecDecodeWorker calls the target model with correct - inputs with batch expansion. Everything else is mocked out. - """ - draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) - target_worker = mock_worker(use_spec=False) - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - - draft_worker.device = 'cuda' - target_worker.device = 'cuda' - - set_random_seed(1) - - worker = SpecDecodeWorker( - draft_worker, - target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), - disable_logprobs=False, - metrics_collector=metrics_collector, - disable_mqa_scorer=True) - worker.init_device() - - vocab_size = 32_000 - - proposal_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64, - device='cuda') - proposal_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device='cuda') - proposal_lens = torch.ones(batch_size, dtype=torch.int64, - device='cuda') * k - - seq_group_metadata_list, prompts, prev_output_tokens = create_batch( - batch_size, k) - - draft_worker.get_spec_proposals.return_value = SpeculativeProposals( - proposal_token_ids=proposal_token_ids, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens) - - exception_secret = 'artificial stop' - target_worker.execute_model.side_effect = ValueError(exception_secret) - - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k)) - - seen_contexts: list[list[int]] = [] - - call_args_list = target_worker.execute_model.call_args_list - assert len(call_args_list) == 1 - for _, kwargs in call_args_list: - seq_group_metadata_list = kwargs[ - "execute_model_req"].seq_group_metadata_list - - assert len(seq_group_metadata_list) == (k + 1) * batch_size - for seq_group_metadata in seq_group_metadata_list: - for seq_data in seq_group_metadata.seq_data.values(): - seen_contexts.append(seq_data.get_token_ids()) - - expected_seen_contexts: list[list[int]] = [] - - for prompt, prev_generated, draft_tokens in zip( - prompts, prev_output_tokens, proposal_token_ids.tolist()): - - for i in range(len(draft_tokens) + 1): - expected_seen_contexts.append(prompt + prev_generated + - draft_tokens[:i]) - - seen_contexts.sort() - expected_seen_contexts.sort() - assert expected_seen_contexts == seen_contexts - - -@pytest.mark.parametrize('k', [1, 2, 6]) -@pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, - acceptance_sampler_method: str): - """Verify SpecDecodeWorker calls the rejection sampler with - correct inputs. Everything else is mocked out. - """ - vocab_size = 32_000 - - draft_worker = mock_worker(cls=MultiStepWorker, - vocab_size=vocab_size, - use_spec=False) - target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - draft_worker.device = 'cuda' - target_worker.device = 'cuda' - - set_random_seed(1) - - worker = SpecDecodeWorker(draft_worker, - target_worker, - spec_decode_sampler, - disable_logprobs=False, - metrics_collector=metrics_collector) - worker.init_device() - - proposal_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64, - device='cuda') - proposal_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device='cuda') - - proposal_lens = torch.ones(batch_size, dtype=torch.int64, - device='cuda') * k - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - - draft_worker.get_spec_proposals.return_value = SpeculativeProposals( - proposal_token_ids=proposal_token_ids, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens) - - target_token_ids = torch.randint(low=0, - high=vocab_size, - size=(1, batch_size * (k + 1)), - dtype=torch.int64, - device='cuda') - target_token_probs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_token_logprobs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_output = create_sampler_output_list(target_token_ids, - target_token_probs, - target_token_logprobs) - - target_worker.execute_model.return_value = [target_output[0]] - - exception_secret = 'artificial stop' - - spec_decode_sampler.side_effect = ValueError(exception_secret) - - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k)) - - assert len(spec_decode_sampler.call_args_list) == 1 - _, kwargs = spec_decode_sampler.call_args_list[0] - actual = SimpleNamespace(**kwargs) - - assert torch.equal(actual.bonus_token_ids, - target_token_ids.reshape(batch_size, k + 1)[:, -1:]) - assert torch.equal(actual.target_with_bonus_probs, - target_token_probs.reshape(batch_size, k + 1, -1)) - assert torch.equal(actual.draft_token_ids, proposal_token_ids) - assert torch.equal(actual.draft_probs, proposal_probs) - - -@pytest.mark.parametrize('k', [1, 2, 6]) -@pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_correctly_formats_output(k: int, batch_size: int, - acceptance_sampler_method: str): - """Verify SpecDecodeWorker formats sampler output correctly. - Everything else is mocked out. - """ - vocab_size = 32_000 - - draft_worker = mock_worker(cls=MultiStepWorker, - vocab_size=vocab_size, - use_spec=False) - target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - draft_worker.device = 'cuda' - target_worker.device = 'cuda' - - set_random_seed(1) - spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - worker = SpecDecodeWorker(draft_worker, - target_worker, - spec_decode_sampler, - disable_logprobs=False, - metrics_collector=metrics_collector) - worker.init_device() - - proposal_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64, - device='cuda') - proposal_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device='cuda') - - proposal_lens = torch.ones(batch_size, dtype=torch.int64, - device='cuda') * k - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - - draft_worker.get_spec_proposals.return_value = SpeculativeProposals( - proposal_token_ids=proposal_token_ids, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens) - - target_token_ids = torch.randint(low=0, - high=vocab_size, - size=(1, batch_size * (k + 1)), - dtype=torch.int64, - device='cuda') - target_token_probs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_token_logprobs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_output = create_sampler_output_list(target_token_ids, - target_token_probs, - target_token_logprobs) - - target_worker.execute_model.return_value = [target_output[0]] - - spec_decode_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') - for i in range(batch_size): - minimum_accepted_tokens = 1 - spec_decode_sampler_output[i][ - -random.randint(minimum_accepted_tokens, k + 1):] = -1 - - spec_decode_sampler.return_value = spec_decode_sampler_output - output = worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k)) - - expected_output = create_sampler_output_list( - token_ids=spec_decode_sampler_output.transpose(0, 1), - probs=[None for _ in range(k + 1)], - logprobs=[None for _ in range(k + 1)]) - - seq_ids = [ - next(iter(seq_group_metadata.seq_data.keys())) - for seq_group_metadata in seq_group_metadata_list - ] - actual_output_by_seq: dict[int, list[SequenceOutput]] = { - seq_id: [] - for seq_id in seq_ids - } - expected_output_by_seq: dict[int, list[SequenceOutput]] = { - seq_id: [] - for seq_id in seq_ids - } - - for step in output: - for seq_group in step: - for sample in seq_group.samples: - seq_id = sample.parent_seq_id - actual_output_by_seq[seq_id].append(sample) - - for step in expected_output: - for seq_group in step: - for sample in seq_group.samples: - seq_id = sample.parent_seq_id - expected_output_by_seq[seq_id].append(sample) - - all_seen_seq_ids = set( - list(actual_output_by_seq.keys()) + - list(expected_output_by_seq.keys())) - for seq_id in all_seen_seq_ids: - actual_by_step = actual_output_by_seq[seq_id] - expected_by_step = expected_output_by_seq[seq_id] - - for i in range(k + 1): - if i >= len(actual_by_step): - assert expected_by_step[i].output_token == -1 - continue - assert actual_by_step[i].output_token == expected_by_step[ - i].output_token - - -@pytest.mark.parametrize('k', [1, 2]) -@pytest.mark.parametrize('batch_size', [1]) -@pytest.mark.parametrize('returns_metrics', [True, False]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool, - acceptance_sampler_method: str): - """Verify SpecDecodeWorker collects metrics. - """ - vocab_size = 32_000 - - draft_worker = mock_worker(cls=MultiStepWorker, - vocab_size=vocab_size, - use_spec=False) - target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - draft_worker.device = 'cuda' - target_worker.device = 'cuda' - - set_random_seed(1) - - worker = SpecDecodeWorker(draft_worker, - target_worker, - spec_decode_sampler, - disable_logprobs=False, - metrics_collector=metrics_collector) - worker.init_device() - - proposal_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k), - dtype=torch.int64, - device='cuda') - proposal_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device='cuda') - - proposal_lens = torch.ones(batch_size, dtype=torch.int64, - device='cuda') * k - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - - draft_worker.get_spec_proposals.return_value = SpeculativeProposals( - proposal_token_ids=proposal_token_ids, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens) - - target_token_ids = torch.randint(low=0, - high=vocab_size, - size=(1, batch_size * (k + 1)), - dtype=torch.int64, - device='cuda') - target_token_probs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_token_logprobs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_output = create_sampler_output_list(target_token_ids, - target_token_probs, - target_token_logprobs) - - target_worker.execute_model.return_value = [target_output[0]] - - spec_decode_sampler_output = torch.randint(low=0, - high=vocab_size, - size=(batch_size, k + 1), - dtype=torch.int64, - device='cuda') - for i in range(batch_size): - minimum_accepted_tokens = 1 - spec_decode_sampler_output[i][ - -random.randint(minimum_accepted_tokens, k + 1):] = -1 - spec_decode_sampler.return_value = spec_decode_sampler_output - - mock_rejsample_metrics = MagicMock( - spec=SpecDecodeWorkerMetrics) if returns_metrics else None - metrics_collector.maybe_collect_rejsample_metrics.return_value = ( - mock_rejsample_metrics) - - output = worker.execute_model(execute_model_req=ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k)) - assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics - - call_args_list = ( - metrics_collector.maybe_collect_rejsample_metrics.call_args_list) - assert len(call_args_list) == 1 - args, kwargs = call_args_list[0] - assert args[0] == k or kwargs.get('k', -1) == k - - -@pytest.mark.parametrize('k', [0]) -@pytest.mark.parametrize('batch_size', [1, 2, 32]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_k_equals_zero(k: int, batch_size: int, - acceptance_sampler_method: str): - """Verify that the SpecDecodeWorker calls the draft and target workers - when k is zero. This happens during prefill. - """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - - sampler_output = MagicMock(spec=SamplerOutput) - sampler_output.hidden_states = None - target_worker.execute_model.return_value = [sampler_output] - - draft_worker.device = 'cuda' - target_worker.device = 'cuda' - - set_random_seed(1) - - worker = SpecDecodeWorker( - proposer_worker=draft_worker, - scorer_worker=target_worker, - spec_decode_sampler=mock_spec_decode_sampler( - acceptance_sampler_method), - disable_logprobs=False, - metrics_collector=metrics_collector, - ) - - seq_group_metadata_list, _, _ = create_batch(batch_size, - k, - prev_output_token_len=0) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - - out = worker.execute_model(execute_model_req=execute_model_req) - - assert len(out) == 1, f"expected only one token output when {k=}" - assert out[0].sampled_token_probs is None, ( - "expect gpu tensor references to be None") - assert out[ - 0].sampled_token_ids is None, "expect gpu tensor references to be None" - - draft_worker.execute_model.assert_called_once_with(execute_model_req) - target_worker.execute_model.assert_called_once_with(execute_model_req) - - -@pytest.mark.parametrize('k', [0, 5]) -@pytest.mark.parametrize('batch_size', [0]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_empty_input_batch(k: int, batch_size: int, - acceptance_sampler_method: str): - """Verify that the SpecDecodeWorker calls the draft and target workers - when the input batch is empty. This can happen if the engine communicates - to the workers information without scheduling a batch. - """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - - sampler_output = MagicMock(spec=SamplerOutput) - sampler_output.hidden_states = None - target_worker.execute_model.return_value = [sampler_output] - - draft_worker.device = 'cuda' - target_worker.device = 'cuda' - - set_random_seed(1) - - worker = SpecDecodeWorker( - proposer_worker=draft_worker, - scorer_worker=target_worker, - spec_decode_sampler=mock_spec_decode_sampler( - acceptance_sampler_method), - disable_logprobs=False, - metrics_collector=metrics_collector, - ) - - seq_group_metadata_list, _, _ = create_batch(batch_size, - k, - prev_output_token_len=0) - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, num_lookahead_slots=k) - - out = worker.execute_model(execute_model_req=execute_model_req) - - assert len(out) == 1, f"expected only one token output when {k=}" - assert out[0].sampled_token_probs is None, ( - "expect gpu tensor references to be None") - assert out[ - 0].sampled_token_ids is None, "expect gpu tensor references to be None" - - draft_worker.execute_model.assert_called_once_with(execute_model_req) - target_worker.execute_model.assert_called_once_with(execute_model_req) - - -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@pytest.mark.skip_global_cleanup -def test_init_device(acceptance_sampler_method: str): - """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as - well as other GPU initialization. - """ - draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) - target_worker = mock_worker(use_spec=False) - spec_decode_sampler = mock_spec_decode_sampler(acceptance_sampler_method) - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - - worker = SpecDecodeWorker( - proposer_worker=draft_worker, - scorer_worker=target_worker, - spec_decode_sampler=spec_decode_sampler, - disable_logprobs=False, - metrics_collector=metrics_collector, - ) - worker.init_device() - - draft_worker.init_device.assert_called_once() - - target_worker.init_device.assert_called_once() - - metrics_collector.init_tensors.assert_called_once() - spec_decode_sampler.init_tensors.assert_called_once() - - -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@torch.inference_mode() -def test_initialize_cache(acceptance_sampler_method): - """Verify SpecDecodeWorker invokes initialize_cache on proposer/scorer - workers. - """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - - worker = SpecDecodeWorker(proposer_worker=draft_worker, - scorer_worker=target_worker, - spec_decode_sampler=mock_spec_decode_sampler( - acceptance_sampler_method), - metrics_collector=metrics_collector) - - kwargs = {"num_gpu_blocks": 1024, "num_cpu_blocks": 1023} - worker.initialize_cache(**kwargs) - - draft_worker.initialize_cache.assert_called_once_with(**kwargs) - target_worker.initialize_cache.assert_called_once_with(**kwargs) - - -@pytest.mark.parametrize('available_gpu_blocks', [1, 1024]) -@pytest.mark.parametrize('available_cpu_blocks', [500]) -@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096]) -@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) -@pytest.mark.parametrize("acceptance_sampler_method", - ["rejection_sampler", "typical_acceptance_sampler"]) -@pytest.mark.skip_global_cleanup -def test_determine_num_available_blocks(available_gpu_blocks: int, - available_cpu_blocks: int, - target_cache_block_size_bytes: int, - draft_kv_size_bytes: int, - acceptance_sampler_method: str): - """Verify SpecDecodeWorker correctly profiles num available GPU blocks. - Specifically, it should run profiling in the scorer worker, and then evenly - split the blocks between proposer and scorer worker. - """ - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - - target_worker.determine_num_available_blocks.return_value = ( - available_gpu_blocks, available_cpu_blocks) - target_worker.get_cache_block_size_bytes.return_value = ( - target_cache_block_size_bytes) - draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes - - worker = SpecDecodeWorker( - draft_worker, target_worker, - mock_spec_decode_sampler(acceptance_sampler_method), metrics_collector) - - num_gpu_blocks, num_cpu_blocks = worker.determine_num_available_blocks() - - target_worker.determine_num_available_blocks.assert_called_once() - assert num_cpu_blocks == available_cpu_blocks - - assert num_gpu_blocks == split_num_cache_blocks_evenly( - target_cache_block_size_bytes, draft_kv_size_bytes, - available_gpu_blocks) - - -@pytest.mark.parametrize('available_gpu_blocks', - list(range(20)) + [1024, 1024**2]) -@pytest.mark.parametrize('target_cache_block_size_bytes', - [2 * 2 * 4096, 2 * 2 * 8192]) -@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096]) -@pytest.mark.skip_global_cleanup -def test_split_num_cache_blocks_evenly(available_gpu_blocks: int, - target_cache_block_size_bytes: int, - draft_kv_size_bytes: int): - """Verify split_num_cache_blocks_evenly does not exceed original memory - allocation in bytes. - """ - num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes, - draft_kv_size_bytes, - available_gpu_blocks) - assert (num_blocks * target_cache_block_size_bytes) + ( - num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks * - target_cache_block_size_bytes) - - -@torch.inference_mode() -def test_populate_seq_ids_with_bonus_tokens(): - """ - Verify that a call to _create_output_sampler_list correctly updates - seq_with_bonus_token_in_last_step. - - seq_with_bonus_token_in_last_step is an internal data structure in - SpecDecodeWorker that tracks the sequence IDs which are assigned bonus - tokens by the target model in their last forward pass. This state is - maintained only for models relying on the KV cache, such as those using - the MultiStepWorker. - """ - batch_size = 10 - k = 5 - vocab_size = 10000 - num_sequences_with_bonus_tokens = 5 - target_worker = mock_worker(vocab_size=vocab_size, use_spec=False) - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - target_worker.execute_model.return_value = [MagicMock(spec=SamplerOutput)] - target_worker.device = 'cuda' - - set_random_seed(1) - draft_worker = mock_worker(cls=MultiStepWorker) - draft_worker.device = 'cuda' - # The sequence_ids attached to each sequence in the batch. - # The sequence at index i has seq_id assigned_seq_ids[i] - assigned_seq_ids = list(range(batch_size)) - seq_group_metadata_list, _, _ = create_batch(batch_size, - k, - seq_ids=assigned_seq_ids, - prev_output_token_len=10) - target_token_logprobs = torch.rand(batch_size, (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - accepted_token_ids = torch.randint(low=0, - high=vocab_size, - size=(batch_size, (k + 1)), - dtype=torch.int64, - device='cuda') - expected_request_id_seq_ids_mapping: dict[str, set[int]] = defaultdict(set) - for seq_group_metadata in seq_group_metadata_list: - for seq_id in seq_group_metadata.seq_data: - expected_request_id_seq_ids_mapping[ - seq_group_metadata.request_id].add(seq_id) - # Generate a random sample of sequence indexes with bonus tokens - seq_indexes_with_bonus_tokens = random.sample( - range(batch_size), num_sequences_with_bonus_tokens) - # Create a mask that is True for indices in seq_indexes_with_bonus_tokens - mask = torch.ones(batch_size, dtype=torch.bool, device='cuda') - mask[seq_indexes_with_bonus_tokens] = False - # Set the last token ID to -1 for all indices not in - # seq_indexes_with_bonus_tokens to indicate the lack of bonus token in - # those indices. - accepted_token_ids[mask, -1:] = -1 - worker = SpecDecodeWorker(draft_worker, - target_worker, - mock_spec_decode_sampler("rejection_sampler"), - disable_logprobs=False, - metrics_collector=metrics_collector) - # Initialize _seq_with_bonus_token_in_last_step with a set of sequence IDs. - # This set includes all sequence IDs in the batch as well as an additional - # `num_extra_sequence_ids` sequence IDs. Note that the sequence IDs are in - # the range [0, batch_size + num_extra_sequence_ids). - num_extra_sequence_ids = 10 - worker._seq_with_bonus_token_in_last_step = set( - range(batch_size + num_extra_sequence_ids)) - worker._create_output_sampler_list( - seq_group_metadata_list=seq_group_metadata_list, - accepted_token_ids=accepted_token_ids, - target_logprobs=target_token_logprobs, - prompt_logprobs=None, - k=k, - stage_times=(0, 0, 0)) - # Verify that _seq_with_bonus_token_in_last_step contains the following: - # 1. Sequence IDs that were already present in - # _seq_with_bonus_token_in_last_step but were not part of the current - # batch are retained. - # 2. Of the sequence IDs present in the current batch, only those with a - # bonus token are retained in _seq_with_bonus_token_in_last_step. - # Sequence IDs that are present in the current batch but do not have - # bonus tokens are removed from _seq_with_bonus_token_in_last_step. - expected_seq_ids_with_bonus_tokens = \ - set([assigned_seq_ids[i] for i in seq_indexes_with_bonus_tokens]) - additional_sequence_ids = \ - set(range(batch_size, batch_size + num_extra_sequence_ids)) - assert worker._seq_with_bonus_token_in_last_step == \ - expected_seq_ids_with_bonus_tokens.union(additional_sequence_ids) - assert worker._request_id_seq_id_mapping == \ - expected_request_id_seq_ids_mapping - - -@torch.inference_mode() -def test_handle_finished_requests(): - """ - Test to verify that finished request IDs are appropriately processed to - update the internal state of the SpecDecodeWorker. - - This test initializes the SpecDecodeWorker with mock data, marks certain - requests as finished, and ensures that the corresponding sequence IDs are - correctly removed from the internal mappings. - """ - batch_size = 32 - k = 3 - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, target_worker, - mock_spec_decode_sampler("rejection_sampler"), - metrics_collector) - # Initialize the request_id_seq_id_mapping mapping dict with a few fake - # request ids and corresponding sequence ids. - worker._request_id_seq_id_mapping = \ - {'request-1': {1,2,3}, 'request-2': {4,5,6,7}, - 'request-3': {8,9}, 'request-4': {10,11}} - # Initialize seq_with_bonus_token_in_last_step with a few fake - # sequence ids. - worker._seq_with_bonus_token_in_last_step = {1, 4, 5, 8, 9, 10} - exception_secret = 'artificial stop' - draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret) - - seq_group_metadata_list, _, _ = create_batch(batch_size, k) - # Mark requests with ids request-1 and request-3 as finished. - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=seq_group_metadata_list, - num_lookahead_slots=k, - finished_requests_ids=['request-1', 'request-3']) - - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) - # Verify that request-1 and request-3 are removed from - # request_id_seq_id_mapping - assert worker._request_id_seq_id_mapping == \ - {'request-2': {4,5,6,7}, 'request-4': {10,11}} - # Verify that all sequence ids corresponding to 'request-1' - # and 'request-3' are removed from seq_with_bonus_token_in_last_step. - assert worker._seq_with_bonus_token_in_last_step == \ - {4,5,10} - - -@pytest.mark.parametrize('k', [3]) -@pytest.mark.parametrize('batch_size', [2, 32]) -@pytest.mark.parametrize("batch_composition", - ["prefill_only", "decode_only", "mixed"]) -@torch.inference_mode() -def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str): - """ - Verify SpecDecodeWorker calls match the expected flow. - """ - vocab_size = 32_000 - draft_worker = mock_worker(cls=MultiStepWorker) - target_worker = mock_worker() - metrics_collector = MagicMock(spec=AsyncMetricsCollector) - worker = SpecDecodeWorker(draft_worker, - target_worker, - mock_spec_decode_sampler("rejection_sampler"), - disable_logprobs=False, - metrics_collector=metrics_collector) - exception_secret = 'artificial stop' - worker.scorer = mock_worker(BatchExpansionTop1Scorer) - worker.scorer.score_proposals.side_effect = ValueError(exception_secret) - - # Create batch with combination of terminal/non-terminal prefill chunks - # and decodes (different seq_ids). - decodes, _, _ = create_batch(batch_size, k) - # Pre-chunking here, get 'batch_size' chunks. - prefill, _, _ = create_batch(batch_size, - k, - prefill_chunk_size=4, - seq_ids=list(range(batch_size, - batch_size * 2))) - - if batch_composition == "prefill_only": - n_prefills = batch_size - elif batch_composition == "decode_only": - n_prefills = 0 - else: - n_prefills = random.randint(1, batch_size - 1) - n_decodes = batch_size - n_prefills - - prefill = random.sample(prefill, n_prefills) - decodes = random.sample(decodes, n_decodes) - target_group_metadata_list = prefill + decodes - execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=target_group_metadata_list, - # For prefill only batches we expect num_lookahead_slots = 0. - num_lookahead_slots=k if n_decodes > 0 else 0) - - target_token_ids = torch.randint(low=0, - high=vocab_size, - size=(1, batch_size * (k + 1)), - dtype=torch.int64, - device='cuda') - target_token_probs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_token_logprobs = torch.rand(1, - batch_size * (k + 1), - vocab_size, - dtype=torch.float32, - device='cuda') - target_output = create_sampler_output_list(target_token_ids, - target_token_probs, - target_token_logprobs) - - target_worker.execute_model.return_value = [target_output[0]] - - if not len(decodes): - worker.execute_model(execute_model_req=execute_model_req) - # no spec run (prefill only) - draft_worker.execute_model.assert_called_once_with(execute_model_req) - target_worker.execute_model.assert_called_once_with(execute_model_req) - else: - # Decode-only run OR mixed batch, scorer call fails (it's mocked) - with pytest.raises(ValueError, match=exception_secret): - worker.execute_model(execute_model_req=execute_model_req) - # but first draft still counted - assert draft_worker.get_spec_proposals.call_count == 1 - - -def test_correctly_load_weight_for_eagle(): - """ - Verify SpecDecodeWorker loads lm_head weight for eagle correctly. - """ - seed = 100 - block_size = 32 - num_gpu_blocks = 8096 // block_size - target_worker = create_worker( - Worker, - "JackFram/llama-68m", - block_size, - num_gpu_blocks, - seed, - ) - draft_worker = create_worker( - MultiStepWorker, - "abhigoyal/vllm-eagle-llama-68m-random", - block_size, - num_gpu_blocks, - seed, - model_runner_cls=TP1DraftModelRunner, - ) - - spec_decode_sampler = mock_spec_decode_sampler("rejection_sampler") - worker = SpecDecodeWorker(draft_worker, - target_worker, - spec_decode_sampler, - disable_logprobs=False) - worker.proposer_worker.maybe_load_lm_head_weight( - target_worker.model_runner.model.lm_head.weight.data) - assert torch.allclose( - worker.proposer_worker.worker.model_runner.model.lm_head.weight.data, - worker.scorer_worker.model_runner.model.lm_head.weight.data) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py deleted file mode 100644 index 9cfc618b9d95..000000000000 --- a/tests/spec_decode/test_utils.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from unittest.mock import MagicMock - -import pytest -import torch - -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.model_executor.layers.sampler import _get_ranks -from vllm.model_executor.layers.typical_acceptance_sampler import ( - TypicalAcceptanceSampler) -from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids -from vllm.spec_decode.util import (get_sampled_token_logprobs, - split_batch_by_proposal_len) - - -def test_get_all_seq_ids(): - """Verify get_all_seq_ids extracts all seq ids. - """ - expected_seq_ids = list(range(10)) + list(range(100, 110)) - - seq_group_metadata_list = [ - SequenceGroupMetadata( - request_id=str(seq_id), - is_prompt=True, - seq_data={ - seq_id: MagicMock(), - }, - sampling_params=MagicMock(), - block_tables={ - seq_id: MagicMock(), - }, - lora_request=None, - ) for seq_id in expected_seq_ids - ] - - actual_seq_ids = get_all_seq_ids(seq_group_metadata_list) - assert actual_seq_ids == expected_seq_ids - - -@pytest.fixture -def fake_sequence_group_metadata(): - seq_ids = list(range(3)) - return [ - SequenceGroupMetadata( - request_id=str(i), - is_prompt=True, - seq_data={ - i: MagicMock(), - }, - sampling_params=MagicMock(), - block_tables={ - i: MagicMock(), - }, - lora_request=None, - ) for i in seq_ids - ] - - -def test_filter_zero_length_proposals(fake_sequence_group_metadata): - proposal_lens = [0, 1, 0] - _, (filtered_groups, - indices) = split_batch_by_proposal_len(fake_sequence_group_metadata, - proposal_lens) - - expected_groups = [ - fake_sequence_group_metadata[0], fake_sequence_group_metadata[2] - ] - expected_indices = [0, 2] - - assert filtered_groups == expected_groups - assert indices == expected_indices - - -def test_filter_non_zero_length_proposals(fake_sequence_group_metadata): - proposal_lens = [0, 1, 2] - (filtered_groups, - indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata, - proposal_lens) - - expected_groups = [ - fake_sequence_group_metadata[1], fake_sequence_group_metadata[2] - ] - expected_indices = [1, 2] - - assert filtered_groups == expected_groups - assert indices == expected_indices - - -def test_empty_inputs(): - _, (filtered_groups, indices) = split_batch_by_proposal_len([], []) - - assert filtered_groups == [] - assert indices == [] - - -def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata): - proposal_lens = [0, 0, 0] - (filtered_groups, - indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata, - proposal_lens) - - assert filtered_groups == [] - assert indices == [] - - -def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata): - proposal_lens = [1, 1, 1] - _, (filtered_groups, - indices) = split_batch_by_proposal_len(fake_sequence_group_metadata, - proposal_lens) - - assert filtered_groups == [] - assert indices == [] - - -def mock_spec_decode_sampler(acceptance_sampler_method): - """ - Returns either a RejectionSampler or TypicalAcceptanceSampler - object depending on whether acceptance_sampler_method is - 'rejection_sampler' or 'typical_acceptance_sampler' respectively. - """ - if acceptance_sampler_method == "rejection_sampler": - sampler = MagicMock(spec=RejectionSampler) - sampler.token_id_dtype = torch.int64 - return sampler - elif acceptance_sampler_method == "typical_acceptance_sampler": - sampler = MagicMock(spec=TypicalAcceptanceSampler) - sampler.token_id_dtype = torch.int64 - return sampler - else: - raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") - - -def test_get_sampled_token_logprobs(): - """Verify get_sampled_token_logprobs returns consistent rankings - with regular get_ranks when probabilities match exactly. - """ - logprob_tensor = torch.tensor( - [[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size) - sampled_token_tensor = torch.tensor([[1, - 0]]) # shape (num_steps, batch_size) - ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor, - sampled_token_tensor) - - ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)), - sampled_token_tensor.reshape(-1)) - - assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py deleted file mode 100644 index 1733f66feec0..000000000000 --- a/tests/spec_decode/utils.py +++ /dev/null @@ -1,290 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Sequence as GenericSequence -from itertools import count -from typing import Callable, Optional, TypeVar, Union -from unittest.mock import MagicMock - -import torch - -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.utils import set_random_seed -from vllm.sampling_params import SamplingParams -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceData, SequenceGroupMetadata, SequenceOutput) -from vllm.utils import get_distributed_init_method, get_ip, get_open_port -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.model_runner import ModelRunner -from vllm.worker.worker import Worker - -T = TypeVar("T", bound=Worker) - - -def round_up_to_next_block(seq_len: int, block_size: int) -> int: - return (seq_len + block_size - 1) // block_size - - -def mock_worker(cls=None, - vocab_size: int = 30_000, - max_model_len: int = 2048, - rank: int = 0, - use_spec: bool = True) -> MagicMock: - if cls is None: - cls = Worker - - spec = cls if use_spec else None - - worker = MagicMock(spec=spec) - worker.vocab_size = vocab_size - worker.max_model_len = max_model_len - worker.rank = rank - worker.device = 'cuda:0' - return worker - - -def patch_execute_model_with_seeds(worker: Worker, rand_seeds: list[int]): - seed_iter = iter(rand_seeds) - original_execute_model = worker.execute_model - - def new_execute_model(*args, **kwargs): - result = original_execute_model(*args, **kwargs) - set_random_seed(next(seed_iter)) - return result - - return new_execute_model - - -def zero_kv_cache(cache_engine: list[CacheEngine]): - assert cache_engine[0].gpu_cache - for key_blocks, value_blocks in cache_engine[0].gpu_cache: - key_blocks.zero_() - value_blocks.zero_() - - -def create_worker(cls: Callable[..., T], - model_name: str, - block_size: int, - num_gpu_blocks: int, - seed: int, - is_driver_worker: bool = True, - enforce_eager: bool = True, - model_runner_cls: Optional[ModelRunner] = None, - dtype: Optional[str] = "auto") -> T: - engine_args = EngineArgs( - model=model_name, - seed=seed, - block_size=block_size, - enforce_eager=enforce_eager, - dtype=dtype, - ) - engine_config = engine_args.create_engine_config() - - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - - worker = cls( - vllm_config=engine_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - is_driver_worker=is_driver_worker, - model_runner_cls=model_runner_cls, - ) - - worker.init_device() - worker.load_model() - - engine_config.cache_config.num_gpu_blocks = num_gpu_blocks - engine_config.cache_config.num_cpu_blocks = 0 - worker.initialize_cache( - num_gpu_blocks=engine_config.cache_config.num_gpu_blocks, - num_cpu_blocks=engine_config.cache_config.num_cpu_blocks) - - return worker - - -def create_seq_group_metadata_from_prompts( - prompts: list[list[int]], - num_gpu_blocks: int, - block_size: int, - final_prompt_lens: list[int], - continuations: Optional[list[list[int]]] = None, - seq_ids: Optional[list[int]] = None, -) -> list[SequenceGroupMetadata]: - - if continuations is None: - continuations = [[] for _ in prompts] - - if seq_ids is None: - seq_ids = list(i for i, _ in enumerate(prompts)) - - free_gpu_blocks = list(range(num_gpu_blocks)) - - block_allocations = { - i: [ - free_gpu_blocks.pop() - for _ in range(round_up_to_next_block(final_len, block_size)) - ] - for i, final_len in enumerate(final_prompt_lens) - } - - seq_grou_metadata_list = [] - for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)): - data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) - data.update_num_computed_tokens( - len(prompt_token_ids) + len(cont_token_ids) - 1) - seq_data = {i: data} - seq_grou_metadata_list.append( - SequenceGroupMetadata( - request_id=str(i), - is_prompt=len(cont_token_ids) == 0, - seq_data=seq_data, - sampling_params=SamplingParams(temperature=0.0), - block_tables={i: block_allocations[i][:]}, - )) - return seq_grou_metadata_list - - -def create_chunked_seq_group_metadata_from_prompt( - prompt: list[int], - num_gpu_blocks: int, - chunk_size: int, - block_size: int, - seq_id: Optional[int] = None) -> list[SequenceGroupMetadata]: - - if seq_id is None: - seq_id = 0 - - free_gpu_blocks = list(range(num_gpu_blocks)) - - block_allocations = [ - free_gpu_blocks.pop() - for _ in range(round_up_to_next_block(len(prompt), block_size)) - ] - - seq_group_metadata_list = [] - for i, idx in enumerate(range(0, len(prompt), chunk_size)): - chunk_ids = prompt[idx:idx + chunk_size] - data = SequenceData.from_seqs(prompt) - data.update_num_computed_tokens(idx) - seq_data = {i: data} - seq_group_metadata_list.append( - SequenceGroupMetadata( - request_id=str(seq_id), - is_prompt=True, - do_sample=idx + chunk_size >= len(prompt), # terminal chunk - seq_data=seq_data, - sampling_params=SamplingParams(temperature=0.0), - block_tables={i: block_allocations}, - token_chunk_size=len(chunk_ids))) - return seq_group_metadata_list - - -def assert_logprobs_dict_allclose( - actual_logprobs: list[dict[int, Logprob]], - expected_logprobs: list[dict[int, Logprob]]) -> None: - for single_step_actual_logprobs, single_step_expected_logprobs in zip( - actual_logprobs, expected_logprobs): - assert set(single_step_actual_logprobs.keys()) == set( - single_step_expected_logprobs.keys()) - for token_id in single_step_actual_logprobs: - actual = torch.tensor( - single_step_actual_logprobs[token_id].logprob) - expected = torch.tensor( - single_step_expected_logprobs[token_id].logprob) - torch.testing.assert_close(actual, expected) - - -def create_sampler_output_list( - token_ids: torch.Tensor, - probs: GenericSequence[Optional[torch.Tensor]], - logprobs: GenericSequence[Optional[torch.Tensor]], - seq_ids: Optional[list[int]] = None) -> list[SamplerOutput]: - num_steps, batch_size = token_ids.shape - token_ids_by_step = token_ids.tolist() - - if seq_ids is None: - seq_ids = list(range(batch_size)) - - return [ - SamplerOutput(outputs=[ - CompletionSequenceGroupOutput( - samples=[ - SequenceOutput( - output_token=token_id, - parent_seq_id=seq_ids[seq_index], - logprobs={token_id: Logprob(0)}, - ) - ], - prompt_logprobs=None, - ) for seq_index, token_id in enumerate(token_ids_by_step[step]) - ], - sampled_token_probs=probs[step], - logprobs=logprobs[step], - sampled_token_ids=token_ids[step]) - for step in range(num_steps) - ] - - -def create_batch(batch_size, - k, - prompt_len: Union[int, list[int]] = 10, - prev_output_token_len: int = 10, - seq_ids: Optional[list[int]] = None, - num_gpu_blocks: Optional[int] = None, - block_size: Optional[int] = None, - prefill_chunk_size: Optional[int] = None): - if block_size is None: - block_size = 8 - - if num_gpu_blocks is None: - num_gpu_blocks = 2048 // block_size - - iterator = count() - - if isinstance(prompt_len, int): - prompt_lens = [prompt_len for _ in range(batch_size)] - else: - prompt_lens = prompt_len - - prompts = [[next(iterator) for _ in range(p_len)] for p_len in prompt_lens] - - if prefill_chunk_size: - # Create a batch of chunked prompts. - if not seq_ids: - seq_ids = list(range(len(prompts))) - seq_group_metadata_list = [] - for p, sid in zip(prompts, seq_ids): - seq_group_metadata_list += \ - create_chunked_seq_group_metadata_from_prompt( - p, num_gpu_blocks, prefill_chunk_size, block_size, sid) - seq_group_metadata_list = seq_group_metadata_list[:batch_size] - prev_output_tokens = [] - else: - prev_output_tokens = [[ - next(iterator) for _ in range(prev_output_token_len) - ] for _ in range(batch_size)] - final_prompt_lens = [ - len(prompt) + len(prev_output_token) + k + 1 - for prompt, prev_output_token in zip(prompts, prev_output_tokens) - ] - - seq_group_metadata_list = create_seq_group_metadata_from_prompts( - prompts, num_gpu_blocks, block_size, final_prompt_lens, - prev_output_tokens, seq_ids) - return seq_group_metadata_list, prompts, prev_output_tokens - - -def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs): - if prefill_chunk_size > 0: - llm_kwargs.update( - **{ - "enable_chunked_prefill": True, - "max_num_batched_tokens": prefill_chunk_size, - "max_num_seqs": prefill_chunk_size - }) - else: - llm_kwargs["enable_chunked_prefill"] = False diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index cd59d579e8d6..18aa4c88c033 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,9 +1,28 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Callable + import pytest +from vllm import LLM, EngineArgs from vllm.distributed import cleanup_dist_env_and_memory +from vllm.model_executor.model_loader import tensorizer as tensorizer_mod from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor.abstract import UniProcExecutor +from vllm.worker.worker_base import WorkerWrapperBase + +MODEL_REF = "facebook/opt-125m" + + +@pytest.fixture() +def model_ref(): + return MODEL_REF + + +@pytest.fixture(autouse=True) +def allow_insecure_serialization(monkeypatch): + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") @pytest.fixture(autouse=True) @@ -11,7 +30,73 @@ def cleanup(): cleanup_dist_env_and_memory(shutdown_ray=True) +@pytest.fixture() +def just_serialize_model_tensors(model_ref, monkeypatch, tmp_path): + + def noop(*args, **kwargs): + return None + + args = EngineArgs(model=model_ref) + tc = TensorizerConfig(tensorizer_uri=f"{tmp_path}/model.tensors") + + monkeypatch.setattr(tensorizer_mod, "serialize_extra_artifacts", noop) + + tensorizer_mod.tensorize_vllm_model(args, tc) + yield tmp_path + + @pytest.fixture(autouse=True) def tensorizer_config(): config = TensorizerConfig(tensorizer_uri="vllm") return config + + +@pytest.fixture() +def model_path(model_ref, tmp_path): + yield tmp_path / model_ref / "model.tensors" + + +def assert_from_collective_rpc(engine: LLM, closure: Callable, + closure_kwargs: dict): + res = engine.collective_rpc(method=closure, kwargs=closure_kwargs) + return all(res) + + +# This is an object pulled from tests/v1/engine/test_engine_core.py +# Modified to strip the `load_model` method from its `_init_executor` +# method. It's purely used as a dummy utility to run methods that test +# Tensorizer functionality +class DummyExecutor(UniProcExecutor): + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + self.driver_worker = WorkerWrapperBase(vllm_config=self.vllm_config, + rpc_rank=0) + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + local_rank = 0 + # set local rank as the device index if specified + device_info = self.vllm_config.device_config.device.__str__().split( + ":") + if len(device_info) > 1: + local_rank = int(device_info[1]) + rank = 0 + is_driver_worker = True + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + self.collective_rpc("init_worker", args=([kwargs], )) + self.collective_rpc("init_device") + + @property + def max_concurrent_batches(self) -> int: + return 2 + + def shutdown(self): + if hasattr(self, 'thread_pool'): + self.thread_pool.shutdown(wait=False) diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index c97f5968d58a..b8d7892e57f2 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -1,36 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import gc +import json import os import pathlib import subprocess +import sys +from typing import Any import pytest import torch -from vllm import SamplingParams +import vllm.model_executor.model_loader.tensorizer +from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -# yapf conflicts with isort for this docstring # yapf: disable from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig, TensorSerializer, is_vllm_tensorized, open_stream, tensorize_vllm_model) +from vllm.model_executor.model_loader.tensorizer_loader import ( + BLACKLISTED_TENSORIZER_ARGS) # yapf: enable from vllm.utils import PlaceholderModule -from ..utils import VLLM_PATH +from ..utils import VLLM_PATH, RemoteOpenAIServer +from .conftest import DummyExecutor, assert_from_collective_rpc try: + import tensorizer from tensorizer import EncryptionParams except ImportError: tensorizer = PlaceholderModule("tensorizer") # type: ignore[assignment] EncryptionParams = tensorizer.placeholder_attr("EncryptionParams") + +class TensorizerCaughtError(Exception): + pass + + EXAMPLES_PATH = VLLM_PATH / "examples" +pytest_plugins = "pytest_asyncio", + prompts = [ "Hello, my name is", "The president of the United States is", @@ -40,9 +55,37 @@ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0) -model_ref = "facebook/opt-125m" -tensorize_model_for_testing_script = os.path.join( - os.path.dirname(__file__), "tensorize_vllm_model_for_testing.py") + +def patch_init_and_catch_error(self, obj, method_name, + expected_error: type[Exception]): + original = getattr(obj, method_name, None) + if original is None: + raise ValueError("Method '{}' not found.".format(method_name)) + + def wrapper(*args, **kwargs): + try: + return original(*args, **kwargs) + except expected_error as err: + raise TensorizerCaughtError from err + + setattr(obj, method_name, wrapper) + + self.load_model() + + +def assert_specific_tensorizer_error_is_raised( + executor, + obj: Any, + method_name: str, + expected_error: type[Exception], +): + with pytest.raises(TensorizerCaughtError): + executor.collective_rpc(patch_init_and_catch_error, + args=( + obj, + method_name, + expected_error, + )) def is_curl_installed(): @@ -60,32 +103,12 @@ def write_keyfile(keyfile_path: str): f.write(encryption_params.key) -@pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") -def test_can_deserialize_s3(vllm_runner): - model_ref = "EleutherAI/pythia-1.4b" - tensorized_path = f"s3://tensorized/{model_ref}/fp16/model.tensors" - - with vllm_runner(model_ref, - load_format="tensorizer", - model_loader_extra_config=TensorizerConfig( - tensorizer_uri=tensorized_path, - num_readers=1, - s3_endpoint="object.ord1.coreweave.com", - )) as loaded_hf_model: - deserialized_outputs = loaded_hf_model.generate( - prompts, sampling_params) - # noqa: E501 - - assert deserialized_outputs - - @pytest.mark.skipif(not is_curl_installed(), reason="cURL is not installed") def test_deserialized_encrypted_vllm_model_has_same_outputs( - vllm_runner, tmp_path): + model_ref, vllm_runner, tmp_path, model_path): args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: - model_path = tmp_path / (model_ref + ".tensors") - key_path = tmp_path / (model_ref + ".key") + key_path = tmp_path / model_ref / "model.key" write_keyfile(key_path) outputs = vllm_model.generate(prompts, sampling_params) @@ -111,9 +134,9 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs( def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, - tmp_path): + tmp_path, model_ref, + model_path): with hf_runner(model_ref) as hf_model: - model_path = tmp_path / (model_ref + ".tensors") max_tokens = 50 outputs = hf_model.generate_greedy(prompts, max_tokens=max_tokens) with open_stream(model_path, "wb+") as stream: @@ -123,7 +146,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, with vllm_runner(model_ref, load_format="tensorizer", model_loader_extra_config=TensorizerConfig( - tensorizer_uri=model_path, + tensorizer_uri=str(model_path), num_readers=1, )) as loaded_hf_model: deserialized_outputs = loaded_hf_model.generate_greedy( @@ -132,7 +155,7 @@ def test_deserialized_hf_model_has_same_outputs(hf_runner, vllm_runner, assert outputs == deserialized_outputs -def test_load_without_tensorizer_load_format(vllm_runner, capfd): +def test_load_without_tensorizer_load_format(vllm_runner, capfd, model_ref): model = None try: model = vllm_runner( @@ -150,7 +173,8 @@ def test_load_without_tensorizer_load_format(vllm_runner, capfd): torch.cuda.empty_cache() -def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd): +def test_raise_value_error_on_invalid_load_format(vllm_runner, capfd, + model_ref): model = None try: model = vllm_runner( @@ -208,7 +232,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( outputs = base_model.generate(prompts, sampling_params) # load model with two shards and serialize with encryption - model_path = str(tmp_path / (model_ref + "-%02d.tensors")) + model_path = str(tmp_path / model_ref / "model-%02d.tensors") key_path = tmp_path / (model_ref + ".key") tensorizer_config = TensorizerConfig( @@ -242,13 +266,12 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( @pytest.mark.flaky(reruns=3) -def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): +def test_vllm_tensorized_model_has_same_outputs(model_ref, vllm_runner, + tmp_path, model_path): gc.collect() torch.cuda.empty_cache() - model_ref = "facebook/opt-125m" - model_path = tmp_path / (model_ref + ".tensors") config = TensorizerConfig(tensorizer_uri=str(model_path)) - args = EngineArgs(model=model_ref, device="cuda") + args = EngineArgs(model=model_ref) with vllm_runner(model_ref) as vllm_model: outputs = vllm_model.generate(prompts, sampling_params) @@ -264,3 +287,243 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): # noqa: E501 assert outputs == deserialized_outputs + + +def test_load_with_just_model_tensors(just_serialize_model_tensors, model_ref): + # For backwards compatibility, ensure Tensorizer can be still be loaded + # for inference by passing the model reference name, not a local/S3 dir, + # and the location of the model tensors + + model_dir = just_serialize_model_tensors + + extra_config = {"tensorizer_uri": f"{model_dir}/model.tensors"} + + ## Start OpenAI API server + args = [ + "--load-format", + "tensorizer", + "--model-loader-extra-config", + json.dumps(extra_config), + ] + + with RemoteOpenAIServer(model_ref, args): + # This test only concerns itself with being able to load the model + # and successfully initialize the server + pass + + +def test_assert_serialization_kwargs_passed_to_tensor_serializer(tmp_path): + + serialization_params = { + "limit_cpu_concurrency": 2, + } + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + llm = LLM(model=model_ref, ) + + def serialization_test(self, *args, **kwargs): + # This is performed in the ephemeral worker process, so monkey-patching + # will actually work, and cleanup is guaranteed so don't + # need to reset things + + original_dict = serialization_params + to_compare = {} + + original = tensorizer.serialization.TensorSerializer.__init__ + + def tensorizer_serializer_wrapper(self, *args, **kwargs): + nonlocal to_compare + to_compare = kwargs.copy() + return original(self, *args, **kwargs) + + tensorizer.serialization.TensorSerializer.__init__ = ( + tensorizer_serializer_wrapper) + + tensorizer_config = TensorizerConfig(**kwargs["tensorizer_config"]) + self.save_tensorized_model(tensorizer_config=tensorizer_config, ) + return to_compare | original_dict == to_compare + + kwargs = {"tensorizer_config": config.to_serializable()} + + assert assert_from_collective_rpc(llm, serialization_test, kwargs) + + +def test_assert_deserialization_kwargs_passed_to_tensor_deserializer( + tmp_path, capfd): + + deserialization_kwargs = { + "num_readers": "bar", # illegal value + } + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + loader_tc = TensorizerConfig( + tensorizer_uri=str(model_path), + deserialization_kwargs=deserialization_kwargs, + ) + + engine_args = EngineArgs( + model="facebook/opt-125m", + load_format="tensorizer", + model_loader_extra_config=loader_tc.to_serializable(), + ) + + vllm_config = engine_args.create_engine_config() + executor = DummyExecutor(vllm_config) + + assert_specific_tensorizer_error_is_raised( + executor, + tensorizer.serialization.TensorDeserializer, + "__init__", + TypeError, + ) + + +def test_assert_stream_kwargs_passed_to_tensor_deserializer(tmp_path, capfd): + + deserialization_kwargs = { + "num_readers": 1, + } + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + stream_kwargs = {"mode": "foo"} + + loader_tc = TensorizerConfig( + tensorizer_uri=str(model_path), + deserialization_kwargs=deserialization_kwargs, + stream_kwargs=stream_kwargs, + ) + + engine_args = EngineArgs( + model="facebook/opt-125m", + load_format="tensorizer", + model_loader_extra_config=loader_tc.to_serializable(), + ) + + vllm_config = engine_args.create_engine_config() + executor = DummyExecutor(vllm_config) + + assert_specific_tensorizer_error_is_raised( + executor, + vllm.model_executor.model_loader.tensorizer, + "open_stream", + ValueError, + ) + + +@pytest.mark.asyncio +async def test_serialize_and_serve_entrypoints(tmp_path): + model_ref = "facebook/opt-125m" + + suffix = "test" + try: + result = subprocess.run([ + sys.executable, + f"{VLLM_PATH}/examples/others/tensorize_vllm_model.py", "--model", + model_ref, "serialize", "--serialized-directory", + str(tmp_path), "--suffix", suffix, "--serialization-kwargs", + '{"limit_cpu_concurrency": 4}' + ], + check=True, + capture_output=True, + text=True) + except subprocess.CalledProcessError as e: + print("Tensorizing failed.") + print("STDOUT:\n", e.stdout) + print("STDERR:\n", e.stderr) + raise + + assert "Successfully serialized" in result.stdout + + # Next, try to serve with vllm serve + model_uri = tmp_path / "vllm" / model_ref / suffix / "model.tensors" + + model_loader_extra_config = { + "tensorizer_uri": str(model_uri), + "stream_kwargs": { + "force_http": False, + }, + "deserialization_kwargs": { + "verify_hash": True, + "num_readers": 8, + } + } + + cmd = [ + "-m", "vllm.entrypoints.cli.main", "serve", "--host", "localhost", + "--load-format", "tensorizer", model_ref, + "--model-loader-extra-config", + json.dumps(model_loader_extra_config, indent=2) + ] + + proc = await asyncio.create_subprocess_exec( + sys.executable, + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + + assert proc.stdout is not None + fut = proc.stdout.readuntil(b"Application startup complete.") + + try: + await asyncio.wait_for(fut, 180) + except asyncio.TimeoutError: + pytest.fail("Server did not start successfully") + finally: + proc.terminate() + await proc.communicate() + + +@pytest.mark.parametrize("illegal_value", BLACKLISTED_TENSORIZER_ARGS) +def test_blacklisted_parameter_for_loading(tmp_path, vllm_runner, capfd, + illegal_value): + + serialization_params = { + "limit_cpu_concurrency": 2, + } + + model_ref = "facebook/opt-125m" + model_path = tmp_path / (model_ref + ".tensors") + config = TensorizerConfig(tensorizer_uri=str(model_path), + serialization_kwargs=serialization_params) + + args = EngineArgs(model=model_ref) + tensorize_vllm_model(args, config) + + loader_tc = {"tensorizer_uri": str(model_path), illegal_value: "foo"} + + try: + vllm_runner( + model_ref, + load_format="tensorizer", + model_loader_extra_config=loader_tc, + ) + except RuntimeError: + out, err = capfd.readouterr() + combined_output = out + err + assert (f"ValueError: {illegal_value} is not an allowed " + f"Tensorizer argument.") in combined_output diff --git a/tests/test_config.py b/tests/test_config.py index 6ed7ef9e6a40..015baef91811 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -7,7 +7,7 @@ from vllm.compilation.backends import VllmBackend from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig, - get_field) + get_field, update_config) from vllm.model_executor.layers.pooler import PoolingType from vllm.platforms import current_platform @@ -46,6 +46,34 @@ def test_get_field(): assert c.default_factory is MISSING +@dataclass +class _TestNestedConfig: + a: _TestConfigFields = field( + default_factory=lambda: _TestConfigFields(a=0)) + + +def test_update_config(): + # Simple update + config1 = _TestConfigFields(a=0) + new_config1 = update_config(config1, {"a": 42}) + assert new_config1.a == 42 + # Nonexistent field + with pytest.raises(AssertionError): + new_config1 = update_config(config1, {"nonexistent": 1}) + # Nested update with dataclass + config2 = _TestNestedConfig() + new_inner_config = _TestConfigFields(a=1, c="new_value") + new_config2 = update_config(config2, {"a": new_inner_config}) + assert new_config2.a == new_inner_config + # Nested update with dict + config3 = _TestNestedConfig() + new_config3 = update_config(config3, {"a": {"c": "new_value"}}) + assert new_config3.a.c == "new_value" + # Nested update with invalid type + with pytest.raises(AssertionError): + new_config3 = update_config(config3, {"a": "new_value"}) + + @pytest.mark.parametrize( ("model_id", "expected_runner_type", "expected_task"), [ @@ -54,7 +82,7 @@ def test_get_field(): ("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"), ("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"), ("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"), - ("openai/whisper-small", "transcription", "transcription"), + ("openai/whisper-small", "generate", "transcription"), ], ) def test_auto_task(model_id, expected_runner_type, expected_task): @@ -69,7 +97,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task): ) assert config.runner_type == expected_runner_type - assert config.task == expected_task + + if config.runner_type == "pooling": + assert config.task == expected_task + else: + assert expected_task in config.supported_tasks @pytest.mark.parametrize( @@ -98,11 +130,50 @@ def test_score_task(model_id, expected_runner_type, expected_task): assert config.task == expected_task +@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"), + [ + ("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"), + ]) +def test_draft_task(model_id, expected_runner_type, expected_task): + config = ModelConfig( + model_id, + runner="draft", + tokenizer=model_id, + seed=0, + dtype="float16", + ) + + assert config.runner_type == expected_runner_type + assert config.task == expected_task + + +@pytest.mark.parametrize( + ("model_id", "expected_runner_type", "expected_task"), + [ + ("openai/whisper-small", "generate", "transcription"), + ], +) +def test_transcription_task(model_id, expected_runner_type, expected_task): + config = ModelConfig( + model_id, + task="transcription", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + assert config.runner_type == expected_runner_type + assert config.task == expected_task + + @pytest.mark.parametrize(("model_id", "bad_task"), [ ("Qwen/Qwen2.5-Math-RM-72B", "generate"), + ("Qwen/Qwen3-0.6B", "transcription"), ]) def test_incorrect_task(model_id, bad_task): - with pytest.raises(ValueError, match=r"does not support the .* task"): + with pytest.raises(ValueError, match=r"does not support task=.*"): ModelConfig( model_id, task=bad_task, diff --git a/tests/test_sequence.py b/tests/test_sequence.py index a782a3bf7716..c734c8514a6d 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -29,7 +29,6 @@ def test_sampler_output_initialization(sampler_output, sample_outputs): assert len(sampler_output) == len(sample_outputs) assert sampler_output.sampled_token_probs is None assert sampler_output.sampled_token_ids is None - assert sampler_output.spec_decode_worker_metrics is None def test_sampler_output_getitem(sampler_output, sample_outputs): diff --git a/tests/test_utils.py b/tests/test_utils.py index a165d2d7213a..53a34642e5ba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,15 +14,18 @@ import pytest import torch import zmq +from transformers import AutoTokenizer from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.transformers_utils.detokenizer_utils import ( + convert_ids_list_to_tokens) from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, common_broadcastable_dtype, - deprecate_kwargs, get_open_port, get_tcp_uri, - is_lossless_cast, join_host_port, make_zmq_path, - make_zmq_socket, memory_profiling, + current_stream, deprecate_kwargs, get_open_port, + get_tcp_uri, is_lossless_cast, join_host_port, + make_zmq_path, make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_host_port, split_zmq_path, supports_kw, swap_dict_values) @@ -455,6 +458,31 @@ def test_bind_kv_cache(): assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] +def test_bind_kv_cache_kv_sharing(): + from vllm.attention import Attention + + ctx = { + 'layers.0.self_attn': Attention(32, 128, 0.1), + 'layers.1.self_attn': Attention(32, 128, 0.1), + 'layers.2.self_attn': Attention(32, 128, 0.1), + 'layers.3.self_attn': Attention(32, 128, 0.1), + } + kv_cache = [ + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + torch.zeros((1, )), + ] + shared_kv_cache_layers = { + 'layers.2.self_attn': 'layers.1.self_attn', + 'layers.3.self_attn': 'layers.0.self_attn' + } + bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers) + assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] + assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1] + assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0] + def test_bind_kv_cache_non_attention(): from vllm.attention import Attention @@ -918,3 +946,52 @@ def test_split_host_port(): def test_join_host_port(): assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555" assert join_host_port("::1", 5555) == "[::1]:5555" + + +def test_convert_ids_list_to_tokens(): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct") + token_ids = tokenizer.encode("Hello, world!") + # token_ids = [9707, 11, 1879, 0] + assert tokenizer.convert_ids_to_tokens(token_ids) == [ + 'Hello', ',', 'Ġworld', '!' + ] + tokens = convert_ids_list_to_tokens(tokenizer, token_ids) + assert tokens == ['Hello', ',', ' world', '!'] + + +def test_current_stream_multithread(): + import threading + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + main_default_stream = torch.cuda.current_stream() + child_stream = torch.cuda.Stream() + + thread_stream_ready = threading.Event() + thread_can_exit = threading.Event() + + def child_thread_func(): + with torch.cuda.stream(child_stream): + thread_stream_ready.set() + thread_can_exit.wait(timeout=10) + + child_thread = threading.Thread(target=child_thread_func) + child_thread.start() + + try: + assert thread_stream_ready.wait( + timeout=5), "Child thread failed to enter stream context in time" + + main_current_stream = current_stream() + + assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread" + assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream" + + # Notify child thread it can exit + thread_can_exit.set() + + finally: + # Ensure child thread exits properly + child_thread.join(timeout=5) + if child_thread.is_alive(): + pytest.fail("Child thread failed to exit properly") diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f8aeba8301b1..ccafc8846127 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -393,7 +393,7 @@ def test_decode_prompt_logprobs_chunked_prefill( logprobs=5, prompt_logprobs=5, temperature=0.0) - vllm_results = vllm_model.model.generate( + vllm_results = vllm_model.llm.generate( example_prompts, sampling_params=vllm_sampling_params) for idx, result in enumerate(vllm_results): diff --git a/tests/tokenization/test_do_lower_case.py b/tests/tokenization/test_do_lower_case.py new file mode 100644 index 000000000000..7aa655e1c3b4 --- /dev/null +++ b/tests/tokenization/test_do_lower_case.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.transformers_utils.tokenizer import get_tokenizer + +TOKENIZER_NAMES = ["BAAI/bge-base-en"] + + +@pytest.mark.parametrize("tokenizer_name", TOKENIZER_NAMES) +@pytest.mark.parametrize("n_tokens", [510]) +def test_special_tokens(tokenizer_name: str, n_tokens: int): + tokenizer = get_tokenizer(tokenizer_name, revision="main") + + prompts = '[UNK]' * n_tokens + prompt_token_ids = tokenizer.encode(prompts) + assert len(prompt_token_ids) == n_tokens + 2 diff --git a/tests/tool_use/test_glm4_moe_tool_parser.py b/tests/tool_use/test_glm4_moe_tool_parser.py new file mode 100644 index 000000000000..478f4b916672 --- /dev/null +++ b/tests/tool_use/test_glm4_moe_tool_parser.py @@ -0,0 +1,410 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json + +import pytest + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +pytest.skip("skip glm4_moe parser test", allow_module_level=True) +# Use a common model that is likely to be available +MODEL = "THUDM/GLM-4.5" + + +@pytest.fixture(scope="module") +def glm4_moe_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def glm4_moe_tool_parser(glm4_moe_tokenizer): + return Glm4MoeModelToolParser(glm4_moe_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 0 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function.name == expected_tool_call.function.name + # Compare arguments as JSON objects to handle formatting differences + actual_args = json.loads(actual_tool_call.function.arguments) + expected_args = json.loads(expected_tool_call.function.arguments) + assert actual_args == expected_args + + +def test_extract_tool_calls_no_tools(glm4_moe_tool_parser): + model_output = "This is a test" + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_call", + "multiple_tool_calls", + "tool_call_with_content_before", + "tool_call_with_mixed_args", + "tool_call_with_chinese_content", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """<tool_call>get_current_weather + <arg_key>city</arg_key> + <arg_value>Dallas</arg_value> + <arg_key>state</arg_key> + <arg_value>TX</arg_value> + <arg_key>unit</arg_key> + <arg_value>fahrenheit</arg_value> + </tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )) + ], + None, + ), + ( + """<tool_call>get_current_weather + <arg_key>city</arg_key> + <arg_value>Dallas</arg_value> + <arg_key>state</arg_key> + <arg_value>TX</arg_value> + <arg_key>unit</arg_key> + <arg_value>fahrenheit</arg_value> + </tool_call> + <tool_call>get_current_weather + <arg_key>city</arg_key> + <arg_value>Orlando</arg_value> + <arg_key>state</arg_key> + <arg_value>FL</arg_value> + <arg_key>unit</arg_key> + <arg_value>fahrenheit</arg_value> + </tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit", + }), + )), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit", + }), + )), + ], + None, + ), + ( + """I'll help you check the weather. <tool_call>get_current_weather + <arg_key>city</arg_key> + <arg_value>Seattle</arg_value> + <arg_key>state</arg_key> + <arg_value>WA</arg_value> + <arg_key>unit</arg_key> + <arg_value>celsius</arg_value> + </tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "Seattle", + "state": "WA", + "unit": "celsius", + }), + )) + ], + "I'll help you check the weather.", + ), + ( + """<tool_call>get_current_weather + <arg_key>city</arg_key> + <arg_value>New York</arg_value> + <arg_key>state</arg_key> + <arg_value>NY</arg_value> + <arg_key>unit</arg_key> + <arg_value>celsius</arg_value> + </tool_call>""", + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "New York", + "state": "NY", + "unit": "celsius", + }), + )) + ], + None, + ), + ("""I will help you get the weather.<tool_call>get_weather + <arg_key>city</arg_key> + <arg_value>Beijing</arg_value> + <arg_key>date</arg_key> + <arg_value>2025-08-01</arg_value> + </tool_call>""", [ + ToolCall(function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Beijing", + "date": "2025-08-01", + }), + )) + ], "I will help you get the weather."), + ], +) +def test_extract_tool_calls(glm4_moe_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser): + """Test tool extraction when thinking tags are present.""" + model_output = """<think>I want to get the weather.</think> + +I will help you get the weather. +<tool_call>get_weather +<arg_key>city</arg_key> +<arg_value>Beijing</arg_value> +<arg_key>date</arg_key> +<arg_value>2025-08-01</arg_value> +</tool_call>""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[0].function.name == "get_weather" + + expected_content = """<think>I want to get the weather.</think> + +I will help you get the weather.""" + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser): + """Test that malformed XML is handled gracefully.""" + model_output = """<tool_call>get_weather +<arg_key>city</arg_key> +<arg_value>Seattle</arg_value> +<arg_key>incomplete_arg +<arg_value>value</arg_value> +</tool_call>""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + # Should handle malformed XML gracefully + # The parser should either extract what it can or return no tool calls + # depending on how robust we want the parsing to be + assert isinstance(extracted_tool_calls.tools_called, bool) + assert isinstance(extracted_tool_calls.tool_calls, list) + + +def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser): + """Test tool calls with no arguments.""" + model_output = """<tool_call>get_current_time +</tool_call>""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_time" + # Empty arguments should result in empty JSON object + assert extracted_tool_calls.tool_calls[0].function.arguments == "{}" + + +def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser): + """Test extraction with mixed content and multiple tool calls.""" + model_output = """I will help you get the weather info. + +<tool_call>get_weather +<arg_key>city</arg_key> +<arg_value>Beijing</arg_value> +<arg_key>date</arg_key> +<arg_value>2025-08-01</arg_value> +</tool_call> + +meaningwhile, I will also check the weather in Shanghai. + +<tool_call>get_weather +<arg_key>city</arg_key> +<arg_value>Shanghai</arg_value> +<arg_key>date</arg_key> +<arg_value>2025-08-01</arg_value> +</tool_call>""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 2 + + # Check first tool call + assert extracted_tool_calls.tool_calls[0].function.name == "get_weather" + args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args1["city"] == "Beijing" + assert args1["date"] == "2025-08-01" + + # Check second tool call + assert extracted_tool_calls.tool_calls[1].function.name == "get_weather" + args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments) + assert args2["city"] == "Shanghai" + assert args2["date"] == "2025-08-01" + + # Content should be everything before the first tool call + assert extracted_tool_calls.content == "I will help you get the weather info." + + +def test_streaming_basic_functionality(glm4_moe_tool_parser): + """Test basic streaming functionality.""" + # Reset streaming state + glm4_moe_tool_parser.current_tool_name_sent = False + glm4_moe_tool_parser.prev_tool_call_arr = [] + glm4_moe_tool_parser.current_tool_id = -1 + glm4_moe_tool_parser.streamed_args_for_tool = [] + + # Test with a simple tool call + current_text = """<tool_call>get_weather +<arg_key>city</arg_key> +<arg_value>Beijing</arg_value> +</tool_call>""" + + # Mock token IDs for testing + tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345 + tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346 + + result = glm4_moe_tool_parser.extract_tool_calls_streaming( + previous_text="", + current_text=current_text, + delta_text="</tool_call>", + previous_token_ids=[], + current_token_ids=[tool_call_start_id, tool_call_end_id], + delta_token_ids=[tool_call_end_id], + request=None, + ) + + # The result behavior depends on the streaming state + # This test mainly ensures no exceptions are thrown + assert result is None or hasattr(result, 'tool_calls') or hasattr( + result, 'content') + + +def test_streaming_no_tool_calls(glm4_moe_tool_parser): + """Test streaming when there are no tool calls.""" + current_text = "This is just regular text without any tool calls." + + result = glm4_moe_tool_parser.extract_tool_calls_streaming( + previous_text="This is just regular text", + current_text=current_text, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." + + +def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser): + """Test streaming when there's content before tool calls.""" + # Reset streaming state + glm4_moe_tool_parser.current_tool_name_sent = False + glm4_moe_tool_parser.prev_tool_call_arr = [] + glm4_moe_tool_parser.current_tool_id = -1 + glm4_moe_tool_parser.streamed_args_for_tool = [] + + current_text = "I will help you get the weather<tool_call>" + + result = glm4_moe_tool_parser.extract_tool_calls_streaming( + previous_text="I will help you", + current_text=current_text, + delta_text="get the weather.<tool_call>", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return content when no tool call tokens are detected + assert result is not None + assert hasattr(result, 'content') + assert result.content == "get the weather.<tool_call>" + + +def test_extract_tool_calls_special_characters(glm4_moe_tool_parser): + """Test tool calls with special characters and unicode.""" + model_output = """<tool_call>send_message +<arg_key>recipient</arg_key> +<arg_value>Amy</arg_value> +<arg_key>message</arg_key> +<arg_value>It is a nice day</arg_value> +<arg_key>priority</arg_key> +<arg_value>high</arg_value> +</tool_call>""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[0].function.name == "send_message" + + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args["recipient"] == "Amy" + assert args["message"] == "It is a nice day" + assert args["priority"] == "high" + + +def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser): + """Test incomplete tool calls (missing closing tag).""" + model_output = """<tool_call>get_weather +<arg_key>city</arg_key> +<arg_value>Beijing</arg_value> +<arg_key>date</arg_key> +<arg_value>2025-08-01</arg_value>""" + + extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + # Incomplete tool calls should not be extracted + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output diff --git a/tests/tool_use/test_kimi_k2_tool_parser.py b/tests/tool_use/test_kimi_k2_tool_parser.py new file mode 100644 index 000000000000..bd030632f167 --- /dev/null +++ b/tests/tool_use/test_kimi_k2_tool_parser.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +import json + +import pytest + +from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall +from vllm.entrypoints.openai.tool_parsers import KimiK2ToolParser +from vllm.transformers_utils.tokenizer import get_tokenizer + +# Use a common model that is likely to be available +MODEL = "moonshotai/Kimi-K2-Instruct" + + +@pytest.fixture(scope="module") +def kimi_k2_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL, trust_remote_code=True) + + +@pytest.fixture +def kimi_k2_tool_parser(kimi_k2_tokenizer): + return KimiK2ToolParser(kimi_k2_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + # assert tool call id format + assert actual_tool_call.id.startswith("functions.") + assert actual_tool_call.id.split(':')[-1].isdigit() + assert actual_tool_call.id.split('.')[1].split( + ':')[0] == expected_tool_call.function.name + + +def test_extract_tool_calls_no_tools(kimi_k2_tool_parser): + model_output = "This is a test" + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "tool_call_with_content_before", + "multi_tool_call_with_content_before", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""", + [ + ToolCall(id='functions.get_weather:0', + function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Beijing", + }, ), + ), + type='function') + ], + "I'll help you check the weather. ", + ), + ( + """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|> +functions.get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""", + [ + ToolCall(id='functions.get_weather:0', + function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Beijing", + }, ), + ), + type='function'), + ToolCall(id='functions.get_weather:1', + function=FunctionCall( + name="get_weather", + arguments=json.dumps({ + "city": "Shanghai", + }, ), + ), + type='function') + ], + "I'll help you check the weather. ", + ), + ], +) +def test_extract_tool_calls(kimi_k2_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_invalid_json(kimi_k2_tool_parser): + """we'll return every funcall result""" + model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.invalid_get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing" <|tool_call_end|> <|tool_call_begin|> +functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" + + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + # Should extract only the valid JSON tool calls + assert len(extracted_tool_calls.tool_calls) == 2 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "invalid_get_weather" + assert extracted_tool_calls.tool_calls[ + 1].function.name == "valid_get_weather" + + +def test_extract_tool_calls_invalid_funcall(kimi_k2_tool_parser): + """we'll return every funcall result""" + model_output = """I'll help you check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.invalid_get_weather.0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_call_begin|> +functions.valid_get_weather:1 <|tool_call_argument_begin|> {"city": "Shanghai"} <|tool_call_end|> <|tool_calls_section_end|>""" + + extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + + assert extracted_tool_calls.tools_called + # Should extract only the valid JSON tool calls + assert len(extracted_tool_calls.tool_calls) == 1 + assert extracted_tool_calls.tool_calls[ + 0].function.name == "valid_get_weather" + + +def test_streaming_basic_functionality(kimi_k2_tool_parser): + """Test basic streaming functionality.""" + # Reset streaming state + kimi_k2_tool_parser.current_tool_name_sent = False + kimi_k2_tool_parser.prev_tool_call_arr = [] + kimi_k2_tool_parser.current_tool_id = -1 + kimi_k2_tool_parser.streamed_args_for_tool = [] + + # Test with a simple tool call + current_text = """ check the weather. <|tool_calls_section_begin|> <|tool_call_begin|> +functions.get_weather:0 <|tool_call_argument_begin|> {"city": "Beijing"} <|tool_call_end|> <|tool_calls_section_end|>""" + + # First call should handle the initial setup + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="I'll help you", + current_text=current_text, + delta_text="<|tool_calls_section_end|>", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # The result might be None or contain tool call information + # This depends on the internal state management + if result is not None and hasattr(result, + 'tool_calls') and result.tool_calls: + assert len(result.tool_calls) >= 0 + + +def test_streaming_no_tool_calls(kimi_k2_tool_parser): + """Test streaming when there are no tool calls.""" + current_text = "This is just regular text without any tool calls." + + result = kimi_k2_tool_parser.extract_tool_calls_streaming( + previous_text="This is just regular text", + current_text=current_text, + delta_text=" without any tool calls.", + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=None, + ) + + # Should return the delta text as content + assert result is not None + assert hasattr(result, 'content') + assert result.content == " without any tool calls." diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py new file mode 100644 index 000000000000..40c3158e9e68 --- /dev/null +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -0,0 +1,618 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Generator +from typing import Optional + +import pytest + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( + Qwen3CoderToolParser) +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" + + +@pytest.fixture(scope="module") +def qwen3_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def qwen3_tool_parser(qwen3_tokenizer): + return Qwen3CoderToolParser(qwen3_tokenizer) + + +@pytest.fixture +def sample_tools(): + return [ + ChatCompletionToolsParam(type="function", + function={ + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city name" + }, + "state": { + "type": "string", + "description": + "The state code" + }, + "unit": { + "type": "string", + "enum": + ["fahrenheit", "celsius"] + } + }, + "required": ["city", "state"] + } + }), + ChatCompletionToolsParam(type="function", + function={ + "name": "calculate_area", + "description": + "Calculate area of a shape", + "parameters": { + "type": "object", + "properties": { + "shape": { + "type": "string" + }, + "dimensions": { + "type": "object" + }, + "precision": { + "type": "integer" + } + } + } + }) + ] + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + # Qwen3 parser doesn't generate IDs during extraction + assert actual_tool_call.type == "function" + assert ( + actual_tool_call.function.name == expected_tool_call.function.name) + assert (json.loads(actual_tool_call.function.arguments) == json.loads( + expected_tool_call.function.arguments)) + + +def stream_delta_message_generator( + qwen3_tool_parser: Qwen3CoderToolParser, + qwen3_tokenizer: AnyTokenizer, + model_output: str, + request: Optional[ChatCompletionRequest] = None +) -> Generator[DeltaMessage, None, None]: + all_token_ids = qwen3_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=qwen3_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = qwen3_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=request, + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(qwen3_tool_parser): + model_output = "This is a test response without any tool calls" + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool", + "single_tool_with_content", + "single_tool_multiline_param", + "parallel_tools", + "tool_with_typed_params", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ('''<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], None), + ('''Sure! Let me check the weather for you.<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], "Sure! Let me check the weather for you."), + ('''<tool_call> +<function=calculate_area> +<parameter=shape> +rectangle +</parameter> +<parameter=dimensions> +{"width": 10, + "height": 20} +</parameter> +<parameter=precision> +2 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "rectangle", + "dimensions": { + "width": 10, + "height": 20 + }, + "precision": 2 + }))) + ], None), + ('''<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call> +<tool_call> +<function=get_current_weather> +<parameter=city> +Orlando +</parameter> +<parameter=state> +FL +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))), + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit" + }))) + ], None), + ('''Let me calculate that area for you.<tool_call> +<function=calculate_area> +<parameter=shape> +circle +</parameter> +<parameter=dimensions> +{"radius": 15.5} +</parameter> +<parameter=precision> +3 +</parameter> +</function> +</tool_call>''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "circle", + "dimensions": { + "radius": 15.5 + }, + "precision": 3 + }))) + ], "Let me calculate that area for you."), + ], +) +def test_extract_tool_calls(qwen3_tool_parser, sample_tools, model_output, + expected_tool_calls, expected_content): + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def test_extract_tool_calls_fallback_no_tags(qwen3_tool_parser, sample_tools): + """Test fallback parsing when XML tags are missing""" + model_output = '''<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +</function>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + assert (extracted_tool_calls.tool_calls[0].function.name == + "get_current_weather") + + +def test_extract_tool_calls_type_conversion(qwen3_tool_parser): + """Test parameter type conversion based on tool schema""" + tools = [ + ChatCompletionToolsParam(type="function", + function={ + "name": "test_types", + "parameters": { + "type": "object", + "properties": { + "int_param": { + "type": "integer" + }, + "float_param": { + "type": "float" + }, + "bool_param": { + "type": "boolean" + }, + "str_param": { + "type": "string" + }, + "obj_param": { + "type": "object" + } + } + } + }) + ] + + model_output = '''<tool_call> +<function=test_types> +<parameter=int_param> +42 +</parameter> +<parameter=float_param> +3.14 +</parameter> +<parameter=bool_param> +true +</parameter> +<parameter=str_param> +hello world +</parameter> +<parameter=obj_param> +{"key": "value"} +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, messages=[], tools=tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert args["int_param"] == 42 + assert args["float_param"] == 3.14 + assert args["bool_param"] is True + assert args["str_param"] == "hello world" + assert args["obj_param"] == {"key": "value"} + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool", + "single_tool_with_content", + "parallel_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("This is a test without tools", [], "This is a test without tools"), + ('''<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], ""), + ('''Sure! Let me check the weather for you.<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call>''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], "Sure! Let me check the weather for you."), + ('''<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +<parameter=unit> +fahrenheit +</parameter> +</function> +</tool_call> +<tool_call> +<function=get_current_weather> +<parameter=city> +Orlando +</parameter> +<parameter=state> +FL +</parameter> +<parameter=unit> +celsius +</parameter> +</function> +</tool_call>''', [ + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))), + ToolCall( + function=FunctionCall(name="get_current_weather", + arguments=json.dumps({ + "city": "Orlando", + "state": "FL", + "unit": "celsius" + }))) + ], ""), + ], +) +def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, + sample_tools, model_output, + expected_tool_calls, expected_content): + """Test incremental streaming behavior""" + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} # Track state per tool index + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + # Initialize state for new tool + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + # First chunk should have id, name, and type + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + # Should only be set once + assert tool_states[idx]["name"] is None + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + # Accumulate arguments incrementally + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify final content + assert other_content == expected_content + + # Verify we got all expected tool calls + assert len(tool_states) == len(expected_tool_calls) + + # Verify each tool call + for idx, expected_tool in enumerate(expected_tool_calls): + state = tool_states[idx] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == expected_tool.function.name + + # Parse accumulated arguments + arguments_str = state["arguments"] + assert arguments_str is not None + actual_args = json.loads(arguments_str) + expected_args = json.loads(expected_tool.function.arguments) + assert actual_args == expected_args + + +def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, + qwen3_tokenizer, + sample_tools): + """Test that streaming is truly incremental""" + model_output = '''I'll check the weather.<tool_call> +<function=get_current_weather> +<parameter=city> +Dallas +</parameter> +<parameter=state> +TX +</parameter> +</function> +</tool_call>''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + chunks = [] + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + chunks.append(delta_message) + + # Should have multiple chunks + assert len(chunks) > 3 + + # First chunk(s) should be content + assert chunks[0].content is not None + assert chunks[0].tool_calls is None or chunks[0].tool_calls == [] + + # Should have a chunk with tool header (id, name, type) + header_found = False + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].id: + header_found = True + assert (chunk.tool_calls[0].function.name == "get_current_weather") + assert chunk.tool_calls[0].type == "function" + # Empty initially + assert chunk.tool_calls[0].function.arguments == "" + break + assert header_found + + # Should have chunks with incremental arguments + arg_chunks = [] + for chunk in chunks: + if chunk.tool_calls and chunk.tool_calls[0].function.arguments: + arg_chunks.append(chunk.tool_calls[0].function.arguments) + + # Arguments should be streamed incrementally + assert len(arg_chunks) > 1 + + # Concatenated arguments should form valid JSON + full_args = "".join(arg_chunks) + parsed_args = json.loads(full_args) + assert parsed_args["city"] == "Dallas" + assert parsed_args["state"] == "TX" diff --git a/tests/tool_use/test_tool_choice_required.py b/tests/tool_use/test_tool_choice_required.py index 3b43b723d438..e0ed221a93e1 100644 --- a/tests/tool_use/test_tool_choice_required.py +++ b/tests/tool_use/test_tool_choice_required.py @@ -72,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output, assert isinstance(schema, dict) # use build_regex_from_schema used in JSONLogitsProcessor to create Guide - from outlines_core.fsm.json_schema import build_regex_from_schema + from outlines_core.json_schema import build_regex_from_schema regex = build_regex_from_schema(json.dumps(schema)) compiled = re.compile(regex) matches = compiled.fullmatch(json.dumps(sample_output)) is not None diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py index a13cf7064d54..6cefbae4bdd1 100644 --- a/tests/tpu/test_quantization_accuracy.py +++ b/tests/tpu/test_quantization_accuracy.py @@ -14,7 +14,7 @@ @dataclass class GSM8KAccuracyTestConfig: model_name: str - excepted_value: float + expected_value: float def get_model_args(self) -> str: return (f"pretrained={self.model_name}," @@ -25,13 +25,13 @@ def get_model_args(self) -> str: ACCURACY_CONFIGS = [ GSM8KAccuracyTestConfig( model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - excepted_value=0.76), # no bias + expected_value=0.76), # no bias # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, # so only one of these tests can run in a single call to pytest. As # a follow up, move this into the LM-EVAL section of the CI. # GSM8KAccuracyTestConfig( # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - # excepted_value=0.66), # bias in QKV layers + # expected_value=0.66), # bias in QKV layers ] @@ -45,7 +45,7 @@ def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): batch_size="auto", ) - EXPECTED_VALUE = config.excepted_value + EXPECTED_VALUE = config.expected_value measured_value = results["results"][TASK][FILTER] assert (measured_value - RTOL < EXPECTED_VALUE and measured_value + RTOL > EXPECTED_VALUE diff --git a/tests/utils.py b/tests/utils.py index a37872830dad..f4317e6bdb40 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -818,14 +818,15 @@ def create_new_process_for_each_test( Args: method: The process creation method. Can be either "spawn" or "fork". - If not specified, - it defaults to "spawn" on ROCm platforms and "fork" otherwise. + If not specified, it defaults to "spawn" on ROCm and XPU + platforms and "fork" otherwise. Returns: A decorator to run test functions in separate processes. """ if method is None: - method = "spawn" if current_platform.is_rocm() else "fork" + use_spawn = current_platform.is_rocm() or current_platform.is_xpu() + method = "spawn" if use_spawn else "fork" assert method in ["spawn", "fork"], "Method must be either 'spawn' or 'fork'" diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py new file mode 100644 index 000000000000..b4e0101a0d4b --- /dev/null +++ b/tests/v1/attention/test_attention_backends.py @@ -0,0 +1,466 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for v1 attention backends without GPUModelRunner dependency.""" + +import pytest +import torch + +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config, + get_attention_backend) +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + +BACKENDS_TO_TEST = [ + _Backend.FLASH_ATTN_VLLM_V1, _Backend.FLASHINFER_VLLM_V1, + _Backend.FLEX_ATTENTION, _Backend.TRITON_ATTN_VLLM_V1 +] + +# Remove flashinfer from the list if it's not available +try: + import flashinfer # noqa: F401 +except ImportError: + BACKENDS_TO_TEST.remove(_Backend.FLASHINFER_VLLM_V1) + + +def _convert_dtype_to_torch(dtype): + """Convert ModelDType to torch.dtype.""" + if isinstance(dtype, str): + if dtype == "auto": + return torch.float16 # Default dtype for testing + elif dtype in STR_DTYPE_TO_TORCH_DTYPE: + return STR_DTYPE_TO_TORCH_DTYPE[dtype] + else: + raise ValueError(f"Unknown dtype: {dtype}") + elif isinstance(dtype, torch.dtype): + return dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + +# Define common batch configurations +BATCH_SPECS = { + "small_decode": + BatchSpec(seq_lens=[32, 40], query_lens=[1, 1]), + "small_prefill": + BatchSpec(seq_lens=[32, 40], query_lens=[8, 8]), + "mixed_small": + BatchSpec(seq_lens=[32, 40, 48, 56], query_lens=[1, 1, 5, 5]), + "medium_decode": + BatchSpec(seq_lens=[128, 256, 512, 1024, 128, 256, 512, 1024], + query_lens=[1, 1, 1, 1, 1, 1, 1, 1]), + "medium_prefill": + BatchSpec(seq_lens=[256, 512, 1024, 2048], query_lens=[16, 16, 16, 16]), + "mixed_medium": + BatchSpec(seq_lens=[512, 1024, 2048, 512, 1024, 2048], + query_lens=[1, 1, 1, 7, 7, 7]), + "large_decode": + BatchSpec(seq_lens=[2048] * 32, query_lens=[1] * 32), + "large_prefill": + BatchSpec(seq_lens=[4096] * 8, query_lens=[32] * 8), + "single_decode": + BatchSpec(seq_lens=[1024], query_lens=[1]), + "single_prefill": + BatchSpec(seq_lens=[1024], query_lens=[64]), +} + + +def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + 2, # K and V + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), + device=device, + ) + return kv_cache + + +def create_and_prepopulate_kv_cache( + k_contexts: list[torch.Tensor], + v_contexts: list[torch.Tensor], + block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int, + common_attn_metadata: CommonAttentionMetadata, + randomize_blocks: bool = True) -> torch.Tensor: + """Create and prepopulate a KV cache with context data. + + Args: + k_contexts: List of key context tensors for each sequence + v_contexts: List of value context tensors for each sequence + seq_lens: List of sequence lengths + block_size: Size of each block + num_kv_heads: Number of KV heads + head_size: Size of each head + dtype: Data type for the cache + device: Device to create the cache on + num_blocks: Total number of blocks in the cache + block_table: Block table tensor to populate + randomize_blocks: Whether to randomly permute blocks + or use sequential order + + Returns: + Tuple of (kv_cache, updated_block_table) + """ + batch_size = len(k_contexts) + seq_lens = common_attn_metadata.seq_lens_cpu + query_lens = common_attn_metadata.query_start_loc_cpu[ + 1:] - common_attn_metadata.query_start_loc_cpu[:-1] + context_lens = common_attn_metadata.num_computed_tokens_cpu + block_table = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + + # Create KV cache + kv_cache = torch.empty(2, + num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(2, -1, num_kv_heads, head_size) + + # Populate the cache with the context tokens + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + k_context, v_context = k_contexts[i], v_contexts[i] + start = start_block_idx * block_size + end = start + k_context.shape[0] + kv_cache_flat[0, start:end, ...] = k_context + kv_cache_flat[1, start:end, ...] = v_context + + # Stay block aligned and allocate enough blocks for the new tokens + start_block_idx += cdiv(int(seq_lens[i]), block_size) + + blocks_end = start_block_idx + + # Permute the context blocks (excluding block 0 which is null) + if randomize_blocks: + perm = torch.randperm( + blocks_end - 1) + 1 # Random permutation starting from block 1 + else: + perm = torch.arange( + 1, blocks_end) # Sequential order starting from block 1 + + inv_perm = torch.zeros(blocks_end, dtype=torch.long, device=device) + inv_perm[1:] = torch.argsort( + perm) + 1 # Add 1 to account for starting from block 1 + kv_cache[:, 1:blocks_end, ...] = kv_cache[:, perm, ...] + + # Construct the right block table + # Start from block_id=1 since block_id=0 is considered the null block + start_block_idx = 1 + for i in range(batch_size): + num_blocks_for_seq = cdiv(int(seq_lens[i]), block_size) + start = start_block_idx + end = start + num_blocks_for_seq + block_table[i, :num_blocks_for_seq] = inv_perm[start:end] + start_block_idx += num_blocks_for_seq + + # Create a realistic slot mapping that corresponds to the block table + for i in range(batch_size): + token_offsets = torch.arange(int(query_lens[i])) + int(context_lens[i]) + block_indices = token_offsets // block_size + token_inter_block_offsets = token_offsets % block_size + start = common_attn_metadata.query_start_loc_cpu[i] + end = common_attn_metadata.query_start_loc_cpu[i + 1] + slot_mapping[start:end] = block_table[ + i, + block_indices] * block_size + token_inter_block_offsets.to(device) + + return kv_cache + + +class MockAttentionLayer: + """A mock attention layer for testing.""" + + def __init__(self, device: torch.device): + self._q_scale = torch.tensor(1.0, device=device) + self._k_scale = torch.tensor(1.0, device=device) + self._v_scale = torch.tensor(1.0, device=device) + # Add float versions for flashinfer + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + + +def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec, + vllm_config, device: torch.device, + common_attn_metadata: CommonAttentionMetadata, + query: torch.Tensor, key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor) -> torch.Tensor: + """Run attention computation using the specified backend's AttentionImpl.""" + + builder_cls, impl_cls = get_attention_backend(backend) + + # Mock flashinfer's get_per_layer_parameters if needed + if backend == _Backend.FLASHINFER_VLLM_V1: + import unittest.mock + + from vllm.v1.attention.backends.flashinfer import PerLayerParameters + + def mock_get_per_layer_parameters(vllm_config): + # Return mock parameters for a single layer + head_size = vllm_config.model_config.get_head_size() + return { + "mock_layer": + PerLayerParameters( + window_left=-1, # No sliding window + logits_soft_cap=0.0, # No soft cap + sm_scale=1.0 / (head_size**0.5) # Standard scale + ) + } + + with unittest.mock.patch( + 'vllm.v1.attention.backends.flashinfer.get_per_layer_parameters', + mock_get_per_layer_parameters): + builder = builder_cls(kv_cache_spec, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + else: + # Build metadata + builder = builder_cls(kv_cache_spec, vllm_config, device) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # Instantiate implementation + num_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + scale = 1.0 / (head_size**0.5) + impl = impl_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype="auto", + ) + + # Create mock layer and output buffer + mock_layer = MockAttentionLayer(device) + output = torch.empty_like(query) + + # Run forward pass + # NOTE: The query, key, and value are already shaped correctly + # in the calling test function. + output = impl.forward(mock_layer, + query, + key, + value, + kv_cache, + attn_metadata, + output=output) + + return output + + +@pytest.mark.parametrize("batch_spec_name", [ + "small_decode", "small_prefill", "mixed_small", "medium_decode", + "medium_prefill", "mixed_medium" +]) +@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"]) +def test_backend_correctness(batch_spec_name: str, model: str): + """ + Test that all backends produce similar outputs to a reference implementation + using torch.nn.functional.scaled_dot_product_attention. + + This test works by: + 1. Generating a batch of sequences with specified context and query lengths. + 2. Computing a ground-truth attention output using torch.sdpa on + contiguous Q, K, and V tensors. + 3. Simulating vLLM's paged KV cache: It takes the context portion of the + K/V tensors and manually places them into a paged buffer according to + the test's (randomly generated) block table. + 4. Running each vLLM attention backend with the new queries and the + simulated paged KV cache. + 5. Comparing the vLLM backend's output to the ground-truth SDPA output. + """ + batch_spec = BATCH_SPECS[batch_spec_name] + vllm_config = create_vllm_config(model_name=model) + device = torch.device("cuda:0") + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + # 1. Setup + batch_size = batch_spec.batch_size + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) + num_kv_heads = vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config) + head_size = vllm_config.model_config.get_head_size() + dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype) + block_size = vllm_config.cache_config.block_size + scale = 1.0 / (head_size**0.5) + + # 2. Generate data and compute SDPA reference output + all_q_vllm, all_k_vllm, all_v_vllm = [], [], [] + all_sdpa_outputs = [] + k_contexts, v_contexts = [], [] + + for i in range(batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + context_len = s_len - q_len + + # Generate Q, K, V for the whole sequence to be used in SDPA + q = torch.randn(q_len, + num_q_heads, + head_size, + dtype=dtype, + device=device) + k_full = torch.randn(s_len, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + v_full = torch.randn(s_len, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + + # SDPA expects (N, H, L, D), so unsqueeze batch and permute + q_sdpa_in = q.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0, ( + f"num_q_heads ({num_q_heads}) must be divisible by " + f"num_kv_heads ({num_kv_heads})") + repeats = num_q_heads // num_kv_heads + k_sdpa_in = k_sdpa_in.repeat_interleave(repeats, dim=1) + v_sdpa_in = v_sdpa_in.repeat_interleave(repeats, dim=1) + + # Create causal mask: query token i attends to positions 0 to + # (context_len + i) + kv_len = s_len + offset = context_len + attn_mask = torch.full((q_len, kv_len), + float('-inf'), + device=device, + dtype=dtype) + for i in range(q_len): + attn_mask[i, :offset + i + 1] = 0.0 + + sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, + k_sdpa_in, + v_sdpa_in, + attn_mask=attn_mask, + scale=scale, + enable_gqa=True) + # Convert back to (L, H, D) + all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0)) + + # Inputs for vLLM backends are just the new tokens + all_q_vllm.append(q) + all_k_vllm.append(k_full[context_len:]) + all_v_vllm.append(v_full[context_len:]) + + # Contextual K/V data used to populate the paged cache + k_contexts.append(k_full[:context_len]) + v_contexts.append(v_full[:context_len]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + key_vllm = torch.cat(all_k_vllm, dim=0) + value_vllm = torch.cat(all_v_vllm, dim=0) + sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, vllm_config.cache_config.block_size, device) + + # 3. Simulate Paged KV Cache and a realistic slot_mapping + kv_cache = create_and_prepopulate_kv_cache( + k_contexts=k_contexts, + v_contexts=v_contexts, + block_size=block_size, + num_kv_heads=num_kv_heads, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks or 1000, + common_attn_metadata=common_attn_metadata, + randomize_blocks=True) + + # 4. Run vLLM backends and compare + # Note: flex_attention has known Triton kernel compatibility issues + # with test infrastructures + for backend_name in BACKENDS_TO_TEST: + # FlashAttentionm + FlexAttention: + # [2, num_blocks, block_size, num_kv_heads, head_size] + # FlashInfer: + # [num_blocks, 2, block_size, num_kv_heads, head_size] + # Select the appropriate KV cache format for each backend + kv_cache_for_backend = kv_cache + if backend_name == _Backend.FLASHINFER_VLLM_V1: + kv_cache_for_backend = kv_cache.transpose(0, 1) + + backend_output = run_attention_backend(backend_name, kv_cache_spec, + vllm_config, device, + common_attn_metadata, + query_vllm, key_vllm, + value_vllm, + kv_cache_for_backend) + + # Check shape and dtype consistency + assert backend_output.shape == sdpa_output.shape, ( + f"[{backend_name}] shape {backend_output.shape} != " + f"SDPA shape {sdpa_output.shape}") + assert backend_output.dtype == sdpa_output.dtype, ( + f"[{backend_name}] dtype {backend_output.dtype} != " + f"SDPA dtype {sdpa_output.dtype}") + + assert torch.isfinite(backend_output).all(), ( + f"[{backend_name}] produced non-finite values") + + # Check numerical similarity + rtol = 1e-2 + atol = 5e-3 + + if backend_name == _Backend.FLEX_ATTENTION: + atol = 5e-1 # TODO: figure out why flex_attention has such large + # numerical differences for medium_decode, medium_prefill, + # mixed_medium + + max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_rel_diff = torch.max( + torch.abs(backend_output - sdpa_output) / + torch.abs(sdpa_output)).item() + all_close = torch.allclose(backend_output, + sdpa_output, + rtol=rtol, + atol=atol) + + if not all_close: + print(f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") + print(f"[{backend_name}] output: {backend_output}") + print(f"[{backend_name}] SDPA baseline: {sdpa_output}") + + assert all_close, ( + f"[{backend_name}] output differs from SDPA baseline. " + f"Max diff: {max_diff:.6f} (rel: {max_rel_diff:.6f})") diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py new file mode 100644 index 000000000000..30cfbdda5d86 --- /dev/null +++ b/tests/v1/attention/utils.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions for attention-related v1 tests.""" + +from dataclasses import dataclass +from typing import Union + +import pytest +import torch + +from vllm.config import (CacheConfig, CompilationConfig, DeviceConfig, + LoadConfig, ModelConfig, ModelDType, ParallelConfig, + SchedulerConfig, VllmConfig) +from vllm.platforms import _Backend +from vllm.utils import resolve_obj_by_qualname +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import FullAttentionSpec + + +@dataclass +class BatchSpec: + """Specification for a batch configuration (workload shape only).""" + seq_lens: list[int] + query_lens: list[int] + + name: str = "unnamed" + + @property + def batch_size(self): + return len(self.seq_lens) + + def __post_init__(self): + assert len(self.seq_lens) == len(self.query_lens) + + def compute_num_tokens(self): + return sum(self.query_lens) + + +def create_common_attn_metadata( + batch_spec: BatchSpec, + block_size: int, + device: torch.device, + max_block_idx: int = 1000) -> CommonAttentionMetadata: + """Create CommonAttentionMetadata from a BatchSpec and ModelParams.""" + # Create query start locations + query_start_loc = torch.zeros(batch_spec.batch_size + 1, + dtype=torch.int32, + device=device) + query_start_loc[1:] = torch.tensor(batch_spec.query_lens, + dtype=torch.int32, + device=device).cumsum(0) + query_start_loc_cpu = query_start_loc.cpu() + num_tokens = batch_spec.compute_num_tokens() + + # Create sequence lengths + seq_lens = torch.tensor(batch_spec.seq_lens, + dtype=torch.int32, + device=device) + seq_lens_cpu = seq_lens.cpu() + + # Create computed tokens (context length for each sequence) + context_lens = [ + batch_spec.seq_lens[i] - batch_spec.query_lens[i] + for i in range(batch_spec.batch_size) + ] + num_computed_tokens_cpu = torch.tensor(context_lens, dtype=torch.int32) + + # Create block table (random for testing) + max_blocks = max(batch_spec.seq_lens) // block_size + 1 + block_table_tensor = torch.randint(0, + max_block_idx, + (batch_spec.batch_size, max_blocks), + dtype=torch.int32, + device=device) + + # Create slot mapping + slot_mapping = torch.randint(0, + max_block_idx, (num_tokens, ), + dtype=torch.int64, + device=device) + + # Calculate max query length + max_query_len = max(batch_spec.query_lens) + + return CommonAttentionMetadata( + query_start_loc=query_start_loc, + query_start_loc_cpu=query_start_loc_cpu, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + num_computed_tokens_cpu=num_computed_tokens_cpu, + num_reqs=batch_spec.batch_size, + num_actual_tokens=num_tokens, + max_query_len=max_query_len, + block_table_tensor=block_table_tensor, + slot_mapping=slot_mapping, + ) + + +def get_attention_backend(backend_name: _Backend): + """Set up attention backend classes for testing. + + Args: + backend_name: Name of the backend ("flash_attn", "flashinfer", etc.) + vllm_config: VllmConfig instance + + Returns: + Tuple of (backend_builder_class, backend_impl_class) + """ + backend_map = { + _Backend.FLASH_ATTN_VLLM_V1: + "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend", + _Backend.FLASHINFER_VLLM_V1: + "vllm.v1.attention.backends.flashinfer.FlashInferBackend", + _Backend.FLEX_ATTENTION: + "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend", + _Backend.TRITON_ATTN_VLLM_V1: + "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend", + } + + if backend_name not in backend_map: + raise ValueError(f"Unknown backend: {backend_name}") + + backend_class_name = backend_map[backend_name] + + try: + backend_class = resolve_obj_by_qualname(backend_class_name) + return backend_class.get_builder_cls(), backend_class.get_impl_cls() + except ImportError as e: + pytest.skip(f"{backend_name} not available: {e}") + + +def create_standard_kv_cache_spec( + vllm_config: VllmConfig) -> FullAttentionSpec: + """Create a FullAttentionSpec from ModelParams only.""" + return FullAttentionSpec( + block_size=vllm_config.cache_config.block_size, + num_kv_heads=vllm_config.model_config.get_num_kv_heads( + vllm_config.parallel_config), + head_size=vllm_config.model_config.get_head_size(), + dtype=vllm_config.model_config.dtype, + use_mla=vllm_config.model_config.use_mla, + sliding_window=vllm_config.model_config.get_sliding_window(), + ) + + +def create_vllm_config(model_name: str = "meta-llama/Meta-Llama-3-8B", + tensor_parallel_size: int = 1, + max_model_len: int = 1024, + dtype: Union[ModelDType, torch.dtype] = "auto", + block_size: int = 16, + max_num_seqs: int = 256, + max_num_batched_tokens: int = 8192, + add_mock_model_methods: bool = True) -> VllmConfig: + """Create a VllmConfig for testing with reasonable defaults.""" + + model_config = ModelConfig( + model=model_name, + tokenizer=model_name, + trust_remote_code=False, + dtype=dtype, + seed=0, + max_model_len=max_model_len, + ) + + cache_config = CacheConfig( + block_size=block_size, + cache_dtype="auto", + swap_space=0, + ) + # Set cache blocks for testing + # (these may be set during initialization normally) + cache_config.num_gpu_blocks = 1000 + cache_config.num_cpu_blocks = 0 + + parallel_config = ParallelConfig( + tensor_parallel_size=tensor_parallel_size, ) + + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + ) + + device_config = DeviceConfig() + load_config = LoadConfig() + compilation_config = CompilationConfig() + + if add_mock_model_methods: + # Add mock methods to satisfy backends that need them + # This is a workaround because tests don't build full, real models, + # but some backends expect to query the model for layer-specific + # parameters + import types + model_config.get_num_layers = types.MethodType(lambda self: 1, + model_config) + model_config.get_sliding_window_for_layer = types.MethodType( + lambda self, i: None, model_config) + model_config.get_logits_soft_cap_for_layer = types.MethodType( + lambda self, i: 0.0, model_config) + model_config.get_sm_scale_for_layer = types.MethodType( + lambda self, i: 1.0 / model_config.get_head_size()**0.5, + model_config) + + return VllmConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + compilation_config=compilation_config, + ) + + +def create_dummy_kv_cache(block_size: int, + num_kv_heads: int, + head_size: int, + dtype: torch.dtype, + device: torch.device, + num_blocks: int = 100) -> torch.Tensor: + """Create a dummy KV cache tensor for testing.""" + kv_cache = torch.randn( + num_blocks, + 2, # K and V + block_size, + num_kv_heads, + head_size, + dtype=dtype, + device=device) + return kv_cache diff --git a/tests/spec_decode/e2e/__init__.py b/tests/v1/core/__init__.py similarity index 100% rename from tests/spec_decode/e2e/__init__.py rename to tests/v1/core/__init__.py diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py new file mode 100644 index 000000000000..3ccefbd81cab --- /dev/null +++ b/tests/v1/core/test_async_scheduler.py @@ -0,0 +1,228 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import deque + +import pytest + +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import RequestStatus + +from .utils import create_requests, create_scheduler + + +def _make_model_runner_output( + scheduler_output: SchedulerOutput, ) -> ModelRunnerOutput: + req_ids = list(scheduler_output.num_scheduled_tokens.keys()) + return ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index={ + req_id: i + for i, req_id in enumerate(req_ids) + }, + sampled_token_ids=[[i] for i in range(len(req_ids))], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + + +@pytest.mark.parametrize("max_tokens", [1, 2, 3, 5]) +def test_stop_by_max_tokens(max_tokens: int): + scheduler = create_scheduler(async_scheduling=True) + requests = create_requests(num_requests=2, max_tokens=max_tokens) + req0, req1 = requests + + sched_outputs: deque[SchedulerOutput] = deque() + scheduler.add_request(req0) + sched_outputs.append(scheduler.schedule()) + + scheduler.add_request(req1) + sched_outputs.append(scheduler.schedule()) + + while sched_outputs: + sched_output = sched_outputs.popleft() + model_runner_output = _make_model_runner_output(sched_output) + scheduler.update_from_output(sched_output, model_runner_output) + + sched_output = scheduler.schedule() + if sched_output.num_scheduled_tokens: + sched_outputs.append(sched_output) + + assert scheduler.get_num_unfinished_requests() == 0 + assert req0.num_output_tokens == max_tokens + assert req1.num_output_tokens == max_tokens + + +def test_abort(): + scheduler = create_scheduler(async_scheduling=True) + requests = create_requests(num_requests=10, max_tokens=20) + + for req in requests: + scheduler.add_request(req) + + sched_outputs: deque[SchedulerOutput] = deque() + sched_outputs.append(scheduler.schedule()) + sched_outputs.append(scheduler.schedule()) + + abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9] + abort_order_copy = abort_order.copy() + + def abort_request(): + if not abort_order: + return + req = requests[abort_order.pop(0)] + scheduler.finish_requests(req.request_id, + RequestStatus.FINISHED_ABORTED) + + while sched_outputs: + # Abort a scheduled request. + abort_request() + sched_output = sched_outputs.popleft() + model_runner_output = _make_model_runner_output(sched_output) + scheduler.update_from_output(sched_output, model_runner_output) + + sched_output = scheduler.schedule() + if sched_output.num_scheduled_tokens: + sched_outputs.append(sched_output) + + for i, req in enumerate(requests): + assert req.status == RequestStatus.FINISHED_ABORTED + assert req.num_output_tokens == abort_order_copy.index(i) + + +def test_preempt(): + scheduler = create_scheduler(async_scheduling=True) + requests = create_requests(num_requests=10, max_tokens=20) + + for req in requests: + scheduler.add_request(req) + + sched_outputs: deque[SchedulerOutput] = deque() + sched_outputs.append(scheduler.schedule()) + sched_outputs.append(scheduler.schedule()) + + abort_order = [0, 8, 3, 1, 6, 4, 2, 5, 7, 9] + abort_order_copy = abort_order.copy() + + def abort_request(): + if not abort_order: + return + req = requests[abort_order.pop(0)] + scheduler.finish_requests(req.request_id, + RequestStatus.FINISHED_ABORTED) + + while sched_outputs: + # Abort a scheduled request. + abort_request() + sched_output = sched_outputs.popleft() + model_runner_output = _make_model_runner_output(sched_output) + scheduler.update_from_output(sched_output, model_runner_output) + + sched_output = scheduler.schedule() + if sched_output.num_scheduled_tokens: + sched_outputs.append(sched_output) + + for i, req in enumerate(requests): + assert req.status == RequestStatus.FINISHED_ABORTED + assert req.num_output_tokens == abort_order_copy.index(i) + + +def test_prefix_caching_for_prefill_dedup(): + CHUNK_SIZE = 1000 + BLOCK_SIZE = 16 + num_prompt_tokens = 100 + scheduler = create_scheduler(async_scheduling=True, + max_num_batched_tokens=CHUNK_SIZE, + enable_prefix_caching=True, + block_size=BLOCK_SIZE) + requests = create_requests(num_requests=5, + num_tokens=num_prompt_tokens, + max_tokens=3, + same_prompt=True) + requests_copy = requests.copy() + + # Two requests with the same prompt. + req0 = requests.pop(0) + req1 = requests.pop(0) + scheduler.add_request(req0) + scheduler.add_request(req1) + + sched_outputs: deque[SchedulerOutput] = deque() + sched_output = scheduler.schedule() + sched_outputs.append(sched_output) + # Make sure prefix caching de-duplicates the prompts in the same step, + # so all the blocks except the last are shared between the two requests. + assert len(sched_output.num_scheduled_tokens) == 2 + num_blocks = num_prompt_tokens // BLOCK_SIZE + assert req0.num_cached_tokens == 0 + assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE + + sched_outputs.append(scheduler.schedule()) + while sched_outputs: + if requests: + scheduler.add_request(requests.pop(0)) + sched_output = sched_outputs.popleft() + model_runner_output = _make_model_runner_output(sched_output) + scheduler.update_from_output(sched_output, model_runner_output) + sched_output = scheduler.schedule() + if sched_output.num_scheduled_tokens: + sched_outputs.append(sched_output) + + # Other requests scheduled after the two requests should also get + # prefix cache hit. + assert scheduler.get_num_unfinished_requests() == 0 + for req in requests_copy[1:]: + assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE + + +def test_prefix_caching_for_multi_turn(): + CHUNK_SIZE = 1000 + BLOCK_SIZE = 16 + num_prompt_tokens = 100 + num_output_tokens = 200 + scheduler = create_scheduler(async_scheduling=True, + max_num_batched_tokens=CHUNK_SIZE, + enable_prefix_caching=True, + block_size=BLOCK_SIZE) + requests = create_requests(num_requests=5, + num_tokens=num_prompt_tokens, + max_tokens=num_output_tokens) + + for req in requests: + scheduler.add_request(req) + sched_outputs: deque[SchedulerOutput] = deque() + sched_outputs.append(scheduler.schedule()) + sched_outputs.append(scheduler.schedule()) + + # Process the requests. + while sched_outputs: + sched_output = sched_outputs.popleft() + model_runner_output = _make_model_runner_output(sched_output) + scheduler.update_from_output(sched_output, model_runner_output) + sched_output = scheduler.schedule() + if sched_output.num_scheduled_tokens: + sched_outputs.append(sched_output) + assert scheduler.get_num_unfinished_requests() == 0 + + # Create next-turn requests whose prompts are the full output of the + # previous turn. + next_turn_requests = create_requests( + num_requests=5, + num_tokens=num_prompt_tokens + num_output_tokens, + max_tokens=num_output_tokens, + ) + for i, req in enumerate(next_turn_requests): + req.prompt_token_ids = (requests[i].prompt_token_ids + + list(requests[i].output_token_ids)) + # Schedule the next-turn requests. + for req in next_turn_requests: + scheduler.add_request(req) + sched_outputs.append(scheduler.schedule()) + + # Make sure the next-turn requests get prefix cache hit by the previous + # requests. + for req in next_turn_requests: + assert (req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * + BLOCK_SIZE) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index e80ad8a68151..ccdbe79dfea4 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -8,7 +8,7 @@ from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes, sha256 +from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_manager import KVCacheManager # disable yapf here as it formats differently than isort such that both fail # yapf: disable @@ -16,7 +16,8 @@ FreeKVCacheBlockQueue, KVCacheBlock, PrefixCachingMetrics, estimate_max_model_len, generate_block_hash_extra_keys, get_kv_cache_config, get_max_concurrency_for_kv_cache_config, - hash_block_tokens, hash_request_tokens, unify_kv_cache_configs) + hash_block_tokens, hash_request_tokens, init_none_hash, + unify_kv_cache_configs) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor, SlidingWindowSpec) @@ -78,24 +79,27 @@ def new_sliding_window_spec(block_size=16, sliding_window=sliding_window) -def test_none_hash(monkeypatch): +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) +def test_none_hash(monkeypatch, hash_fn): import vllm.v1.core.kv_cache_utils # case 1: PYTHONHASHSEED is not set, use random with monkeypatch.context() as m: m.delenv('PYTHONHASHSEED', raising=False) reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) + reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) assert reloaded_kv_cache_utils.NONE_HASH != 0 - # case 2: PYTHONHASHSEED is set, use the seed + # case 2: PYTHONHASHSEED is set, use the seed and hash_fn with monkeypatch.context() as m: m.setenv('PYTHONHASHSEED', 'python hash seed') reloaded_kv_cache_utils = importlib.reload(vllm.v1.core.kv_cache_utils) + reloaded_kv_cache_utils.init_none_hash(hash_fn) assert reloaded_kv_cache_utils.NONE_HASH is not None assert isinstance(reloaded_kv_cache_utils.NONE_HASH, int) - assert sha256('python hash seed') == reloaded_kv_cache_utils.NONE_HASH + assert hash_fn('python hash seed') == reloaded_kv_cache_utils.NONE_HASH def test_kv_cache_block(): @@ -128,8 +132,8 @@ def test_free_kv_cache_block_queue_initialization(): block = KVCacheBlock(block_id=0) queue = FreeKVCacheBlockQueue([block]) assert queue.num_free_blocks == 1 - assert queue.free_list_head == block - assert queue.free_list_tail == block + assert queue.fake_free_list_head.next_free_block is block + assert queue.fake_free_list_tail.prev_free_block is block def test_free_kv_cache_block_queue_operations(): @@ -141,36 +145,38 @@ def test_free_kv_cache_block_queue_operations(): # Check initial state assert queue.num_free_blocks == 5 - assert queue.free_list_head == blocks[0] - assert queue.free_list_tail == blocks[4] + assert queue.fake_free_list_head.next_free_block is blocks[0] + assert queue.fake_free_list_tail.prev_free_block is blocks[4] # Pop the first block block1 = queue.popleft() assert block1 == blocks[0] assert queue.num_free_blocks == 4 - assert queue.free_list_head == blocks[1] - assert queue.free_list_tail == blocks[4] + assert queue.fake_free_list_head.next_free_block is blocks[1] + assert queue.fake_free_list_tail.prev_free_block is blocks[4] # Remove a block from the middle block_to_remove = blocks[2] queue.remove(block_to_remove) assert queue.num_free_blocks == 3 - assert blocks[1].next_free_block == blocks[3] - assert blocks[3].prev_free_block == blocks[1] + assert blocks[1].next_free_block is blocks[3] + assert blocks[3].prev_free_block is blocks[1] # Append a block back queue.append(block_to_remove) assert queue.num_free_blocks == 4 - assert queue.free_list_tail == block_to_remove - assert block_to_remove.prev_free_block == blocks[4] - assert block_to_remove.next_free_block is None + assert queue.fake_free_list_tail.prev_free_block is block_to_remove + assert block_to_remove.prev_free_block is blocks[4] + assert block_to_remove.next_free_block is queue.fake_free_list_tail # Pop blocks until empty for _ in range(4): queue.popleft() assert queue.num_free_blocks == 0 - assert queue.free_list_head is None - assert queue.free_list_tail is None + assert (queue.fake_free_list_head.next_free_block + is queue.fake_free_list_tail) + assert (queue.fake_free_list_tail.prev_free_block + is queue.fake_free_list_head) # Attempt to pop from an empty queue with pytest.raises(ValueError) as e: @@ -178,6 +184,111 @@ def test_free_kv_cache_block_queue_operations(): assert str(e.value) == "No free blocks available" +def test_free_kv_cache_block_queue_append_n(): + # Create an empty FreeKVCacheBlockQueue with these blocks + queue = FreeKVCacheBlockQueue([]) + blocks = [KVCacheBlock(block_id=i) for i in range(6)] + # Append 0 block + # fake_head->fake_tail + queue.append_n([]) + assert queue.num_free_blocks == 0 + assert (queue.fake_free_list_head.next_free_block + is queue.fake_free_list_tail) + assert (queue.fake_free_list_tail.prev_free_block + is queue.fake_free_list_head) + # Append 1 block + # fake_head->b0->fake_tail + queue.append_n(blocks[0:1]) + assert queue.num_free_blocks == 1 + assert queue.fake_free_list_head.next_free_block is blocks[0] + assert blocks[0].prev_free_block is queue.fake_free_list_head + assert blocks[0].next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is blocks[0] + # Append 2 blocks + # fake_head->b0->b4->b5->fake_tail + queue.append_n(blocks[4:6]) + assert queue.num_free_blocks == 3 + assert queue.fake_free_list_head.next_free_block is blocks[0] + assert blocks[0].prev_free_block is queue.fake_free_list_head + assert blocks[0].next_free_block is blocks[4] + assert blocks[4].prev_free_block is blocks[0] + assert blocks[4].next_free_block is blocks[5] + assert blocks[5].prev_free_block is blocks[4] + assert blocks[5].next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is blocks[5] + # Append 3 blocks + # fake_head->b0->b4->b5->b1->b2->b3->fake_tail + queue.append_n(blocks[1:4]) + assert queue.num_free_blocks == 6 + assert queue.fake_free_list_head.next_free_block is blocks[0] + assert blocks[0].prev_free_block is queue.fake_free_list_head + assert blocks[0].next_free_block is blocks[4] + assert blocks[4].prev_free_block is blocks[0] + assert blocks[4].next_free_block is blocks[5] + assert blocks[5].prev_free_block is blocks[4] + assert blocks[5].next_free_block is blocks[1] + assert blocks[1].prev_free_block is blocks[5] + assert blocks[1].next_free_block is blocks[2] + assert blocks[2].prev_free_block is blocks[1] + assert blocks[2].next_free_block is blocks[3] + assert blocks[3].prev_free_block is blocks[2] + assert blocks[3].next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is blocks[3] + + +def test_free_kv_cache_block_queue_popleft_n(): + blocks = [KVCacheBlock(block_id=i) for i in range(6)] + # Create a empty FreeKVCacheBlockQueue with these blocks + queue = FreeKVCacheBlockQueue( + [blocks[1], blocks[3], blocks[5], blocks[4], blocks[0], blocks[2]]) + assert queue.num_free_blocks == 6 + assert queue.fake_free_list_head.next_free_block is blocks[1] + assert blocks[1].prev_free_block is queue.fake_free_list_head + assert blocks[1].next_free_block is blocks[3] + assert blocks[3].prev_free_block is blocks[1] + assert blocks[3].next_free_block is blocks[5] + assert blocks[5].prev_free_block is blocks[3] + assert blocks[5].next_free_block is blocks[4] + assert blocks[4].prev_free_block is blocks[5] + assert blocks[4].next_free_block is blocks[0] + assert blocks[0].prev_free_block is blocks[4] + assert blocks[0].next_free_block is blocks[2] + assert blocks[2].prev_free_block is blocks[0] + assert blocks[2].next_free_block is queue.fake_free_list_tail + assert queue.fake_free_list_tail.prev_free_block is blocks[2] + + # Pop 0 block + # fake_head->b1->b3->b5->b4->b0->b2->fake_tail + assert len(queue.popleft_n(0)) == 0 + # Pop 1 block + # fake_head->b3->b5->b4->b0->b2->fake_tail + result_blocks = queue.popleft_n(1) + assert len(result_blocks) == 1 + assert result_blocks[0] is blocks[1] + for block in result_blocks: + assert block.prev_free_block is None + assert block.next_free_block is None + # Pop 2 blocks + # fake_head->b4->b0->b2->fake_tail + result_blocks = queue.popleft_n(2) + assert len(result_blocks) == 2 + assert result_blocks[0] is blocks[3] + assert result_blocks[1] is blocks[5] + for block in result_blocks: + assert block.prev_free_block is None + assert block.next_free_block is None + # Pop 3 blocks + # fake_head->fake_tail + result_blocks = queue.popleft_n(3) + assert len(result_blocks) == 3 + assert result_blocks[0] is blocks[4] + assert result_blocks[1] is blocks[0] + assert result_blocks[2] is blocks[2] + for block in result_blocks: + assert block.prev_free_block is None + assert block.next_free_block is None + + def test_free_kv_cache_block_queue_get_all_free_blocks(): # Create a list of KVCacheBlock objects blocks = [KVCacheBlock(block_id=i) for i in range(5)] @@ -287,9 +398,10 @@ def test_generate_block_hash_extra_keys_cache_salt(): assert next_mm_idx == 1 -@pytest.mark.parametrize("hash_fn", [sha256, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) def test_hash_block_tokens(hash_fn): import vllm.v1.core.kv_cache_utils + init_none_hash(hash_fn) parent_block_hash = 123 curr_block_token_ids = (1, 2, 3) extra_keys = ("key1", "key2") @@ -303,9 +415,10 @@ def test_hash_block_tokens(hash_fn): assert block_hash.extra_keys == extra_keys -@pytest.mark.parametrize("hash_fn", [sha256, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) def test_hash_request_tokens(hash_fn): import vllm.v1.core.kv_cache_utils + init_none_hash(hash_fn) request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], @@ -332,8 +445,10 @@ def test_hash_request_tokens(hash_fn): assert block_hashes[1].extra_keys == ("hash2", ) -@pytest.mark.parametrize("hash_fn", [sha256, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) def test_hash_tokens_different_mm_input(hash_fn): + init_none_hash(hash_fn) + request1 = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], @@ -359,8 +474,10 @@ def test_hash_tokens_different_mm_input(hash_fn): assert block_hashes1[1] != block_hashes2[1] -@pytest.mark.parametrize("hash_fn", [sha256, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) def test_hash_request_tokens_no_mm_inputs(hash_fn): + init_none_hash(hash_fn) + request = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], @@ -916,4 +1033,4 @@ def test_get_kv_cache_config(): ], kv_cache_groups=[ KVCacheGroupSpec(["layer_1", "layer_2"], new_kv_cache_spec()) - ]) \ No newline at end of file + ]) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 7a42778831c5..085616303d85 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -11,11 +11,12 @@ from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams -from vllm.utils import sha256 +from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_manager import KVCacheManager, Request from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, - KVCacheBlock, hash_block_tokens) + KVCacheBlock, hash_block_tokens, + init_none_hash) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, SlidingWindowSpec) @@ -91,7 +92,7 @@ def make_kv_cache_config_hybrid_model(block_size: int, ) -@pytest.mark.parametrize("hash_algo", ["sha256", "hash"]) +@pytest.mark.parametrize("hash_algo", ["sha256", "sha256_cbor_64bit", "hash"]) def test_prefill(hash_algo): manager = KVCacheManager( make_kv_cache_config(16, 11), @@ -101,7 +102,8 @@ def test_prefill(hash_algo): ) # choose the hash function according to the parameter - hash_fn = sha256 if hash_algo == "sha256" else hash + hash_fn = (sha256_cbor_64bit if hash_algo == "sha256_cbor_64bit" else + sha256 if hash_algo == "sha256" else hash) # Complete 3 blocks (48 tokens) common_token_ids = [i for i in range(3) for _ in range(16)] @@ -153,13 +155,14 @@ def test_prefill(hash_algo): assert block.ref_cnt == 2 # At this point, we should have 5 free blocks left. - assert manager.block_pool.free_block_queue.num_free_blocks == 5 + free_block_queue = manager.block_pool.free_block_queue + assert free_block_queue.num_free_blocks == 5 manager.free(req0) manager.free(req1) # All blocks should be available. - assert manager.block_pool.free_block_queue.num_free_blocks == 10 + assert free_block_queue.num_free_blocks == 10 # The order should be # [unallocated (6, 7, 8, 9, 10)] # [unique_req0 (4)] @@ -186,14 +189,10 @@ def test_prefill(hash_algo): # Although we only have 6 free blocks, we have 8 blocks in # the free block queue due to lazy removal. - assert manager.block_pool.free_block_queue.num_free_blocks == 6 - assert all([ - b.ref_cnt == 0 - for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ]) - assert len([ - b for b in manager.block_pool.free_block_queue.get_all_free_blocks() - ]) == 6 + assert free_block_queue.num_free_blocks == 6 + assert all( + [b.ref_cnt == 0 for b in free_block_queue.get_all_free_blocks()]) + assert len([b for b in free_block_queue.get_all_free_blocks()]) == 6 manager.free(req2) @@ -207,9 +206,12 @@ def test_prefill(hash_algo): computed_blocks) # This block ID order also checks the eviction order. assert blocks.get_block_ids() == ([7, 8, 9, 10, 4, 5, 6, 3, 2, 1], ) - assert manager.block_pool.free_block_queue.num_free_blocks == 0 - assert manager.block_pool.free_block_queue.free_list_head is None - assert manager.block_pool.free_block_queue.free_list_tail is None + + assert free_block_queue.num_free_blocks == 0 + assert (free_block_queue.fake_free_list_head.next_free_block + is free_block_queue.fake_free_list_tail) + assert (free_block_queue.fake_free_list_tail.prev_free_block + is free_block_queue.fake_free_list_head) def test_prefill_hybrid_model(): @@ -696,12 +698,14 @@ def test_basic_prefix_caching_disabled(): assert not blocks -@pytest.mark.parametrize("hash_fn", [sha256, hash]) +@pytest.mark.parametrize("hash_fn", [sha256, sha256_cbor_64bit, hash]) def test_cache_blocks(hash_fn): """ This is a unit test that tests the correctness of the _cache_full_blocks function of KVCacheManager. """ + init_none_hash(hash_fn) + block_size = 4 block_pool = BlockPool( num_gpu_blocks=5, @@ -1093,6 +1097,73 @@ def test_prefix_cache_stats_disabled(): assert manager.prefix_cache_stats is None +def test_maybe_evict_cached_block(): + pool = BlockPool(num_gpu_blocks=4, enable_caching=True) + block_hash0 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=10, + token_ids=(100, )), + group_id=1000) + block_hash1 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=20, + token_ids=(200, )), + group_id=2000) + block_hash2 = BlockHashWithGroupId(block_hash=BlockHash(hash_value=30, + token_ids=(300, )), + group_id=3000) + block_hashes = [ + block_hash0, + block_hash1, + block_hash2, + # block3 had the exact same block_hash as the first block + block_hash0, + ] + assert len(pool.blocks) == len(block_hashes) + # Manually add all blocks to cached_blocks + for block, block_hash in zip(pool.blocks, block_hashes): + block.block_hash = block_hash + pool.cached_block_hash_to_block[block_hash][block.block_id] = block + + block0, block1, block2, block3 = pool.blocks + assert pool.cached_block_hash_to_block == { + block_hash0: { + block0.block_id: block0, + block3.block_id: block3 + }, + block_hash1: { + block1.block_id: block1 + }, + block_hash2: { + block2.block_id: block2 + } + } + # Evict block1 + pool._maybe_evict_cached_block(block1) + assert pool.cached_block_hash_to_block == { + block_hash0: { + block0.block_id: block0, + block3.block_id: block3 + }, + block_hash2: { + block2.block_id: block2 + } + } + # Evict block0: block_hash0 entry should NOT be removed, as block3 + # also use the same hash + pool._maybe_evict_cached_block(block0) + assert pool.cached_block_hash_to_block == { + block_hash0: { + block3.block_id: block3 + }, + block_hash2: { + block2.block_id: block2 + } + } + # Evict block2 + pool._maybe_evict_cached_block(block2) + assert pool.cached_block_hash_to_block == {block_hash0: {3: block3}} + # Evict block3 + pool._maybe_evict_cached_block(block3) + assert pool.cached_block_hash_to_block == {} + + @pytest.mark.parametrize("blocks_to_cache", [2, 3, 10]) def test_kv_cache_events(blocks_to_cache: int): block_size = 16 diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 02d2c83ab158..a858a4d8c823 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -19,133 +19,7 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.request import StructuredOutputRequest -EOS_TOKEN_ID = 50256 - - -def create_scheduler( - model: str = "facebook/opt-125m", - max_num_seqs: int = 16, - max_num_batched_tokens: int = 8192, - enable_prefix_caching: Optional[bool] = None, - long_prefill_token_threshold: int = 0, - disable_chunked_mm_input: bool = False, - use_kv_connector: bool = False, - num_blocks: int = 10000, - block_size: int = 16, - max_model_len: Optional[int] = None, - num_speculative_tokens: Optional[int] = None, - skip_tokenizer_init: bool = False, -) -> Scheduler: - '''Create scheduler under test. - - Args: - model: model under test - max_num_seqs: max sequences to schedule - max_num_batch_tokens: max num tokens to batch - enable_prefix_caching: optionally force APC config - (True/False) or use default - (None) - - Returns: - {class}`Scheduler` instance - ''' - if max_model_len is None: - max_model_len = max_num_batched_tokens - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_model_len, - long_prefill_token_threshold=long_prefill_token_threshold, - disable_chunked_mm_input=disable_chunked_mm_input, - enable_chunked_prefill=True, - ) - model_config = ModelConfig( - model=model, - task="auto", - tokenizer=model, - tokenizer_mode="auto", - trust_remote_code=True, - dtype="float16", - seed=42, - skip_tokenizer_init=skip_tokenizer_init, - ) - # Cache config, optionally force APC - kwargs_cache = ({} if enable_prefix_caching is None else { - 'enable_prefix_caching': enable_prefix_caching - }) - cache_config = CacheConfig( - block_size=block_size, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - **kwargs_cache, - ) - kv_transfer_config = KVTransferConfig( - kv_connector="SharedStorageConnector", - kv_role="kv_both", - kv_connector_extra_config={"shared_storage_path": "local_storage"}, - ) if use_kv_connector else None - - speculative_config: Optional[SpeculativeConfig] = None - if num_speculative_tokens is not None: - speculative_config = SpeculativeConfig( - model="ngram", num_speculative_tokens=num_speculative_tokens) - - vllm_config = VllmConfig( - scheduler_config=scheduler_config, - model_config=model_config, - cache_config=cache_config, - kv_transfer_config=kv_transfer_config, - speculative_config=speculative_config, - ) - kv_cache_config = KVCacheConfig( - num_blocks=num_blocks, # A large number of blocks to hold all requests - kv_cache_tensors=[], - kv_cache_groups=[ - KVCacheGroupSpec(['layer'], - FullAttentionSpec(block_size, 1, 1, torch.float32, - False)) - ], - ) - cache_config.num_gpu_blocks = num_blocks - return Scheduler( - vllm_config=vllm_config, - kv_cache_config=kv_cache_config, - log_stats=True, - structured_output_manager=StructuredOutputManager(vllm_config), - ) - - -def create_requests(num_requests: int, - num_tokens: int = 10, - mm_positions: Optional[list[PlaceholderRange]] = None, - max_tokens: int = 16, - stop_token_ids: Optional[list[int]] = None, - prompt_logprobs: Optional[int] = None): - sampling_params = SamplingParams(ignore_eos=False, - max_tokens=max_tokens, - stop_token_ids=stop_token_ids, - prompt_logprobs=prompt_logprobs) - requests = [] - for i in range(num_requests): - if mm_positions is not None: - mm_position = mm_positions[i] - mm_inputs = [MultiModalKwargs({})] * len(mm_position) - else: - mm_position = None - mm_inputs = None - request = Request( - request_id=f"{i}", - prompt_token_ids=[i] * num_tokens, - sampling_params=sampling_params, - pooling_params=None, - multi_modal_inputs=mm_inputs, - multi_modal_placeholders=mm_position, - multi_modal_hashes=None, - eos_token_id=EOS_TOKEN_ID, - ) - requests.append(request) - return requests +from .utils import EOS_TOKEN_ID, create_requests, create_scheduler def test_add_requests(): @@ -451,6 +325,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -504,6 +379,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -556,6 +432,7 @@ def test_stop_via_update_from_output(): req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) + req.status = RequestStatus.RUNNING scheduler_output = SchedulerOutput( scheduled_new_reqs=[], @@ -703,6 +580,65 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], scheduler.update_from_output(scheduler_output1, model_runner_output) +def test_preempt_during_execution(): + # NOTE(woosuk): The actual number of available blocks is 10 instead of 11 + # because block 0 is reserved as the null block. + scheduler = create_scheduler(max_num_batched_tokens=100, + block_size=16, + num_blocks=11, + enable_prefix_caching=False) + requests = create_requests(num_requests=2, num_tokens=80) + + # Schedule the first request. + scheduler.add_request(requests[0]) + scheduler_output0 = scheduler.schedule() + assert len(scheduler_output0.num_scheduled_tokens) == 1 + assert len(scheduler_output0.scheduled_new_reqs[0].block_ids[0]) == 5 + + # Schedule the second request while the first request is still running. + # This scenario can occur in certain cases, when max_concurrent_batches > 1 + # (e.g., when pipeline parallelism is used). + scheduler.add_request(requests[1]) + scheduler_output1 = scheduler.schedule() + assert len(scheduler_output1.num_scheduled_tokens) == 1 + assert len(scheduler_output1.scheduled_new_reqs[0].block_ids[0]) == 5 + + # Get the output of the first request. + model_runner_output0 = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[0]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(scheduler_output0, model_runner_output0) + + # Schedule the first request again. This will cause the preemption + # of the second request because the KV cache is full. + _ = scheduler.schedule() + assert len(scheduler.running) == 1 + assert scheduler.running[0] == requests[0] + assert requests[1].status == RequestStatus.PREEMPTED + + model_runner_output1 = ModelRunnerOutput( + req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, + sampled_token_ids=[[42]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + ) + scheduler.update_from_output(scheduler_output1, model_runner_output1) + + # The second request (that is preempted) should be updated with the + # sampled token id. + assert len(requests[1].output_token_ids) == 1 + assert requests[1].output_token_ids[0] == 42 + + # Note - these test cases mirror some of those in test_rejection_sampler.py @pytest.mark.parametrize( "spec_tokens,output_tokens,expected", diff --git a/tests/v1/core/test_scheduler_e2e.py b/tests/v1/core/test_scheduler_e2e.py index 85415f6ad4b6..bd0320baef87 100644 --- a/tests/v1/core/test_scheduler_e2e.py +++ b/tests/v1/core/test_scheduler_e2e.py @@ -14,7 +14,7 @@ @pytest.fixture(scope="module") -def model() -> LLM: +def llm() -> LLM: return LLM(MODEL, enforce_eager=True, enable_prefix_caching=True, @@ -24,16 +24,16 @@ def model() -> LLM: block_size=16) -def test_concurrent_partial_prefill(model): - outputs = model.generate([PROMPT] * 3) +def test_concurrent_partial_prefill(llm): + outputs = llm.generate([PROMPT] * 3) assert len(outputs) == 3 for output in outputs: assert len(output.outputs) == 1 -def test_prefix_cache_stats_is_recorded(model): +def test_prefix_cache_stats_is_recorded(llm): # 17 tokens will make sure first 16 tokens are cached in a block input_tokens = {"prompt_token_ids": [101] * 17} - _ = model.generate([input_tokens]) - outputs = model.generate([input_tokens]) + _ = llm.generate([input_tokens]) + outputs = llm.generate([input_tokens]) assert outputs[0].num_cached_tokens == 16 diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index a9e1898df934..b67c05bd7ac1 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -1,13 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import random + import torch from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId, KVCacheBlock) -from vllm.v1.core.single_type_kv_cache_manager import SlidingWindowManager -from vllm.v1.kv_cache_interface import SlidingWindowSpec +from vllm.v1.core.single_type_kv_cache_manager import ( + ChunkedLocalAttentionManager, SlidingWindowManager) +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + SlidingWindowSpec) def get_sliding_window_manager(sliding_window_spec, block_pool): @@ -17,6 +21,80 @@ def get_sliding_window_manager(sliding_window_spec, block_pool): kv_cache_group_id=0) +def get_chunked_local_attention_manager(chunked_local_attention_spec, + block_pool): + return ChunkedLocalAttentionManager(chunked_local_attention_spec, + block_pool, + caching_hash_fn=lambda x: x, + kv_cache_group_id=0) + + +def test_chunked_local_attention_possible_cached_prefix(): + block_size = 2 + chunked_local_attention_spec = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + manager = get_chunked_local_attention_manager(chunked_local_attention_spec, + block_pool) + + def run_one_case(block_is_cached, tail_token, expect_length): + block_hash_list = [ + BlockHash(i, ()) for i in range(len(block_is_cached)) + ] + + block_pool.cached_block_hash_to_block.clear() + + # Mock the block pool with the cached blocks + for i, (block_hash, + is_cached) in enumerate(zip(block_hash_list, block_is_cached)): + if is_cached: + block_pool.cached_block_hash_to_block[BlockHashWithGroupId( + block_hash, 0)] = { + i: block_pool.blocks[i + 10], + } + + computed_blocks = manager.find_longest_cache_hit( + block_hashes=block_hash_list, + max_length=len(block_hash_list) * block_size + tail_token, + kv_cache_group_ids=[0], + block_pool=block_pool, + kv_cache_spec=chunked_local_attention_spec, + use_eagle=False)[0] + assert len(computed_blocks) == expect_length + + assert all(block == block_pool.null_block + for block in computed_blocks[:(expect_length - 1) // 2]) + + run_one_case([True], 0, 1) + run_one_case([True], 1, 1) + run_one_case([True, False], 0, 2) + run_one_case([True, False], 1, 2) + run_one_case([True, True], 0, 2) + run_one_case([True, True], 1, 2) + run_one_case([True, True, False], 0, 2) + run_one_case([True, True, False], 1, 2) + run_one_case([True, True, True], 0, 3) + run_one_case([True, True, True], 1, 3) + run_one_case([True, True, True, False], 0, 4) + run_one_case([True, True, True, False], 1, 4) + run_one_case([random.choice([True, False])] * 8 + [True], 1, 9) + run_one_case([random.choice([True, False])] * 8 + [False], 1, 8) + run_one_case([random.choice([True, False])] * 8 + [True, True], 1, 10) + run_one_case([random.choice([True, False])] * 8 + [True, False], 0, 10) + run_one_case([random.choice([True, False])] * 8 + [True, False], 1, 10) + run_one_case([random.choice([True, False])] * 8 + [False, True], 0, 10) + run_one_case([random.choice([True, False])] * 8 + [False, True], 1, 10) + run_one_case([random.choice([True, False])] * 8 + [False, False], 0, 10) + run_one_case([random.choice([True, False])] * 8 + [False, False], 1, 10) + + def test_sliding_window_possible_cached_prefix(): block_size = 2 sliding_window_spec = SlidingWindowSpec( @@ -84,6 +162,58 @@ def run_one_case(block_is_cached, expect_length): ], 8) +def test_chunked_local_attention_remove_skipped_blocks(): + attention_spec = ChunkedLocalAttentionSpec( + block_size=2, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) + + manager = get_chunked_local_attention_manager(attention_spec, block_pool) + + null_block_id = block_pool.null_block.block_id + + def id_to_block_table(ids) -> list[KVCacheBlock]: + return [ + KVCacheBlock(id_) + if id_ != null_block_id else block_pool.null_block for id_ in ids + ] + + def assert_block_id(block_table: list[KVCacheBlock], ids: list[int]): + for block, id_ in zip(block_table, ids): + if id_ == null_block_id: + assert block == block_pool.null_block + else: + assert block.block_id == id_ + + original_block_ids = [ + 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 + ] + block_table = id_to_block_table(original_block_ids) + manager.req_to_blocks["test"] = block_table + + manager.remove_skipped_blocks("test", 0) + assert_block_id(block_table, original_block_ids) + + # For 4th token (0-indexed), token 0-3 is out of the local attention window. + manager.remove_skipped_blocks("test", 4) + assert_block_id(block_table, [null_block_id] * 2) + + # For 6th token (0-indexed), token 4 - 6 are in local attention window, + # token 0 - 3 are out, 2 blocks can be removed. + manager.remove_skipped_blocks("test", 6) + assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) + # For 12th token (0-indexed), + # token 0-11 are out, 6 block can be removed. + manager.remove_skipped_blocks("test", 12) + assert_block_id(block_table, [null_block_id] * 6) + + def test_sliding_window_remove_skipped_blocks(): sliding_window_spec = SlidingWindowSpec( block_size=2, @@ -172,3 +302,26 @@ def test_get_num_blocks_to_allocate(): cached_blocks_1) == 20 assert manager.get_num_blocks_to_allocate("2", 20 * block_size, cached_blocks_2) == 15 + + +def test_chunked_local_attention_get_num_blocks_to_allocate(): + block_size = 2 + attention_spec = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + attention_chunk_size=4, # Placeholder value, not related to test result + use_mla=False, + ) + + block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) + manager = get_chunked_local_attention_manager(attention_spec, block_pool) + cached_blocks_1 = [KVCacheBlock(i + 1) for i in range(10)] + cached_blocks_2 = [block_pool.null_block for _ in range(5) + ] + [KVCacheBlock(i + 1) for i in range(5)] + + assert manager.get_num_blocks_to_allocate("1", 20 * block_size, + cached_blocks_1) == 20 + assert manager.get_num_blocks_to_allocate("2", 20 * block_size, + cached_blocks_2) == 15 diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py new file mode 100644 index 000000000000..0b7d8251b640 --- /dev/null +++ b/tests/v1/core/utils.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional, Union + +import torch + +from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, + SchedulerConfig, SpeculativeConfig, VllmConfig) +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.async_scheduler import AsyncScheduler +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.request import Request +from vllm.v1.structured_output import StructuredOutputManager + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, + skip_tokenizer_init: bool = False, + async_scheduling: bool = False, +) -> Union[Scheduler, AsyncScheduler]: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + {class}`Scheduler` instance + ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, + async_scheduling=async_scheduling, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + skip_tokenizer_init=skip_tokenizer_init, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + scheduler_cls = AsyncScheduler if async_scheduling else Scheduler + return scheduler_cls( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests( + num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None, + same_prompt: bool = False, +) -> list[Request]: + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + prompt_token_ids = ([0] * num_tokens if same_prompt else [i] * + num_tokens) + request = Request( + request_id=f"{i}", + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + pooling_params=None, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + ) + requests.append(request) + return requests diff --git a/tests/v1/e2e/test_cascade_attention.py b/tests/v1/e2e/test_cascade_attention.py index 161bcd4d3ef9..f2f460513605 100644 --- a/tests/v1/e2e/test_cascade_attention.py +++ b/tests/v1/e2e/test_cascade_attention.py @@ -5,10 +5,10 @@ from vllm import LLM, SamplingParams -from ...utils import fork_new_process_for_each_test +from ...utils import create_new_process_for_each_test -@fork_new_process_for_each_test +@create_new_process_for_each_test() @pytest.mark.parametrize("attn_backend", ["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"]) def test_cascade_attention(example_system_message, monkeypatch, attn_backend): diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 93e7c12f3a09..2423f966acfa 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -6,8 +6,10 @@ from typing import Any import pytest +import torch from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory @pytest.fixture @@ -53,14 +55,6 @@ def model_name(): return "meta-llama/Llama-3.1-8B-Instruct" -def eagle_model_name(): - return "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" - - -def eagle3_model_name(): - return "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" - - def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -77,6 +71,8 @@ def test_ngram_correctness( ref_llm = LLM(model=model_name, max_model_len=1024) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() spec_llm = LLM( model=model_name, @@ -103,34 +99,50 @@ def test_ngram_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm - - -@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("model_setup", [ + ("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), + ("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +], + ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle"]) def test_eagle_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], sampling_config: SamplingParams, - model_name: str, - use_eagle3: bool, + model_setup: tuple[str, str, str, int], ): ''' Compare the outputs of a original LLM and a speculative LLM should be the same when using eagle speculative decoding. + model_setup: (method, model_name, eagle_model_name, tp_size) ''' with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") + method, model_name, spec_model_name, tp_size = model_setup - ref_llm = LLM(model=model_name, max_model_len=2048) + ref_llm = LLM(model=model_name, + max_model_len=2048, + tensor_parallel_size=tp_size) ref_outputs = ref_llm.chat(test_prompts, sampling_config) del ref_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() - spec_model_name = eagle3_model_name( - ) if use_eagle3 else eagle_model_name() spec_llm = LLM( model=model_name, trust_remote_code=True, + tensor_parallel_size=tp_size, speculative_config={ - "method": "eagle3" if use_eagle3 else "eagle", + "method": method, "model": spec_model_name, "num_speculative_tokens": 3, "max_model_len": 2048, @@ -152,3 +164,5 @@ def test_eagle_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.66 * len(ref_outputs)) del spec_llm + torch.cuda.empty_cache() + cleanup_dist_env_and_memory() diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index e137452f2625..412df3acff12 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -336,9 +336,10 @@ async def test_customize_loggers(monkeypatch): await engine.do_log_stats() - assert len(engine.stat_loggers) == 1 - assert len(engine.stat_loggers[0]) == 1 - engine.stat_loggers[0][0].log.assert_called_once() + stat_loggers = engine.logger_manager.per_engine_logger_dict + assert len(stat_loggers) == 1 + assert len(stat_loggers[0]) == 1 + stat_loggers[0][0].log.assert_called_once() @pytest.mark.asyncio(scope="module") diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 65f1da803fb2..2ac6dc796bd1 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -565,8 +565,8 @@ def create_mock_executor(vllm_config): from vllm.v1.engine.utils import EngineZmqAddresses - def mock_startup_handshake(self, handshake_socket, on_head_node, - parallel_config): + def mock_startup_handshake(self, handshake_socket, local_client, + headless, parallel_config): return EngineZmqAddresses(inputs=["tcp://127.0.0.1:5555"], outputs=["tcp://127.0.0.1:5556"], coordinator_input=None, diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 059106c62a20..f37686317fd1 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -112,9 +112,9 @@ def test_compatibility_with_skip_tokenizer_init( example_prompts, structured_outputs=True, ) - model: LLM = vllm_model_skip_tokenizer_init.model + llm: LLM = vllm_model_skip_tokenizer_init.llm with pytest.raises(ValueError): - _ = model.generate(example_prompts, sampling_params_list) + _ = llm.generate(example_prompts, sampling_params_list) def test_parallel_sampling(vllm_model, example_prompts) -> None: @@ -125,8 +125,8 @@ def test_parallel_sampling(vllm_model, example_prompts) -> None: example_prompt: test fixture providing prompts for testing. """ sampling_params_list, n_list = _get_test_sampling_params(example_prompts) - model: LLM = vllm_model.model - outputs = model.generate(example_prompts, sampling_params_list) + llm: LLM = vllm_model.llm + outputs = llm.generate(example_prompts, sampling_params_list) # Validate each request response for out, n in zip(outputs, n_list): @@ -166,10 +166,10 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): speculative_config=speculative_config, disable_log_stats=False, ) as vllm_model: - model: LLM = vllm_model.model + llm: LLM = vllm_model.llm sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = model.generate(example_prompts, sampling_params) + outputs = llm.generate(example_prompts, sampling_params) n_prompts = len(example_prompts) assert len(outputs) == n_prompts @@ -180,7 +180,7 @@ def test_engine_metrics(vllm_runner, monkeypatch, example_prompts): total_tokens += len(out.outputs[0].token_ids) assert total_tokens == max_tokens * n_prompts - metrics = model.get_metrics() + metrics = llm.get_metrics() def find_metric(name) -> list[Metric]: found = [] diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index 1c8c5f25e29b..949ab764e2e9 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -35,7 +35,7 @@ def _ref_convert_id_to_token( Returns: String representation of input token id """ - return tokenizer.convert_ids_to_tokens(token_id) or "" + return tokenizer.decode([token_id]) or "" @pytest.mark.parametrize( diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index a39ab47b8d87..8bddfb0b48a5 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -41,6 +41,10 @@ ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", None), ("mistralai/Ministral-8B-Instruct-2410", "xgrammar", "mistral", None), ("Qwen/Qwen2.5-1.5B-Instruct", "xgrammar", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", None), + ("mistralai/Ministral-8B-Instruct-2410", "outlines", "mistral", None), + ("mistralai/Ministral-8B-Instruct-2410", "outlines", "auto", + NGRAM_SPEC_CONFIG), #FIXME: This test is flaky on CI thus disabled #("Qwen/Qwen2.5-1.5B-Instruct", "guidance", "auto"), ("mistralai/Ministral-8B-Instruct-2410", "guidance", "auto", @@ -106,13 +110,15 @@ def test_structured_output( enforce_eager = bool(not current_platform.is_tpu()) # Use a single LLM instance for several scenarios to # speed up the test suite. - llm = LLM(model=model_name, - enforce_eager=enforce_eager, - max_model_len=1024, - guided_decoding_backend=guided_decoding_backend, - guided_decoding_disable_any_whitespace=True, - tokenizer_mode=tokenizer_mode, - speculative_config=speculative_config) + llm = LLM( + model=model_name, + enforce_eager=enforce_eager, + max_model_len=1024, + guided_decoding_backend=guided_decoding_backend, + guided_decoding_disable_any_whitespace=(guided_decoding_backend + in {"xgrammar", "guidance"}), + tokenizer_mode=tokenizer_mode, + speculative_config=speculative_config) # # Test 1: Generate JSON output based on a provided schema @@ -146,32 +152,33 @@ def test_structured_output( # # Test 2: Generate JSON object without a schema # - sampling_params = SamplingParams( - temperature=1.0, - max_tokens=4096, - n=2, - guided_decoding=GuidedDecodingParams(json_object=True)) + if guided_decoding_backend != "outlines": + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=4096, + n=2, + guided_decoding=GuidedDecodingParams(json_object=True)) - outputs = llm.generate( - prompts=("Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old. " - "Make the response as short as possible."), - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Generate a JSON object with curly braces for a person with " + "name and age fields for John Smith who is 31 years old. " + "Make the response as short as possible."), + sampling_params=sampling_params, + use_tqdm=True) - assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) - for i in range(2): - generated_text = output.outputs[i].text - print(generated_text) - assert generated_text is not None + for i in range(2): + generated_text = output.outputs[i].text + print(generated_text) + assert generated_text is not None - # Parse to verify it is a valid JSON object - parsed_json = json.loads(generated_text) - assert isinstance(parsed_json, dict) + # Parse to verify it is a valid JSON object + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) # # Test 3: test a jsonschema incompatible with xgrammar @@ -210,96 +217,97 @@ def test_structured_output( parsed_json = json.loads(generated_text) assert isinstance(parsed_json, dict) - # - # Test 4: Generate SQL statement using EBNF grammar - # - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) - outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True, - ) + if guided_decoding_backend != "outlines": + # + # Test 4: Generate SQL statement using EBNF grammar + # + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) + outputs = llm.generate( + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) - assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt - generated_text = output.outputs[0].text - assert generated_text is not None + generated_text = output.outputs[0].text + assert generated_text is not None - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") - assert generated_text.strip() == ground_truth + assert generated_text.strip() == ground_truth - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - # - # Test 5: Generate SQL statement using Lark grammar - # - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) - outputs = llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short as " - "possible."), - sampling_params=sampling_params, - use_tqdm=True, - ) + # + # Test 5: Generate SQL statement using Lark grammar + # + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) + outputs = llm.generate( + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) - assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - prompt = output.prompt + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + prompt = output.prompt - generated_text = output.outputs[0].text - assert generated_text is not None + generated_text = output.outputs[0].text + assert generated_text is not None - # use Lark to parse the output, and make sure it's a valid parse tree - from lark import Lark - parser = Lark(sample_sql_lark) - parser.parse(generated_text) + # use Lark to parse the output, and make sure it's a valid parse tree + from lark import Lark + parser = Lark(sample_sql_lark) + parser.parse(generated_text) - # remove spaces for comparison b/c we removed them in the grammar - ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( - " ", "") + # remove spaces for comparison b/c we removed them in the grammar + ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace( + " ", "") - assert generated_text.strip() == ground_truth + assert generated_text.strip() == ground_truth - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - # - # Test 6: Test invalid grammar input - # - sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams(grammar="not a grammar")) - with pytest.raises(ValueError, match="Failed to convert the grammar "): - llm.generate( - prompts=( - "Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1. Make the response as short " - "as possible."), - sampling_params=sampling_params, - use_tqdm=True, - ) + # + # Test 6: Test invalid grammar input + # + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams(grammar="not a grammar")) + with pytest.raises(ValueError, match="Failed to convert the grammar "): + llm.generate( + prompts= + ("Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short " + "as possible."), + sampling_params=sampling_params, + use_tqdm=True, + ) # # Test 7: Generate text based on a regex pattern @@ -421,35 +429,36 @@ def test_structured_output( output_json = json.loads(generated_text) jsonschema.validate(instance=output_json, schema=json_schema) - # - # Test 11: Generate structured output using structural_tag format - # - structural_tag_config = { - "type": - "structural_tag", - "structures": [{ - "begin": "<function=get_weather>", - "schema": { - "type": "object", - "properties": { - "city": { - "type": "string" - } + if guided_decoding_backend != "outlines": + # + # Test 11: Generate structured output using structural_tag format + # + structural_tag_config = { + "type": + "structural_tag", + "structures": [{ + "begin": "<function=get_weather>", + "schema": { + "type": "object", + "properties": { + "city": { + "type": "string" + } + }, + "additionalProperties": False }, - "additionalProperties": False - }, - "end": "</function>" - }], - "triggers": ["<function="] - } + "end": "</function>" + }], + "triggers": ["<function="] + } - sampling_params = SamplingParams( - temperature=0.0, - max_tokens=4096, - guided_decoding=GuidedDecodingParams( - structural_tag=json.dumps(structural_tag_config))) + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=4096, + guided_decoding=GuidedDecodingParams( + structural_tag=json.dumps(structural_tag_config))) - prompt = """ + prompt = """ You have access to the following function to retrieve the weather in a city: { @@ -469,7 +478,7 @@ def test_structured_output( start_tag => `<function` parameters => a JSON dict with the function argument name - as key and function argument value as value. + as key and function argument value as value. end_tag => `</function>` Here is an example, @@ -488,37 +497,37 @@ def test_structured_output( Make the response as short as possible. """ - # Change this once other backends support structural_tag - outputs = llm.generate(prompts=prompt, - sampling_params=sampling_params, - use_tqdm=True) - assert outputs is not None + # Change this once other backends support structural_tag + outputs = llm.generate(prompts=prompt, + sampling_params=sampling_params, + use_tqdm=True) + assert outputs is not None - for output in outputs: - assert output is not None - assert isinstance(output, RequestOutput) - generated_text = output.outputs[0].text - assert generated_text is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + generated_text = output.outputs[0].text + assert generated_text is not None - # Search for function call pattern in the response - function_call_pattern = r'<function=get_weather>(.*?)</function>' - matches = re.findall(function_call_pattern, generated_text) - - if not matches: - print(f"Warning: No function calls found in response: " - f"{generated_text!r}") - continue - - # Take the first function call if multiple are found - json_str = matches[0] - try: - json_content = json.loads(json_str) - assert "city" in json_content - assert isinstance(json_content["city"], str) - print(f"Found valid function call: {generated_text!r}") - except (json.JSONDecodeError, AssertionError) as e: - pytest.fail("Invalid function call format: " - f"{generated_text!r}\nError: {str(e)}") + # Search for function call pattern in the response + function_call_pattern = r'<function=get_weather>(.*?)</function>' + matches = re.findall(function_call_pattern, generated_text) + + if not matches: + print(f"Warning: No function calls found in response: " + f"{generated_text!r}") + continue + + # Take the first function call if multiple are found + json_str = matches[0] + try: + json_content = json.loads(json_str) + assert "city" in json_content + assert isinstance(json_content["city"], str) + print(f"Found valid function call: {generated_text!r}") + except (json.JSONDecodeError, AssertionError) as e: + pytest.fail("Invalid function call format: " + f"{generated_text!r}\nError: {str(e)}") @pytest.mark.skip_global_cleanup diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/tests/v1/entrypoints/openai/responses/__init__.py similarity index 100% rename from vllm/attention/ops/blocksparse_attention/__init__.py rename to tests/v1/entrypoints/openai/responses/__init__.py diff --git a/tests/v1/entrypoints/openai/responses/conftest.py b/tests/v1/entrypoints/openai/responses/conftest.py new file mode 100644 index 000000000000..2dcdda04ecb5 --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/conftest.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +# Use a small reasoning model to test the responses API. +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + "--max-model-len", + "8192", + "--enforce-eager", # For faster startup. + "--reasoning-parser", + "deepseek_r1", + ] + + +@pytest.fixture(scope="module") +def server(default_server_args): + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/v1/entrypoints/openai/responses/test_basic.py b/tests/v1/entrypoints/openai/responses/test_basic.py new file mode 100644 index 000000000000..974ea8673c44 --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_basic.py @@ -0,0 +1,75 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import openai # use the official client for correctness check +import pytest + + +@pytest.mark.asyncio +async def test_simple_input(client: openai.AsyncOpenAI): + response = await client.responses.create(input="What is 13 * 24?") + print(response) + + outputs = response.output + # Whether the output contains the answer. + assert outputs[-1].type == "message" + assert "312" in outputs[-1].content[0].text + + # Whether the output contains the reasoning. + assert outputs[0].type == "reasoning" + assert outputs[0].text != "" + + +@pytest.mark.asyncio +async def test_instructions(client: openai.AsyncOpenAI): + response = await client.responses.create( + instructions="Finish the answer with QED.", + input="What is 13 * 24?", + ) + print(response) + + output_text = response.output[-1].content[0].text + assert "312" in output_text + assert "QED" in output_text + + +@pytest.mark.asyncio +async def test_chat(client: openai.AsyncOpenAI): + response = await client.responses.create(input=[ + { + "role": "system", + "content": "Finish the answer with QED." + }, + { + "role": "user", + "content": "What is 5 * 3?" + }, + { + "role": "assistant", + "content": "15. QED." + }, + { + "role": "user", + "content": "Multiply the result by 2." + }, + ], ) + print(response) + + output_text = response.output[-1].content[0].text + assert "30" in output_text + assert "QED" in output_text + + +@pytest.mark.asyncio +async def test_chat_with_input_type(client: openai.AsyncOpenAI): + response = await client.responses.create(input=[ + { + "role": "user", + "content": [{ + "type": "input_text", + "text": "Hello!" + }], + }, + ], ) + print(response) + assert response.status == "completed" diff --git a/tests/v1/entrypoints/openai/responses/test_image.py b/tests/v1/entrypoints/openai/responses/test_image.py new file mode 100644 index 000000000000..f3bce91e97cd --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_image.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json + +import openai +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer +from vllm.multimodal.utils import encode_image_base64, fetch_image + +# Use a small vision model for testing +MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" +MAXIMUM_IMAGES = 2 +# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) +TEST_IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png", + "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png", + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", +] + + +@pytest.fixture(scope="module") +def default_image_server_args(): + return [ + "--enforce-eager", + "--max-model-len", + "6000", + "--max-num-seqs", + "128", + "--limit-mm-per-prompt", + json.dumps({"image": MAXIMUM_IMAGES}), + ] + + +@pytest.fixture(scope="module") +def image_server(default_image_server_args): + with RemoteOpenAIServer(MODEL_NAME, + default_image_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(image_server): + async with image_server.get_async_client() as async_client: + yield async_client + + +@pytest.fixture(scope="session") +def base64_encoded_image() -> dict[str, str]: + return { + image_url: encode_image_base64(fetch_image(image_url)) + for image_url in TEST_IMAGE_URLS + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image(client: openai.AsyncOpenAI, + model_name: str, image_url: str): + content_text = "What's in this image?" + messages = [{ + "role": + "user", + "content": [ + { + "type": "input_image", + "image_url": image_url, + "detail": "auto", + }, + { + "type": "input_text", + "text": content_text + }, + ], + }] + + # test image url + response = await client.responses.create( + model=model_name, + input=messages, + ) + assert len(response.output_text) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +async def test_single_chat_session_image_base64encoded( + client: openai.AsyncOpenAI, + model_name: str, + image_url: str, + base64_encoded_image: dict[str, str], +): + content_text = "What's in this image?" + messages = [{ + "role": + "user", + "content": [ + { + "type": "input_image", + "image_url": + f"data:image/jpeg;base64,{base64_encoded_image[image_url]}", + "detail": "auto", + }, + { + "type": "input_text", + "text": content_text + }, + ], + }] + # test image base64 + response = await client.responses.create( + model=model_name, + input=messages, + ) + assert len(response.output_text) > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) +async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, + image_urls: list[str]): + messages = [{ + "role": + "user", + "content": [ + *({ + "type": "input_image", + "image_url": image_url, + "detail": "auto", + } for image_url in image_urls), + { + "type": "input_text", + "text": "What's in this image?" + }, + ], + }] + + if len(image_urls) > MAXIMUM_IMAGES: + with pytest.raises(openai.BadRequestError): # test multi-image input + await client.responses.create( + model=model_name, + input=messages, + ) + # the server should still work afterwards + response = await client.responses.create( + model=model_name, + input=[{ + "role": "user", + "content": "What's the weather like in Paris today?", + }], + ) + assert len(response.output_text) > 0 + else: + response = await client.responses.create( + model=model_name, + input=messages, + ) + assert len(response.output_text) > 0 diff --git a/tests/v1/entrypoints/openai/responses/test_stateful.py b/tests/v1/entrypoints/openai/responses/test_stateful.py new file mode 100644 index 000000000000..a2d581ef7ced --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_stateful.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio + +import openai +import pytest + + +@pytest.mark.asyncio +async def test_store(client: openai.AsyncOpenAI): + # By default, store is True. + response = await client.responses.create(input="Hello!") + assert response.status == "completed" + + # Retrieve the response. + response = await client.responses.retrieve(response.id) + assert response.status == "completed" + + # Test store=False. + response = await client.responses.create( + input="Hello!", + store=False, + ) + assert response.status == "completed" + + # The response should not be found. + with pytest.raises(openai.NotFoundError, + match="Response with id .* not found."): + await client.responses.retrieve(response.id) + + +@pytest.mark.asyncio +async def test_background(client: openai.AsyncOpenAI): + # NOTE: This query should be easy enough for the model to answer + # within the 10 seconds. + response = await client.responses.create( + input="Hello!", + background=True, + ) + assert response.status == "queued" + + max_retries = 10 + for _ in range(max_retries): + await asyncio.sleep(1) + response = await client.responses.retrieve(response.id) + if response.status != "queued": + break + print(response) + + assert response.status == "completed" + + +@pytest.mark.asyncio +async def test_background_error(client: openai.AsyncOpenAI): + with pytest.raises( + openai.BadRequestError, + match="background can only be used when `store` is true"): + _ = await client.responses.create( + input="What is 13 * 24?", + background=True, + store=False, + ) + + +@pytest.mark.asyncio +async def test_background_cancel(client: openai.AsyncOpenAI): + response = await client.responses.create( + input="Write a long story about a cat.", + background=True, + ) + assert response.status == "queued" + + # Cancel the response before it is completed. + # FIXME: This test can be flaky. + await asyncio.sleep(0.5) + response = await client.responses.cancel(response.id) + assert response.status == "cancelled" + + # Make sure the response status remains unchanged. + await asyncio.sleep(5) + response = await client.responses.retrieve(response.id) + assert response.status == "cancelled" + + +@pytest.mark.asyncio +async def test_cancel_completed(client: openai.AsyncOpenAI): + response = await client.responses.create(input="Hello") + assert response.status == "completed" + + with pytest.raises(openai.BadRequestError, + match="Cannot cancel a synchronous response."): + await client.responses.cancel(response.id) + + +@pytest.mark.asyncio +async def test_previous_response_id(client: openai.AsyncOpenAI): + response1 = await client.responses.create( + instructions="You are tested on your ability to retrieve the correct " + "information from the previous response.", + input="Hello, my name is John.") + + response2 = await client.responses.create( + input="Actually, my name is not John. My real name is Mark.", + previous_response_id=response1.id, + ) + + response3 = await client.responses.create( + input="What is my real name again? Answer in one word.", + previous_response_id=response2.id, + ) + print(response3) + assert "Mark" in response3.output[-1].content[0].text + assert "John" not in response3.output[-1].content[0].text + + +@pytest.mark.asyncio +async def test_two_responses_with_same_prev_id(client: openai.AsyncOpenAI): + response1 = await client.responses.create( + instructions="You are tested on your ability to retrieve the correct " + "information from the previous response.", + input="Hello, my name is John.") + + # Both response 2 and 3 use response 1 as the previous response. + response2 = client.responses.create( + input="Actually, my name is not John. My name is Mark.", + previous_response_id=response1.id, + ) + response3 = client.responses.create( + input="What is my name again? Answer in one word.", + previous_response_id=response1.id, + ) + + _ = await response2 + response3_result = await response3 + print(response3_result) + assert "John" in response3_result.output[-1].content[0].text + assert "Mark" not in response3_result.output[-1].content[0].text diff --git a/tests/v1/entrypoints/openai/responses/test_structured_output.py b/tests/v1/entrypoints/openai/responses/test_structured_output.py new file mode 100644 index 000000000000..c4c43a87b601 --- /dev/null +++ b/tests/v1/entrypoints/openai/responses/test_structured_output.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json + +import openai +import pytest +from pydantic import BaseModel + + +@pytest.mark.asyncio +async def test_structured_output(client: openai.AsyncOpenAI): + response = await client.responses.create( + input=[ + { + "role": "system", + "content": "Extract the event information." + }, + { + "role": "user", + "content": + "Alice and Bob are going to a science fair on Friday.", + }, + ], + text={ + "format": { + "type": "json_schema", + "name": "calendar_event", + "schema": { + "type": "object", + "properties": { + "event_name": { + "type": "string" + }, + "date": { + "type": "string" + }, + "participants": { + "type": "array", + "items": { + "type": "string" + } + }, + }, + "required": ["event_name", "date", "participants"], + "additionalProperties": False, + }, + "description": "A calendar event.", + "strict": True, + } + }, + ) + print(response) + + # NOTE: The JSON schema is applied to the output text, not reasoning. + output_text = response.output[-1].content[0].text + event = json.loads(output_text) + + assert event["event_name"].lower() == "science fair" + assert event["date"] == "Friday" + participants = event["participants"] + assert len(participants) == 2 + assert participants[0] == "Alice" + assert participants[1] == "Bob" + + +@pytest.mark.asyncio +async def test_structured_output_with_parse(client: openai.AsyncOpenAI): + + class CalendarEvent(BaseModel): + event_name: str + date: str + participants: list[str] + + response = await client.responses.parse( + model=None, + instructions="Extract the event information.", + input="Alice and Bob are going to a science fair on Friday.", + text_format=CalendarEvent, + ) + print(response) + + # The output is successfully parsed. + event = response.output_parsed + assert event is not None + + # The output is correct. + assert event.event_name.lower() == "science fair" + assert event.date == "Friday" + participants = event.participants + assert len(participants) == 2 + assert participants[0] == "Alice" + assert participants[1] == "Bob" diff --git a/tests/v1/entrypoints/openai/test_completion.py b/tests/v1/entrypoints/openai/test_completion.py index 776fd42bbc35..2462f8f9f10c 100644 --- a/tests/v1/entrypoints/openai/test_completion.py +++ b/tests/v1/entrypoints/openai/test_completion.py @@ -7,6 +7,7 @@ import pytest import pytest_asyncio import regex as re +import requests from openai import BadRequestError from tests.utils import RemoteOpenAIServer @@ -26,7 +27,8 @@ def default_server_args(): "2048", "--max-num-seqs", "128", - "--enforce-eager" + "--enforce-eager", + "--enable-prompt-tokens-details", ] @@ -679,3 +681,17 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str): prompt=prompt, extra_body={"guided_grammar": invalid_simplified_sql_grammar}, ) + + +@pytest.mark.asyncio +async def test_completion_with_empty_prompt_embeds( + client: openai.AsyncOpenAI) -> None: + """Test completion with empty prompt embeds.""" + payload: dict[str, list] = {"prompt_embeds": []} + headers: dict[str, str] = {"Content-Type": "application/json"} + # base_url = http://localhost:8000/v1/completions + response = requests.post(f"{client.base_url}completions", + headers=headers, + json=payload) + assert response.status_code == 200, ( + f"Expected status code 200, got {response.status_code}. ") diff --git a/tests/v1/entrypoints/openai/test_multi_api_servers.py b/tests/v1/entrypoints/openai/test_multi_api_servers.py index e84b5e3095d0..f7c31b0c4377 100644 --- a/tests/v1/entrypoints/openai/test_multi_api_servers.py +++ b/tests/v1/entrypoints/openai/test_multi_api_servers.py @@ -2,136 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import os -import re import openai # use the official client for correctness check import pytest import pytest_asyncio -import requests from tests.utils import RemoteOpenAIServer +from tests.v1.test_utils import check_request_balancing MODEL_NAME = "ibm-research/PowerMoE-3b" DP_SIZE = os.getenv("DP_SIZE", "1") -def get_prometheus_metrics( - server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: - """Fetch and parse Prometheus metrics from the /metrics endpoint. - - Returns: - Dict mapping metric names to their values grouped by labels. - For example: {"vllm:request_success": { - "engine=0": 5.0, "engine=1": 3.0} - } - """ - try: - response = requests.get(server.url_for("metrics"), timeout=10) - response.raise_for_status() - - metrics: dict[str, dict[str, float]] = {} - - # Regex patterns for Prometheus metrics - metric_with_labels = re.compile( - r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$') - metric_simple = re.compile( - r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$') - - for line in response.text.split('\n'): - line = line.strip() - # Skip comments and empty lines - if not line or line.startswith('#'): - continue - - # Try to match metric with labels first - match = metric_with_labels.match(line) - if match: - metric_name, labels_part, value_str = match.groups() - try: - value = float(value_str) - if metric_name not in metrics: - metrics[metric_name] = {} - metrics[metric_name][f'{{{labels_part}}}'] = value - except ValueError: - continue - else: - # Try simple metric without labels - match = metric_simple.match(line) - if match: - metric_name, value_str = match.groups() - try: - value = float(value_str) - if metric_name not in metrics: - metrics[metric_name] = {} - metrics[metric_name][''] = value - except ValueError: - continue - - return metrics - except Exception as e: - pytest.fail(f"Failed to fetch Prometheus metrics: {e}") - return {} - - -def get_engine_request_counts( - metrics: dict[str, dict[str, float]]) -> dict[str, float]: - """Extract request counts per engine from Prometheus metrics. - - Returns: - Dict mapping engine indices to request counts. - For example: {"0": 15.0, "1": 12.0} - """ - engine_counts = {} - - # Look for request success metrics with engine labels - success_metrics = metrics.get("vllm:request_success_total", {}) - engine_pattern = re.compile(r'engine="([^"]*)"') - - for labels, count in success_metrics.items(): - # Extract engine ID from labels using regex - match = engine_pattern.search(labels) - if match: - engine_id = match.group(1) - if engine_id not in engine_counts: - engine_counts[engine_id] = 0.0 - engine_counts[engine_id] += count - - return engine_counts - - -def check_request_balancing(server: RemoteOpenAIServer): - """Check request balancing via Prometheus metrics if DP_SIZE > 1. - - Args: - server: The RemoteOpenAIServer instance - """ - dp_size = int(DP_SIZE) - if dp_size <= 1: - return - - # Get metrics after all requests are completed - metrics = get_prometheus_metrics(server) - engine_counts = get_engine_request_counts(metrics) - - # Check that multiple engines received requests - engines_with_requests = [ - engine for engine, count in engine_counts.items() if count > 0 - ] - assert len(engines_with_requests) == dp_size, ( - f"Expected requests to be distributed across multiple engines," - f" but only engine(s) {engines_with_requests} received " - f"requests. Engine counts: {engine_counts}") - - # Verify that the load is reasonably balanced - # (no engine should handle all requests) - total_requests = sum(engine_counts.values()) - - for count in engine_counts.values(): - assert count > total_requests // (dp_size + 1), ( - f"requests are imbalanced: {engine_counts}") - - @pytest.fixture(scope="module") def default_server_args(): return [ @@ -217,7 +100,7 @@ async def make_request(): assert all(completion is not None for completion in results) # Check request balancing via Prometheus metrics if DP_SIZE > 1 - check_request_balancing(server) + check_request_balancing(server, int(DP_SIZE)) @pytest.mark.asyncio @@ -295,4 +178,4 @@ async def make_streaming_request(): assert all(results), "Not all streaming requests completed successfully." # Check request balancing via Prometheus metrics if DP_SIZE > 1 - check_request_balancing(server) + check_request_balancing(server, int(DP_SIZE)) diff --git a/vllm/prompt_adapter/__init__.py b/tests/v1/kv_connector/__init__.py similarity index 100% rename from vllm/prompt_adapter/__init__.py rename to tests/v1/kv_connector/__init__.py diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 72848c1a706e..b1780d8a9af8 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -3,16 +3,10 @@ import filecmp import shutil import tempfile -from collections import defaultdict from pathlib import Path from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa - SharedStorageConnector) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.config import KVTransferConfig MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -25,62 +19,6 @@ SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) -class TestSharedStorageConnector(SharedStorageConnector): - - def __init__(self, config: VllmConfig, role): - self.name = config.kv_transfer_config.kv_connector_extra_config["name"] - self._connector = SharedStorageConnector(config, role) - self.call_record: dict[str, int] = defaultdict(int) - # Use a unique temp file per connector - self._event_file = tempfile.gettempdir( - ) + f"/connector_{self.name}-{self.role.name}_events.log" - # Start with an empty file - with open(self._event_file, "w") as _: - pass - - def __getattribute__(self, name): - if name in ("_connector", "call_record", "name", "_event_file", - "__class__", "__dict__", "__getattribute__", - "__init__"): # avoid recursion - return object.__getattribute__(self, name) - if not hasattr(self._connector, name): - return object.__getattribute__(self, name) - attr = getattr(self._connector, name) - - # Intercept calls to the connector interface and write an event - # for each one to a file, which can be read back in the main test proc. - if callable(attr): - - def wrapper(*args, **kwargs): - self.call_record[name] += 1 - - # Include args that we're interested in - to_log = [name] - for arg in args: - if isinstance(arg, int): - to_log.append(str(arg)) - elif isinstance(arg, KVCacheBlocks): - to_log.append( - f"num_blocks={[len(b) for b in arg.blocks]}") - - # Log the event as a line to the file - try: - with open(self._event_file, "a") as f: - f.write(' '.join(to_log) + "\n") - except Exception as e: - print(f"[ERROR] Could not log event {name} " - f"for {self.name}: {e}") - return attr(*args, **kwargs) - - return wrapper - return attr - - -KVConnectorFactory.register_connector("TestSharedStorageConnector", - TestSharedStorageConnector.__module__, - TestSharedStorageConnector.__name__) - - # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -115,19 +53,27 @@ def test_multi_shared_storage_connector_consistency(): kv_role="kv_both", kv_connector_extra_config={ "connectors": [{ - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", + "kv_connector": + "TestSharedStorageConnector", + "kv_role": + "kv_both", "kv_connector_extra_config": { "shared_storage_path": str(storage_1_path), "name": "storage1", - } + }, + "kv_connector_module_path": + "tests.v1.kv_connector.unit.utils", }, { - "kv_connector": "TestSharedStorageConnector", - "kv_role": "kv_both", + "kv_connector": + "TestSharedStorageConnector", + "kv_role": + "kv_both", "kv_connector_extra_config": { "shared_storage_path": str(storage_2_path), "name": "storage2", - } + }, + "kv_connector_module_path": + "tests.v1.kv_connector.unit.utils", }] }, ) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e30a250449aa..99bde919c725 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1,22 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import tempfile +import textwrap import time -import uuid -from collections import defaultdict -from typing import Optional from unittest.mock import patch import pytest +import ray +from vllm import LLM +from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import ( KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata, NixlConnectorWorker) from vllm.forward_context import ForwardContext +from vllm.mocks.mock_nixl_connector import FakeNixlWrapper +from vllm.sampling_params import SamplingParams from .utils import create_request, create_scheduler, create_vllm_config +def _make_stub_pkg() -> str: + """Return a directory that makes + `from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.""" + td = tempfile.mkdtemp() + pkg_root = os.path.join(td, "nixl", "_api") + os.makedirs(pkg_root, exist_ok=True) + + stub = textwrap.dedent("""\ + # Forward the real FakeNixlWrapper that the driver already defined. + print("In fake package") + from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent + """) + with open(os.path.join(pkg_root, "__init__.py"), "w") as f: + f.write(stub) + + # touch parent package + open(os.path.join(td, "nixl", "__init__.py"), "w").close() + return td + + def test_basic_interface(): """Unit test for basic NixlConnector interface functionality.""" @@ -41,9 +66,9 @@ def test_basic_interface(): assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.requests) == 1 - assert request_id in kv_connector_metadata.requests - req_meta = kv_connector_metadata.requests[request_id] + assert len(kv_connector_metadata.reqs_to_recv) == 1 + assert request_id in kv_connector_metadata.reqs_to_recv + req_meta = kv_connector_metadata.reqs_to_recv[request_id] for block_id, block in zip( req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator. @@ -78,83 +103,12 @@ def test_prompt_less_than_block_size(): kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.requests) == 0 + assert len(kv_connector_metadata.reqs_to_recv) == 0 # This request should be scheduled regularly. assert len(scheduler_output.scheduled_new_reqs) == 1 -class FakeNixlWrapper: - """Mock implementation of NixlWrapper for testing. - - We don't inherit from nixl._api.nixl_agent because nixl may not be - installed. - """ - - AGENT_METADATA = b"fake_agent_metadata" - REMOTE_AGENT_NAME = "remote_agent" - - def __init__(self, agent_name: str, *args, **kwargs): - self._cycles_before_xfer_done = 0 - self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( - lambda: 0) - - def get_reg_descs(self, caches_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in caches_data] - - def register_memory(self, descs) -> None: - pass - - def get_xfer_descs(self, blocks_data, memory_type: str) -> list: - return [str(uuid.uuid4()) for _ in blocks_data] - - def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: - return uuid.uuid4().int - - def get_agent_metadata(self) -> bytes: - return self.AGENT_METADATA - - def add_remote_agent(self, agent_metadata: bytes) -> str: - return self.REMOTE_AGENT_NAME - - def get_new_notifs(self) -> dict[str, list[bytes]]: - # Used to collect done_sending, which we don't test yet. - return {} - - def check_xfer_state(self, handle: int) -> str: - if self._check_xfer_state_cycles[ - handle] >= self._cycles_before_xfer_done: - return "DONE" - self._check_xfer_state_cycles[handle] += 1 - return "PROC" - - def release_xfer_handle(self, handle: int) -> None: - pass - - def send_notif(self, agent_name: str, notif_msg: bytes) -> None: - pass - - def make_prepped_xfer(self, - xfer_type: str, - local_xfer_side_handle: int, - local_block_descs_ids: list[int], - remote_xfer_side_handle: int, - remote_block_descs_ids: list[int], - notif_msg: Optional[bytes] = None) -> int: - return uuid.uuid4().int - - def transfer(self, handle: int) -> str: - return "PROC" - - ############################################################ - # Follow are for changing the behavior during testing. - ############################################################ - - def set_cycles_before_xfer_done(self, cycles: int): - """Set the number of cycles before a transfer is considered done.""" - self._cycles_before_xfer_done = cycles - - class FakeNixlConnectorWorker(NixlConnectorWorker): REMOTE_ENGINE_ID = "remote_engine" @@ -163,8 +117,8 @@ def __init__(self, *args, hand_shake_latency: float = 1.8, **kwargs): super().__init__(*args, **kwargs) self._hand_shake_latency = hand_shake_latency - def _nixl_handshake(self, host: str, port: int, - remote_tp_size: int) -> dict[int, str]: + def _nixl_handshake(self, host: str, port: int, remote_tp_size: int, + expected_engine_id: str) -> dict[int, str]: # Mimic slow _nixl_handshake, as well as bypass zmq communication. time.sleep(self._hand_shake_latency) # These should've been done in register_kv_caches(), called by @@ -174,6 +128,8 @@ def _nixl_handshake(self, host: str, port: int, self.num_blocks = 1 self.dst_num_blocks[self.engine_id] = self.num_blocks + assert expected_engine_id == self.REMOTE_ENGINE_ID + remote_agent_name = self.add_remote_agent( NixlAgentMetadata( engine_id=self.REMOTE_ENGINE_ID, @@ -371,3 +327,86 @@ def test_concurrent_load_kv( if cnt_finished_reqs == total_reqs: return raise TimeoutError("Took too long to complete async handshake.") + + +# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which +# we put here is important. First run ray, it will clean up the resources, then +# the rest of the tests. +@pytest.mark.parametrize("distributed_executor_backend", ["ray", None]) +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper) +def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend): + """ + Test lifecycle of an aborted Remote Prefill request hitting the timeout. + -----> P + | {process request} + <-/--- | {result is NOT delivered, eg proxy is down} + | + | + | {eventually free blocks} + """ + model_name = "Qwen/Qwen3-0.6B" + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + timeout = 6 + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout)) + + # Build runtime_env only if we’re using Ray + if distributed_executor_backend == "ray": + runtime_env = { + "working_dir": _make_stub_pkg(), # ship stub package + "env_vars": { + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout), + }, + } + ray.init(runtime_env=runtime_env) + + llm = LLM( + model=model_name, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + distributed_executor_backend=distributed_executor_backend, + ) + remote_prefill_opts = { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": None, + "remote_port": None, + } + # Simulate sidecar request + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=1, + extra_args={"kv_transfer_params": remote_prefill_opts}) + scheduler = llm.llm_engine.engine_core.engine_core.scheduler + req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[ + 0].req_to_blocks + + padding = "Just making this request a little longer so that we're sure " + "we're not hitting the small-request lower bound beneath which we don't " + "actually trigger the whole kv transfer, but rather just recompute the " + "blocks on D." + _ = llm.generate([f"What is the capital of Japan? {padding}"], + sampling_params) + + # Request finished but not freed + assert '0' in scheduler.finished_req_ids and '0' in req_to_blocks + # Some other request, 0 still not freed + _ = llm.generate([f"What is the capital of Italy? {padding}"], + sampling_params) + assert '0' in req_to_blocks + assert '1' in scheduler.finished_req_ids and '1' in req_to_blocks + + # Wait for timeout and trigger another scheduler loop + time.sleep(timeout) + _ = llm.generate([f"What is the capital of France? {padding}"], + sampling_params) + # Request-0 times out and is cleared! + assert '0' not in req_to_blocks diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py new file mode 100644 index 000000000000..cad73f68e9f1 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from concurrent.futures import Future +from typing import Optional + +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator +from vllm.v1.outputs import ModelRunnerOutput + + +class DummyModelRunnerOutput(ModelRunnerOutput): + + def __init__(self, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None): + self.finished_sending = finished_sending + self.finished_recving = finished_recving + + +def test_aggregate_workers_output(): + aggregator = KVOutputAggregator(world_size=2) + + output1 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + output2 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving is None + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving=None) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + assert aggregated.finished_sending == {'req1'} + assert aggregated.finished_recving is None + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + + aggregated = aggregator.aggregate([output1, output2]) + + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving == {'req2'} + + +def test_async_aggregate_workers_output(): + aggregator = KVOutputAggregator(world_size=2) + + future1: Future[DummyModelRunnerOutput] = Future() + future2: Future[DummyModelRunnerOutput] = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + output2 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving is None + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving=None) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + assert aggregated.finished_sending == {'req1'} + assert aggregated.finished_recving is None + + future1 = Future() + future2 = Future() + result_future = aggregator.async_aggregate([future1, future2]) + + output1 = DummyModelRunnerOutput(finished_sending=None, + finished_recving=None) + output2 = DummyModelRunnerOutput(finished_sending={'req1'}, + finished_recving={'req2'}) + future1.set_result(output1) + future2.set_result(output2) + + assert result_future.done() + aggregated = result_future.result() + assert aggregated is output1 + assert aggregated.finished_sending is None + assert aggregated.finished_recving == {'req2'} diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 983d900606fc..cf20d44fbaae 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import tempfile +from collections import defaultdict from typing import Any, Optional import torch @@ -7,6 +9,11 @@ from vllm import SamplingParams from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) @@ -187,3 +194,58 @@ def create_model_runner_output( finished_sending=finished_sending, finished_recving=finished_recving, ) + + +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}-{self.role.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + # Intercept calls to the connector interface and write an event + # for each one to a file, which can be read back in the main test proc. + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + + # Include args that we're interested in + to_log = [name] + for arg in args: + if isinstance(arg, int): + to_log.append(str(arg)) + elif isinstance(arg, KVCacheBlocks): + to_log.append( + f"num_blocks={[len(b) for b in arg.blocks]}") + + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(' '.join(to_log) + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__, + TestSharedStorageConnector.__name__) diff --git a/tests/v1/metrics/test_ray_metrics.py b/tests/v1/metrics/test_ray_metrics.py index 0898ae65e7cd..92f6c6f0e89c 100644 --- a/tests/v1/metrics/test_ray_metrics.py +++ b/tests/v1/metrics/test_ray_metrics.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + import pytest import ray +from vllm.config import ModelDType from vllm.sampling_params import SamplingParams from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger @@ -27,7 +30,7 @@ def use_v1_only(monkeypatch): def test_engine_log_metrics_ray( example_prompts, model: str, - dtype: str, + dtype: ModelDType, max_tokens: int, ) -> None: """ Simple smoke test, verifying this can be used without exceptions. @@ -37,11 +40,14 @@ def test_engine_log_metrics_ray( class EngineTestActor: async def run(self): - engine_args = AsyncEngineArgs( - model=model, - dtype=dtype, - disable_log_stats=False, - ) + # Set environment variable inside the Ray actor since environment + # variables from pytest fixtures don't propagate to Ray actors + os.environ['VLLM_USE_V1'] = '1' + + engine_args = AsyncEngineArgs(model=model, + dtype=dtype, + disable_log_stats=False, + enforce_eager=True) engine = AsyncLLM.from_engine_args( engine_args, stat_loggers=[RayPrometheusStatLogger]) diff --git a/tests/v1/sample/test_logprobs.py b/tests/v1/sample/test_logprobs.py index 69180e6e5db4..680e2ce98bb2 100644 --- a/tests/v1/sample/test_logprobs.py +++ b/tests/v1/sample/test_logprobs.py @@ -12,6 +12,7 @@ assert_incr_detok_str_matches_non_incr_detok_str, compute_correct_cumulative_logprob, get_test_batch) from vllm import SamplingParams +from vllm.config import LogprobsMode from ...conftest import HfRunner, VllmRunner @@ -112,7 +113,7 @@ def _run_and_validate( max_tokens: int, do_apc: bool, ) -> None: - vllm_results = vllm_model.model.generate( + vllm_results = vllm_model.llm.generate( test_prompts, sampling_params=vllm_sampling_params) for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip( @@ -288,7 +289,7 @@ def test_get_logprobs_and_prompt_logprobs( """ with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - do_apc = vllm_model.model.llm_engine.cache_config.enable_prefix_caching + do_apc = vllm_model.llm.llm_engine.cache_config.enable_prefix_caching if do_apc and (temperature < 2.0 or batch_logprobs_composition != SAMPLE_PROMPT): # Skip some test-cases to save time. @@ -378,7 +379,7 @@ def test_none_logprobs(vllm_model, example_prompts, prompt_logprobs=None, temperature=0.0, ) - results_logprobs_none = vllm_model.model.generate( + results_logprobs_none = vllm_model.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_none, ) @@ -408,7 +409,7 @@ def test_zero_logprobs(vllm_model, example_prompts, logprobs=0, prompt_logprobs=0, temperature=0.0) - results_logprobs_zero = vllm_model.model.generate( + results_logprobs_zero = vllm_model.llm.generate( example_prompts, sampling_params=sampling_params_logprobs_zero) for i in range(len(results_logprobs_zero)): @@ -426,3 +427,45 @@ def test_zero_logprobs(vllm_model, example_prompts, # prompt token assert prompt_logprobs is not None assert len(prompt_token_ids) == len(prompt_logprobs) + + +@pytest.mark.parametrize( + "logprobs_mode", + ["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"]) +def test_logprobs_mode(logprobs_mode: LogprobsMode, + monkeypatch: pytest.MonkeyPatch): + """Test with LLM engine with different logprobs_mode. + For logprobs, we should have non-positive values. + For logits, we should expect at least one positive values. + """ + from vllm import LLM + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + llm = LLM( + "facebook/opt-125m", + max_logprobs=5, + enable_prefix_caching=False, + # 2 other llms alive during whole session + gpu_memory_utilization=0.05, + max_model_len=16, + logprobs_mode=logprobs_mode) + vllm_sampling_params = SamplingParams(logprobs=1) + results = llm.generate(["Hello world"], + sampling_params=vllm_sampling_params) + + total_token_with_logprobs = 0 + positive_values = 0 + for output in results[0].outputs: + for logprobs in output.logprobs: + for token_id in logprobs: + logprob = logprobs[token_id] + if "logprobs" in logprobs_mode: + assert logprob.logprob <= 0 + if logprob.logprob > 0: + positive_values = positive_values + 1 + total_token_with_logprobs = total_token_with_logprobs + 1 + assert total_token_with_logprobs >= len(results[0].outputs) + if "logits" in logprobs_mode: + assert positive_values > 0 + del llm diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index ac0f3eb58836..f53e1e1c485d 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -14,30 +14,30 @@ @pytest.fixture(scope="module") -def model() -> LLM: +def llm() -> LLM: # Disable prefix caching so that we can test prompt logprobs. # TODO remove this after https://github.com/vllm-project/vllm/pull/13949 # is merged return LLM(MODEL, enforce_eager=True, enable_prefix_caching=False) -def test_n_gt_1(model): +def test_n_gt_1(llm): """ParallelSampling is supported.""" params = SamplingParams(n=3) - outputs = model.generate(PROMPT, params) + outputs = llm.generate(PROMPT, params) assert len(outputs[0].outputs) == 3 -def test_best_of(model): +def test_best_of(llm): """Raise a ValueError since best_of is deprecated.""" params = SamplingParams(n=2, best_of=3) with pytest.raises(ValueError): - _ = model.generate(PROMPT, params) + _ = llm.generate(PROMPT, params) -def test_penalties(model): +def test_penalties(llm): """Check that we do not get errors if applied.""" params = SamplingParams( @@ -49,18 +49,18 @@ def test_penalties(model): top_p=0.5, top_k=3, ) - _ = model.generate(PROMPT, params) + _ = llm.generate(PROMPT, params) -def test_stop(model): +def test_stop(llm): """Check that we respect the stop words.""" - output = model.generate(PROMPT, SamplingParams(temperature=0)) + output = llm.generate(PROMPT, SamplingParams(temperature=0)) split_text = output[0].outputs[0].text.split() STOP_IDX = 5 params = SamplingParams(temperature=0, stop=split_text[STOP_IDX]) - output = model.generate(PROMPT, params) + output = llm.generate(PROMPT, params) new_split_text = output[0].outputs[0].text.split() # Output should not contain the stop word. @@ -69,40 +69,40 @@ def test_stop(model): params = SamplingParams(temperature=0, stop=split_text[STOP_IDX], include_stop_str_in_output=True) - output = model.generate(PROMPT, params) + output = llm.generate(PROMPT, params) new_split_text = output[0].outputs[0].text.split() # Output should contain the stop word. assert len(new_split_text) == STOP_IDX + 1 -def test_stop_token_ids(model): +def test_stop_token_ids(llm): """Check that we respect the stop token ids.""" - output = model.generate(PROMPT, SamplingParams(temperature=0)) + output = llm.generate(PROMPT, SamplingParams(temperature=0)) stop_token_id_0 = output[0].outputs[0].token_ids[5] stop_token_id_1 = output[0].outputs[0].token_ids[6] stop_token_ids = [stop_token_id_1, stop_token_id_0] params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) - output = model.generate(PROMPT, params) + output = llm.generate(PROMPT, params) assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 stop_token_ids = [stop_token_id_0, stop_token_id_1] params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) - output = model.generate(PROMPT, params) + output = llm.generate(PROMPT, params) assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 -def test_detokenize_false(model): +def test_detokenize_false(llm): """Check that detokenize=False option works.""" - output = model.generate(PROMPT, SamplingParams(detokenize=False)) + output = llm.generate(PROMPT, SamplingParams(detokenize=False)) assert len(output[0].outputs[0].token_ids) > 0 assert len(output[0].outputs[0].text) == 0 - output = model.generate( + output = llm.generate( PROMPT, SamplingParams(detokenize=False, logprobs=3, prompt_logprobs=3)) assert len(output[0].outputs[0].token_ids) > 0 @@ -118,28 +118,28 @@ def test_detokenize_false(model): assert all(lp.decoded_token is None for lp in logprobs.values()) -def test_bad_words(model): +def test_bad_words(llm): """Check that we respect bad words.""" - output = model.generate(PROMPT, SamplingParams(temperature=0)) + output = llm.generate(PROMPT, SamplingParams(temperature=0)) split_text = output[0].outputs[0].text.split() bad_words_1 = " ".join(split_text[:2]) params = SamplingParams(temperature=0, bad_words=[bad_words_1]) - output = model.generate(PROMPT, params) + output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text assert bad_words_1 not in new_text bad_words_2 = new_text.split()[-1] params = SamplingParams(temperature=0, bad_words=[bad_words_1, bad_words_2]) - output = model.generate(PROMPT, params) + output = llm.generate(PROMPT, params) new_text = output[0].outputs[0].text assert bad_words_1 not in new_text assert bad_words_2 not in new_text -def test_logits_processor(model): +def test_logits_processor(llm): """Check that we reject logits processor.""" # This sample logits processor gives infinite score to the i-th token, @@ -150,47 +150,45 @@ def pick_ith(token_ids, logits): return logits with pytest.raises(ValueError): - _ = model.generate(PROMPT, - SamplingParams(logits_processors=[pick_ith])) + _ = llm.generate(PROMPT, SamplingParams(logits_processors=[pick_ith])) -def test_allowed_token_ids(model): +def test_allowed_token_ids(llm): """Check that we can use allowed_token_ids.""" TOKEN_ID = 10 allowed_token_ids = [TOKEN_ID] - output = model.generate( - PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids)) + output = llm.generate(PROMPT, + SamplingParams(allowed_token_ids=allowed_token_ids)) assert output[0].outputs[0].token_ids[-1] == TOKEN_ID # Reject empty allowed_token_ids. with pytest.raises(ValueError): - _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[])) + _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[])) # Reject negative token id. with pytest.raises(ValueError): - _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1])) + _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[-1])) # Reject out of vocabulary. with pytest.raises(ValueError): - _ = model.generate(PROMPT, - SamplingParams(allowed_token_ids=[10000000])) + _ = llm.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000])) -def test_priority(model): +def test_priority(llm): """Check that we reject requests with priority.""" # Reject all allowed token ids with pytest.raises(ValueError): - _ = model.generate(PROMPT, priority=[1]) + _ = llm.generate(PROMPT, priority=[1]) -def test_seed(model): +def test_seed(llm): """Check that seed impacts randomness.""" - out_1 = model.generate(PROMPT, SamplingParams(seed=42)) - out_2 = model.generate(PROMPT, SamplingParams(seed=42)) - out_3 = model.generate(PROMPT, SamplingParams(seed=43)) + out_1 = llm.generate(PROMPT, SamplingParams(seed=42)) + out_2 = llm.generate(PROMPT, SamplingParams(seed=42)) + out_3 = llm.generate(PROMPT, SamplingParams(seed=43)) assert out_1[0].outputs[0].text == out_2[0].outputs[0].text assert out_1[0].outputs[0].text != out_3[0].outputs[0].text diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 5efab2c14407..5c74a286c4a9 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -6,6 +6,10 @@ import pytest import torch +from tests.v1.attention.utils import (BatchSpec, _Backend, + create_common_attn_metadata, + create_standard_kv_cache_spec, + get_attention_backend) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) @@ -64,13 +68,19 @@ def test_prepare_inputs(): """ device = torch.device(current_platform.device_type) - # a = 4, b = 7, c = 5 + # q1 = 4, q2 = 7, q3 = 5 # n1 = 1, n2 = 3, n3 = 2 - # Cumulative lengths: [0, 4, 11, 16] - cu_target_query_lens = torch.tensor([0, 4, 11, 16], - dtype=torch.int32, - device=device) + batch_spec = BatchSpec( + seq_lens=[4, 7, 5], + query_lens=[4, 7, 5], + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) # Rejected tokens per request: [1, 3, 2] num_rejected_tokens = torch.tensor([1, 3, 2], @@ -104,15 +114,13 @@ def test_prepare_inputs(): ], dtype=torch.int32, device=device) + proposer = _create_proposer("eagle", 1) - # n1 + n2 + n3 - a - b -c - num_tokens = cu_target_query_lens[-1].item() - num_rejected_tokens.sum( - ).item() + updated_metadata, token_indices = proposer.prepare_inputs( + common_attn_metadata, num_rejected_tokens.cpu()) - cu_num_tokens, token_indices = EagleProposer.prepare_inputs( - cu_target_query_lens, num_rejected_tokens, num_tokens) - - assert torch.equal(cu_num_tokens, expected_cu_num_tokens) + assert torch.equal(updated_metadata.query_start_loc, + expected_cu_num_tokens) assert token_indices.shape[0] == expected_cu_num_tokens[-1].item() assert torch.equal(token_indices, expected_token_indices) @@ -209,6 +217,7 @@ def test_propose(num_speculative_tokens): seq_len_2 = 3 total_tokens = seq_len_1 + seq_len_2 vocab_size = 100 + seq_lens = [seq_len_1, seq_len_2] # Create proposer first so we can use its actual hidden_size proposer = _create_proposer("eagle", num_speculative_tokens) @@ -270,9 +279,16 @@ def create_deterministic_logits(token_ids): proposer.attn_layer_names = ["layer.0"] # Create input tensors - cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens], - dtype=torch.int32, - device=device) + batch_spec = BatchSpec( + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + block_size=16, + device=device, + ) target_token_ids = torch.randint(0, vocab_size, (total_tokens, ), @@ -284,25 +300,29 @@ def create_deterministic_logits(token_ids): target_hidden_states = torch.randn(total_tokens, hidden_size, device=device) - target_slot_mapping = torch.randint(0, - 100, (total_tokens, ), - device=device) next_token_ids = torch.randint(0, vocab_size, (batch_size, ), dtype=torch.int32, device=device) - block_table = torch.randint(0, 10, (batch_size, 10), device=device) - sampling_metadata = mock.MagicMock() - # Call the method under test + attn_metadata_builder_cls, _ = get_attention_backend( + _Backend.FLASH_ATTN_VLLM_V1) + attn_metadata_builder = attn_metadata_builder_cls( + kv_cache_spec=create_standard_kv_cache_spec(proposer.vllm_config), + vllm_config=proposer.vllm_config, + device=device, + ) + + # Mock runner for attention metadata building + proposer.runner = mock.MagicMock() + proposer.runner.attn_metadata_builders = [attn_metadata_builder] + result = proposer.propose(target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=block_table, + common_attn_metadata=common_attn_metadata, sampling_metadata=sampling_metadata) assert result.shape == (batch_size, num_speculative_tokens) diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index 64a41bec3791..6716d27f571f 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -90,8 +90,10 @@ class SimpleStatsLogger(StatLoggerBase): def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): stats_loggers[engine_index] = self - def record(self, scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats]): + def record(self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0): if iteration_stats: self.finished_req_count += len( iteration_stats.finished_requests) diff --git a/tests/v1/test_external_lb_dp.py b/tests/v1/test_external_lb_dp.py index 17952dfb0d91..98fefad1ff4a 100644 --- a/tests/v1/test_external_lb_dp.py +++ b/tests/v1/test_external_lb_dp.py @@ -17,7 +17,7 @@ # Number of data parallel ranks for external LB testing DP_SIZE = int(os.getenv("DP_SIZE", "2")) -# Default tensor parallell size to use +# Default tensor parallel size to use TP_SIZE = int(os.getenv("TP_SIZE", "1")) diff --git a/tests/v1/test_hybrid_lb_dp.py b/tests/v1/test_hybrid_lb_dp.py new file mode 100644 index 000000000000..08336489abee --- /dev/null +++ b/tests/v1/test_hybrid_lb_dp.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +import threading +import time +from contextlib import AsyncExitStack + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer +from tests.v1.test_utils import check_request_balancing +from vllm.platforms import Platform + +MODEL_NAME = "ibm-research/PowerMoE-3b" + +# Number of data parallel ranks for hybrid LB testing (4 total) +DP_SIZE = int(os.getenv("DP_SIZE", "4")) +# Default tensor parallel size to use +TP_SIZE = int(os.getenv("TP_SIZE", "1")) + +# Number of nodes (2 nodes, each with 2 DP ranks) +NUM_NODES = 2 +DP_SIZE_LOCAL = DP_SIZE // NUM_NODES # 2 ranks per node + + +class HybridLBServerManager: + """Manages hybrid data parallel vLLM server instances where each node + runs a single logical API server that balances requests only to the + DP engines running on that same node.""" + + def __init__(self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_size_local: int = DP_SIZE_LOCAL, + tp_size: int = TP_SIZE): + self.model_name = model_name + self.dp_size = dp_size + self.dp_size_local = dp_size_local + self.tp_size = tp_size + self.api_server_count = api_server_count + self.base_server_args = base_server_args + self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.server_threads: list[threading.Thread] = [] + self.num_nodes = dp_size // dp_size_local + + def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: + """Start all server instances for hybrid LB mode.""" + for node_id in range(self.num_nodes): + # Create server args for this specific node + server_args = self.base_server_args.copy() + + # Calculate start rank for this node + start_rank = node_id * self.dp_size_local + + # Add hybrid LB specific arguments + server_args.extend([ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size_local), + "--data-parallel-start-rank", + str(start_rank), + "--data-parallel-hybrid-lb", # Enable hybrid LB mode + "--tensor-parallel-size", + str(self.tp_size), + "--port", + str(8000 + node_id), # Different port for each node + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ]) + + # Use a thread to start each server to allow parallel initialization + def start_server(node: int, sargs: list[str]): + try: + # Calculate GPU devices for this node + gpus_per_node = self.dp_size_local * self.tp_size + gpu_start = node * gpus_per_node + gpu_end = gpu_start + gpus_per_node + + # Start the server + server = RemoteOpenAIServer( + self.model_name, + sargs, + auto_port=False, + env_dict={ + "CUDA_VISIBLE_DEVICES": + ",".join( + str(Platform.device_id_to_physical_device_id( + i)) for i in range(gpu_start, gpu_end)) + }) + server.__enter__() + print(f"Hybrid LB node {node} started successfully with " + f"{self.dp_size_local} local DP ranks and " + f"{self.api_server_count} API servers") + self.servers.append((server, sargs)) + except Exception as e: + print(f"Failed to start hybrid LB node {node}: {e}") + raise + + thread = threading.Thread(target=start_server, + args=(node_id, server_args)) + thread.start() + + self.server_threads.append(thread) + + # Wait for all servers to start + for thread in self.server_threads: + thread.join() + + # Give servers additional time to fully initialize and coordinate + time.sleep(3) + + if len(self.servers) != self.num_nodes: + raise Exception("Servers failed to start") + + return self.servers + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all server instances.""" + while self.servers: + try: + self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now +def servers(request, default_server_args): + api_server_count = request.param + with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, + default_server_args, DP_SIZE_LOCAL, + TP_SIZE) as server_list: + yield server_list + + +@pytest_asyncio.fixture +async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]): + # Create a client for each node (each node has its own API endpoint) + async with AsyncExitStack() as stack: + yield [ + await stack.enter_async_context(server.get_async_client()) + for server, _ in servers + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_hybrid_lb_completion(clients: list[openai.AsyncOpenAI], + servers: list[tuple[RemoteOpenAIServer, + list[str]]], + model_name: str) -> None: + + async def make_request(client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes early + # or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request to each node + for i, client in enumerate(clients): + result = await make_request(client) + assert result is not None + print( + f"Hybrid LB node {i} handled single completion request successfully" + ) + + await asyncio.sleep(0.5) + + # Send requests to all nodes - each should balance within its local DP ranks + num_requests_per_node = 25 # Total 50 requests across 2 nodes + all_tasks = [] + + for i, client in enumerate(clients): + tasks = [make_request(client) for _ in range(num_requests_per_node)] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + # Second burst of requests + all_tasks = [] + for i, client in enumerate(clients): + tasks = [make_request(client) for _ in range(num_requests_per_node)] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(completion is not None for completion in results) + + _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print( + f"Successfully completed hybrid LB test with {len(clients)} nodes " + f"({DP_SIZE_LOCAL} DP ranks each, API server count: {api_server_count})" + ) + + # Check request balancing within each node + for i, (server, _) in enumerate(servers): + print(f"Checking request balancing for node {i}") + check_request_balancing(server, DP_SIZE_LOCAL) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_hybrid_lb_completion_streaming(clients: list[ + openai.AsyncOpenAI], servers: list[tuple[RemoteOpenAIServer, list[str]]], + model_name: str) -> None: + prompt = "What is an LLM?" + + async def make_streaming_request(client: openai.AsyncOpenAI): + # Perform a non-streaming request to get the expected full output + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single request to each node + for i, client in enumerate(clients): + result = await make_streaming_request(client) + assert result is not None + print( + f"Hybrid LB node {i} handled single streaming request successfully" + ) + + await asyncio.sleep(0.5) + + # Send streaming requests to all nodes + num_requests_per_node = 25 # Total 50 requests across 2 nodes + all_tasks = [] + + for i, client in enumerate(clients): + tasks = [ + make_streaming_request(client) + for _ in range(num_requests_per_node) + ] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + # Second burst of streaming requests + all_tasks = [] + for i, client in enumerate(clients): + tasks = [ + make_streaming_request(client) + for _ in range(num_requests_per_node) + ] + all_tasks.extend(tasks) + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests_per_node * len(clients) + assert all(results), "Not all streaming requests completed successfully." + + _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print(f"Successfully completed hybrid LB streaming test with " + f"{len(clients)} nodes ({DP_SIZE_LOCAL} DP ranks each, " + f"API server count: {api_server_count})") + + # Check request balancing within each node + for i, (server, _) in enumerate(servers): + print(f"Checking streaming request balancing for node {i}") + check_request_balancing(server, DP_SIZE_LOCAL) diff --git a/tests/v1/test_internal_lb_dp.py b/tests/v1/test_internal_lb_dp.py new file mode 100644 index 000000000000..9aef4d5821e8 --- /dev/null +++ b/tests/v1/test_internal_lb_dp.py @@ -0,0 +1,639 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import os +import threading +import time + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer +from tests.v1.test_utils import check_request_balancing +from vllm.platforms import Platform + +MODEL_NAME = "ibm-research/PowerMoE-3b" + +# Number of data parallel ranks for multi-node internal LB testing +DP_SIZE = int(os.getenv("DP_SIZE", "2")) +# Default tensor parallel size to use +TP_SIZE = int(os.getenv("TP_SIZE", "1")) + +# Number of nodes to simulate +NUM_NODES = 2 + + +class MultinodeInternalLBServerManager: + """Manages multi-node data parallel vLLM server instances for internal + load balancer testing using --headless mode.""" + + def __init__(self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + dp_per_node: int = 1, + tp_size: int = TP_SIZE): + self.model_name = model_name + self.dp_size = dp_size + self.dp_per_node = dp_per_node + self.tp_size = tp_size + self.api_server_count = api_server_count + self.base_server_args = base_server_args + self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.server_threads: list[threading.Thread] = [] + + def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: + """Start all server instances for multi-node internal LB mode.""" + for rank in range(0, self.dp_size, self.dp_per_node): + # Create server args for this specific rank + server_args = self.base_server_args.copy() + + if rank == 0: + # Head node - runs API server and first DP rank + server_args.extend([ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_per_node), + "--tensor-parallel-size", + str(self.tp_size), + "--port", + "8000", # Single endpoint for all requests + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ]) + else: + # Secondary nodes - run in headless mode + server_args.extend([ + "--headless", + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_per_node), + "--data-parallel-start-rank", + str(rank), + "--tensor-parallel-size", + str(self.tp_size), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ]) + + # Use a thread to start each server to allow parallel initialization + def start_server(r: int, sargs: list[str]): + gpus_per_node = self.tp_size * self.dp_per_node + try: + # Start the server + server = RemoteOpenAIServer( + self.model_name, + sargs, + auto_port=False, + env_dict={ + "CUDA_VISIBLE_DEVICES": + ",".join( + str(Platform.device_id_to_physical_device_id( + i)) for i in range(r, r + gpus_per_node)) + }) + server.__enter__() + if r == 0: + print( + f"Head node (rank {r}) started successfully with " + f"{self.api_server_count} API servers") + else: + print(f"Headless node (rank {r}) started successfully") + self.servers.append((server, sargs)) + except Exception as e: + print(f"Failed to start server rank {r}: {e}") + raise + + thread = threading.Thread(target=start_server, + args=(rank, server_args)) + thread.start() + + self.server_threads.append(thread) + + # Wait for all servers to start + for thread in self.server_threads: + thread.join() + + # Give servers additional time to fully initialize and coordinate + time.sleep(3) + + if len(self.servers) != self.dp_size // self.dp_per_node: + raise Exception("Servers failed to start") + + return self.servers + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop all server instances.""" + while self.servers: + try: + self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + + +class APIOnlyServerManager: + """Manages API-only server (Node 0) and headless engines server (Node 1) + for testing separated API server and engine configuration.""" + + def __init__(self, + model_name: str, + dp_size: int, + api_server_count: int, + base_server_args: list, + tp_size: int = TP_SIZE): + self.model_name = model_name + self.dp_size = dp_size + self.tp_size = tp_size + self.api_server_count = api_server_count + self.base_server_args = base_server_args + self.servers: list[tuple[RemoteOpenAIServer, list[str]]] = [] + self.server_threads: list[threading.Thread] = [] + + def __enter__(self) -> list[tuple[RemoteOpenAIServer, list[str]]]: + """Start API-only server and headless engines server.""" + + # Start API-only server (Node 0) - no engines, only API server + api_server_args = self.base_server_args.copy() + api_server_args.extend([ + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + "0", # No engines on this node + "--tensor-parallel-size", + str(self.tp_size), + "--port", + "8000", + "--api-server-count", + str(self.api_server_count), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ]) + + # Start headless engines server (Node 1) - all engines, no API server + engines_server_args = self.base_server_args.copy() + engines_server_args.extend([ + "--headless", + "--data-parallel-size", + str(self.dp_size), + "--data-parallel-size-local", + str(self.dp_size), # All engines on this node + "--tensor-parallel-size", + str(self.tp_size), + "--data-parallel-address", + "127.0.0.1", + "--data-parallel-rpc-port", + "13345", + ]) + + # Use threads to start both servers in parallel + def start_api_server(): + try: + server = RemoteOpenAIServer( + self.model_name, + api_server_args, + auto_port=False, + env_dict={}) # No GPUs needed for API-only server + server.__enter__() + print(f"API-only server started successfully with " + f"{self.api_server_count} API servers") + self.servers.append((server, api_server_args)) + except Exception as e: + print(f"Failed to start API-only server: {e}") + raise + + def start_engines_server(): + try: + server = RemoteOpenAIServer( + self.model_name, + engines_server_args, + auto_port=False, + env_dict={ + "CUDA_VISIBLE_DEVICES": + ",".join( + str(Platform.device_id_to_physical_device_id(i)) + for i in range(self.dp_size * self.tp_size)) + }) + server.__enter__() + print(f"Headless engines server started successfully with " + f"{self.dp_size} engines") + self.servers.append((server, engines_server_args)) + except Exception as e: + print(f"Failed to start headless engines server: {e}") + raise + + # Start API server first + api_thread = threading.Thread(target=start_api_server) + api_thread.start() + self.server_threads.append(api_thread) + + # Start engines server second + engines_thread = threading.Thread(target=start_engines_server) + engines_thread.start() + self.server_threads.append(engines_thread) + + # Wait for both servers to start + for thread in self.server_threads: + thread.join() + + # Give servers additional time to fully initialize and coordinate + time.sleep(3) + + if len(self.servers) != 2: + raise Exception("Both servers failed to start") + + return self.servers + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop both server instances.""" + while self.servers: + try: + self.servers.pop()[0].__exit__(exc_type, exc_val, exc_tb) + except Exception as e: + print(f"Error stopping server: {e}") + + +@pytest.fixture(scope="module") +def default_server_args(): + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "2048", + "--max-num-seqs", + "128", + "--enforce-eager", + ] + + +@pytest.fixture(scope="module", params=[1, 4]) +def servers(request, default_server_args): + api_server_count = request.param + with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE, + api_server_count, + default_server_args, + DP_SIZE // NUM_NODES, + TP_SIZE) as server_list: + yield server_list + + +@pytest.fixture(scope="module", params=[1, 4]) +def api_only_servers(request, default_server_args): + """Fixture for API-only server + headless engines configuration.""" + api_server_count = request.param + with APIOnlyServerManager(MODEL_NAME, DP_SIZE, api_server_count, + default_server_args, TP_SIZE) as server_list: + yield server_list + + +@pytest_asyncio.fixture +async def client(servers: list[tuple[RemoteOpenAIServer, list[str]]]): + # For internal LB, we only connect to the head node (rank 0) + # which provides the single API endpoint + head_server = servers[0][0] + async with head_server.get_async_client() as client: + yield client + + +@pytest_asyncio.fixture +async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer, + list[str]]]): + """Client fixture for API-only server configuration.""" + # Connect to the API-only server (first server in the list) + api_server = api_only_servers[0][0] + async with api_server.get_async_client() as client: + yield client + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_multinode_dp_completion(client: openai.AsyncOpenAI, + servers: list[tuple[RemoteOpenAIServer, + list[str]]], + model_name: str) -> None: + + async def make_request(): + completion = await client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes early + # or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request + result = await make_request() + assert result is not None + print( + "Multi-node internal LB handled single completion request successfully" + ) + + await asyncio.sleep(0.5) + + # Send multiple requests - internal LB should distribute across DP ranks + num_requests = 50 + all_tasks = [make_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + # Second burst of requests + all_tasks = [make_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print(f"Successfully completed multi-node internal LB test with " + f"{len(servers)} DP ranks (API server count: {api_server_count})") + + # Check request balancing via Prometheus metrics + head_server = servers[0][0] + check_request_balancing(head_server, DP_SIZE) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_multinode_dp_completion_streaming(client: openai.AsyncOpenAI, + servers: list[ + tuple[RemoteOpenAIServer, + list[str]]], + model_name: str) -> None: + prompt = "What is an LLM?" + + async def make_streaming_request(): + # Perform a non-streaming request to get the expected full output + single_completion = await client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single streaming request + result = await make_streaming_request() + assert result is not None + print( + "Multi-node internal LB handled single streaming request successfully") + + await asyncio.sleep(0.5) + + # Send multiple streaming requests - internal LB should distribute across + # DP ranks + num_requests = 50 + all_tasks = [make_streaming_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + # Second burst of streaming requests + all_tasks = [make_streaming_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(results), "Not all streaming requests completed successfully." + + _, server_args = servers[0] + api_server_count = ( + server_args.count('--api-server-count') + and server_args[server_args.index('--api-server-count') + 1] or 1) + print(f"Successfully completed multi-node internal LB streaming test with " + f"{len(servers)} DP ranks (API server count: {api_server_count})") + + # Check request balancing via Prometheus metrics + head_server = servers[0][0] + check_request_balancing(head_server, DP_SIZE) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_api_only_multinode_dp_completion( + api_only_client: openai.AsyncOpenAI, + api_only_servers: list[tuple[RemoteOpenAIServer, + list[str]]], model_name: str) -> None: + """Test API-only server with all engines on separate headless server.""" + + async def make_request(): + completion = await api_only_client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=10, + temperature=1.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + # The exact number of tokens can vary slightly with temperature=1.0, + # so we check for a reasonable minimum length. + assert len(choice.text) >= 1 + # Finish reason might not always be 'length' if the model finishes + # early or due to other reasons, especially with high temperature. + # So, we'll accept 'length' or 'stop'. + assert choice.finish_reason in ("length", "stop") + + # Token counts can also vary, so we check they are positive. + assert completion.usage.completion_tokens > 0 + assert completion.usage.prompt_tokens > 0 + assert completion.usage.total_tokens > 0 + return completion + + # Test single request + result = await make_request() + assert result is not None + print("API-only server handled single completion request successfully") + + await asyncio.sleep(0.5) + + # Send multiple requests - should be distributed across engines on + # headless server + num_requests = 50 + all_tasks = [make_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + await asyncio.sleep(0.5) + + # Second burst of requests + all_tasks = [make_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(completion is not None for completion in results) + + _, api_server_args = api_only_servers[0] + api_server_count = ( + api_server_args.count('--api-server-count') + and api_server_args[api_server_args.index('--api-server-count') + 1] + or 1) + print(f"Successfully completed API-only multi-node test with {DP_SIZE} " + f"engines on headless server (API server count: {api_server_count})") + + # Check request balancing via Prometheus metrics + api_server = api_only_servers[0][0] + check_request_balancing(api_server, DP_SIZE) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_api_only_multinode_dp_completion_streaming( + api_only_client: openai.AsyncOpenAI, + api_only_servers: list[tuple[RemoteOpenAIServer, + list[str]]], model_name: str) -> None: + """Test API-only server streaming with all engines on separate + headless server.""" + prompt = "What is an LLM?" + + async def make_streaming_request(): + # Perform a non-streaming request to get the expected full output + single_completion = await api_only_client.completions.create( + model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + ) + single_output = single_completion.choices[0].text + + # Perform the streaming request + stream = await api_only_client.completions.create(model=model_name, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True) + chunks: list[str] = [] + finish_reason_count = 0 + last_chunk = None + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + last_chunk = chunk # Keep track of the last chunk + + # finish reason should only return in the last block for OpenAI API + assert finish_reason_count == 1, ( + "Finish reason should appear exactly once.") + assert last_chunk is not None, ( + "Stream should have yielded at least one chunk.") + assert last_chunk.choices[ + 0].finish_reason == "length", "Finish reason should be 'length'." + # Check that the combined text matches the non-streamed version. + assert "".join( + chunks + ) == single_output, "Streamed output should match non-streamed output." + return True # Indicate success for this request + + # Test single streaming request + result = await make_streaming_request() + assert result is not None + print("API-only server handled single streaming request successfully") + + await asyncio.sleep(0.5) + + # Send multiple streaming requests - should be distributed across engines + num_requests = 50 + all_tasks = [make_streaming_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(results), "Not all streaming requests completed successfully." + + await asyncio.sleep(0.5) + + # Second burst of streaming requests + all_tasks = [make_streaming_request() for _ in range(num_requests)] + + results = await asyncio.gather(*all_tasks) + assert len(results) == num_requests + assert all(results), "Not all streaming requests completed successfully." + + _, api_server_args = api_only_servers[0] + api_server_count = ( + api_server_args.count('--api-server-count') + and api_server_args[api_server_args.index('--api-server-count') + 1] + or 1) + print(f"Successfully completed API-only streaming test with {DP_SIZE} " + f"engines on headless server (API server count: {api_server_count})") + + # Check request balancing via Prometheus metrics + api_server = api_only_servers[0][0] + check_request_balancing(api_server, DP_SIZE) diff --git a/tests/v1/test_oracle.py b/tests/v1/test_oracle.py index 7a7ba346a719..b4d4348c7fd9 100644 --- a/tests/v1/test_oracle.py +++ b/tests/v1/test_oracle.py @@ -40,12 +40,6 @@ def test_unsupported_configs(monkeypatch): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - with pytest.raises(NotImplementedError): - AsyncEngineArgs( - model=MODEL, - kv_cache_dtype="fp8", - ).create_engine_config() - with pytest.raises(NotImplementedError): AsyncEngineArgs( model=MODEL, @@ -112,9 +106,9 @@ def test_v1_llm_by_default(monkeypatch): m.delenv("VLLM_USE_V1") # Should default to V1 for supported config. - model = LLM(MODEL, enforce_eager=True, enable_lora=True) - print(model.generate("Hello my name is")) - assert hasattr(model.llm_engine, "engine_core") + llm = LLM(MODEL, enforce_eager=True, enable_lora=True) + print(llm.generate("Hello my name is")) + assert hasattr(llm.llm_engine, "engine_core") m.delenv("VLLM_USE_V1") diff --git a/tests/v1/test_utils.py b/tests/v1/test_utils.py index a3df882a9e29..0b892bd9dffd 100644 --- a/tests/v1/test_utils.py +++ b/tests/v1/test_utils.py @@ -1,9 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import re + +import pytest +import requests import torch -from vllm.v1.utils import bind_kv_cache +from tests.utils import RemoteOpenAIServer +from vllm.v1.worker.utils import bind_kv_cache def test_bind_kv_cache(): @@ -61,3 +66,122 @@ def test_bind_kv_cache_non_attention(): assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] + + +# Prometheus metrics utilities for testing + + +def get_prometheus_metrics( + server: RemoteOpenAIServer) -> dict[str, dict[str, float]]: + """Fetch and parse Prometheus metrics from the /metrics endpoint. + + Returns: + Dict mapping metric names to their values grouped by labels. + For example: {"vllm:request_success": { + "engine=0": 5.0, "engine=1": 3.0} + } + """ + try: + response = requests.get(server.url_for("metrics"), timeout=10) + response.raise_for_status() + + metrics: dict[str, dict[str, float]] = {} + + # Regex patterns for Prometheus metrics + metric_with_labels = re.compile( + r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$') + metric_simple = re.compile( + r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$') + + for line in response.text.split('\n'): + line = line.strip() + # Skip comments and empty lines + if not line or line.startswith('#'): + continue + + # Try to match metric with labels first + match = metric_with_labels.match(line) + if match: + metric_name, labels_part, value_str = match.groups() + try: + value = float(value_str) + if metric_name not in metrics: + metrics[metric_name] = {} + metrics[metric_name][f'{{{labels_part}}}'] = value + except ValueError: + continue + else: + # Try simple metric without labels + match = metric_simple.match(line) + if match: + metric_name, value_str = match.groups() + try: + value = float(value_str) + if metric_name not in metrics: + metrics[metric_name] = {} + metrics[metric_name][''] = value + except ValueError: + continue + + return metrics + except Exception as e: + pytest.fail(f"Failed to fetch Prometheus metrics: {e}") + return {} + + +def get_engine_request_counts( + metrics: dict[str, dict[str, float]]) -> dict[str, float]: + """Extract request counts per engine from Prometheus metrics. + + Returns: + Dict mapping engine indices to request counts. + For example: {"0": 15.0, "1": 12.0} + """ + engine_counts = {} + + # Look for request success metrics with engine labels + success_metrics = metrics.get("vllm:request_success_total", {}) + engine_pattern = re.compile(r'engine="([^"]*)"') + + for labels, count in success_metrics.items(): + # Extract engine ID from labels using regex + match = engine_pattern.search(labels) + if match: + engine_id = match.group(1) + if engine_id not in engine_counts: + engine_counts[engine_id] = 0.0 + engine_counts[engine_id] += count + + return engine_counts + + +def check_request_balancing(server: RemoteOpenAIServer, dp_size: int): + """Check request balancing via Prometheus metrics if dp_size > 1. + + Args: + server: The RemoteOpenAIServer instance + dp_size: Number of data parallel ranks + """ + if dp_size <= 1: + return + + # Get metrics after all requests are completed + metrics = get_prometheus_metrics(server) + engine_counts = get_engine_request_counts(metrics) + + # Check that multiple engines received requests + engines_with_requests = [ + engine for engine, count in engine_counts.items() if count > 0 + ] + assert len(engines_with_requests) == dp_size, ( + f"Expected requests to be distributed across multiple engines," + f" but only engine(s) {engines_with_requests} received " + f"requests. Engine counts: {engine_counts}") + + # Verify that the load is reasonably balanced + # (no engine should handle all requests) + total_requests = sum(engine_counts.values()) + + for count in engine_counts.values(): + assert count > total_requests // (dp_size + 1), ( + f"requests are imbalanced: {engine_counts}") diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py index fe65976a58a1..c8cd099a98cf 100644 --- a/tests/v1/tpu/test_basic.py +++ b/tests/v1/tpu/test_basic.py @@ -67,6 +67,7 @@ def test_basic( assert "1024" in output or "0, 1" in output +@pytest.mark.skip(reason="Temporarily disabled due to timeout") @pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a basic test for TPU only") @pytest.mark.parametrize("max_tokens", [8]) @@ -144,3 +145,35 @@ def test_gemma3_27b_with_text_input_and_tp( for output, answer in zip(vllm_outputs, answers): generated_text = output[1] assert answer in generated_text + + +@pytest.mark.skipif(not current_platform.is_tpu(), + reason="This is a basic test for TPU only") +def test_w8a8_quantization( + vllm_runner: type[VllmRunner], + monkeypatch: pytest.MonkeyPatch, +) -> None: + model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8" + max_tokens = 5 + tensor_parallel_size = 1 + max_num_seqs = 4 + + prompt = "The next numbers of the sequence " + ", ".join( + str(i) for i in range(1024)) + " are:" + example_prompts = [prompt] + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + with vllm_runner( + model, + max_num_batched_tokens=64, + max_model_len=4096, + gpu_memory_utilization=0.7, + max_num_seqs=max_num_seqs, + tensor_parallel_size=tensor_parallel_size) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, + max_tokens) + output = vllm_outputs[0][1] + + assert "1024" in output or "0, 1" in output diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py index e279edfffbc7..bfba3af57f71 100644 --- a/tests/v1/tpu/test_pallas.py +++ b/tests/v1/tpu/test_pallas.py @@ -50,6 +50,7 @@ class FakeAttentionLayer: slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64) max_num_reqs = 8 max_num_blocks_per_req = 8 + num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32) block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req), dtype=torch.int32) context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32) @@ -65,6 +66,7 @@ class FakeAttentionLayer: context_lens=context_lens, query_start_loc=query_start_loc, num_seqs=num_seqs, + num_kv_update_slices=num_kv_update_slices, num_slices_per_kv_cache_update_block=8, ) @@ -93,4 +95,6 @@ class FakeAttentionLayer: sm_scale=scale, sliding_window=sliding_window, soft_cap=logits_soft_cap, + k_scale=1.0, + v_scale=1.0, ) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index d13df553db62..7fec4782517c 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -3,15 +3,19 @@ import random +import numpy as np import pytest import torch from vllm.attention import Attention from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig, set_current_vllm_config) +from vllm.distributed.parallel_state import (init_distributed_environment, + initialize_model_parallel) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams -from vllm.utils import GiB_bytes +from vllm.utils import GiB_bytes, update_environment_variables from vllm.v1.core.kv_cache_utils import (estimate_max_model_len, get_kv_cache_config) from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, @@ -434,21 +438,38 @@ def rnd_stride_order(): assert all(not kv.is_contiguous() for kv in model_runner.kv_caches) +def test_update_config(model_runner): + # Simple update + model_runner.update_config({"load_config": {"load_format": "dummy"}}) + assert model_runner.load_config.load_format == "dummy" + # Raise error on non-existing config + with pytest.raises(AssertionError): + model_runner.update_config({"do_not_exist_config": "dummy"}) + + def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2): # In this test, model_runner loads model + weights in one go, while # model_runner_2 loads dummy weights first then load real weights inplace model_runner.load_model() original_load_format = model_runner_2.load_config.load_format - model_runner_2.load_config.load_format = "dummy" + model_runner_2.update_config({"load_config": {"load_format": "dummy"}}) model_runner_2.load_model() # Initial model loading with dummy weights assert str(model_runner.get_model().state_dict()) != str( model_runner_2.get_model().state_dict()) - model_runner_2.load_config.load_format = original_load_format - model_runner_2.load_model() # Load real weights inplace + model_runner_2.update_config( + {"load_config": { + "load_format": original_load_format + }}) + model_runner_2.reload_weights() # Load real weights inplace assert str(model_runner.get_model().state_dict()) == str( model_runner_2.get_model().state_dict()) +def test_reload_weights_before_load_model(model_runner): + with pytest.raises(AssertionError): + model_runner.reload_weights() + + def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(): torch.set_default_dtype(torch.float16) layer_0 = "model.layers.0.self_attn.attn" @@ -674,3 +695,147 @@ def test_init_kv_cache_with_kv_sharing_valid(): assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 + + +def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): + ''' + The GPU model runner creates different views into the + KVCacheTensors for the attention and mamba layers + (via _reshape_kv_cache_tensors function). This test verifies + that the views are compatible: writing a mamba block + will not corrupt an attention block and vice-versa + ''' + + current_platform.seed_everything(42) + + update_environment_variables({ + 'RANK': "0", + 'LOCAL_RANK': "0", + 'WORLD_SIZE': "1", + 'MASTER_ADDR': 'localhost', + 'MASTER_PORT': '12345', + }) + init_distributed_environment() + initialize_model_parallel(tensor_model_parallel_size=1) + torch.set_default_dtype(torch.float16) + + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + ) + model_config = ModelConfig( + model="ibm-granite/granite-4.0-tiny-preview", + dtype="float16", + ) + cache_config = CacheConfig( + block_size=BLOCK_SIZE, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + ) + parallel_config = ParallelConfig() + vllm_config = VllmConfig( + model_config=model_config, + cache_config=cache_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + ) + + layer_0 = "model.layers.0.self_attn.attn" + layer_1 = "model.layers.1.self_attn.attn" + layer_2 = "model.layers.2.mixer" + layer_3 = "model.layers.3.mixer" + layer_4 = "model.layers.4.mixer" + layer_5 = "model.layers.5.mixer" + + with set_current_vllm_config(vllm_config): + hf_config = vllm_config.model_config.hf_config + fwd_context = {} + for key in [layer_0, layer_1]: + fwd_context[key] = Attention( + num_heads=model_config.get_num_attention_heads( + parallel_config), + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + scale=1.0, + prefix=key, + ) + for key in [layer_2, layer_3, layer_4, layer_5]: + fwd_context[key] = MambaMixer2( + hidden_size = hf_config.hidden_size, + ssm_state_size = hf_config.mamba_d_state, + conv_kernel_size = hf_config.mamba_d_conv, + intermediate_size = hf_config.mamba_expand *\ + hf_config.hidden_size, + use_conv_bias = hf_config.mamba_conv_bias, + use_bias = hf_config.mamba_proj_bias, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + rms_norm_eps=hf_config.rms_norm_eps, + activation=hf_config.hidden_act, + prefix=key, + ) + # suppress var not used error + assert fwd_context is not None + vllm_ctx = vllm_config.compilation_config.static_forward_context + + with monkeypatch.context() as m: + + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + + runner = GPUModelRunner(vllm_config, DEVICE) + kv_cache_spec = runner.get_kv_cache_spec() + + available_memory = 5 * GiB_bytes + kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, + available_memory) + runner.initialize_kv_cache(kv_cache_config) + + # random partition of blocks + # blocks0 will be assigned to attention layers + # blocks1 will be assigned to mamba layers + num_blocks = kv_cache_config.num_blocks + ind = np.arange(num_blocks) + np.random.shuffle(ind) + blocks0, blocks1 = ind[:(num_blocks // 2)], ind[(num_blocks // 2):] + + attn_shape = vllm_ctx[layer_0].kv_cache[0].shape + conv_shape = vllm_ctx[layer_2].kv_cache[0][0].shape + ssm_shape = vllm_ctx[layer_2].kv_cache[0][1].shape + + # assert we are using FlashInfer + assert attn_shape[0] == num_blocks + + attn_blocks_constant = torch.full((len(blocks0), *attn_shape[1:]), + device=DEVICE, + fill_value=3.33) + conv_blocks_constant = torch.full((len(blocks1), *conv_shape[1:]), + device=DEVICE, + fill_value=6.66) + ssm_blocks_constant = torch.full((len(blocks1), *ssm_shape[1:]), + device=DEVICE, + fill_value=9.99) + + # fill all attention blocks with constant + for layer in [layer_0, layer_1]: + vllm_ctx[layer].kv_cache[0][ + blocks0, :] = attn_blocks_constant.detach().clone() + + # fill all mamba blocks with constant + for layer in [layer_2, layer_3, layer_4, layer_5]: + vllm_ctx[layer].kv_cache[0][0][ + blocks1, :] = conv_blocks_constant.detach().clone() + vllm_ctx[layer].kv_cache[0][1][ + blocks1, :] = ssm_blocks_constant.detach().clone() + + # verify attention and mamba contents are correct + for layer in [layer_0, layer_1]: + assert torch.equal(vllm_ctx[layer].kv_cache[0][blocks0, :], + attn_blocks_constant) + for layer in [layer_2, layer_3, layer_4, layer_5]: + assert torch.equal(vllm_ctx[layer].kv_cache[0][0][blocks1, :], + conv_blocks_constant) + assert torch.equal(vllm_ctx[layer].kv_cache[0][1][blocks1, :], + ssm_blocks_constant) diff --git a/tools/ep_kernels/elastic_ep/eep_nvshmem.patch b/tools/ep_kernels/elastic_ep/eep_nvshmem.patch new file mode 100644 index 000000000000..5ebdaea58dd8 --- /dev/null +++ b/tools/ep_kernels/elastic_ep/eep_nvshmem.patch @@ -0,0 +1,92 @@ +From 18c0599c2f07ec965132efa25961dc8179c2dda3 Mon Sep 17 00:00:00 2001 +From: Yongji Wu <wuyongji317@gmail.com> +Date: Tue, 20 May 2025 13:41:12 -0700 +Subject: [PATCH] fix reinit issues due to states not cleaned up + +fix double free +--- + src/host/init/init.cu | 10 ++++++++++ + .../internal/host/nvshmemi_mem_transport.hpp | 15 +++++++++++++++ + src/modules/bootstrap/uid/bootstrap_uid.cpp | 5 +++++ + 3 files changed, 30 insertions(+) + +diff --git a/src/host/init/init.cu b/src/host/init/init.cu +index b1c5dbf..1fecb4b 100644 +--- a/src/host/init/init.cu ++++ b/src/host/init/init.cu +@@ -43,6 +43,8 @@ + #include "internal/host/nvshmemi_types.h" + #include "internal/host/shared_memory.h" + #include "internal/host/nvshmemi_symmetric_heap.hpp" ++// eep-dev ++#include "internal/host/nvshmemi_mem_transport.hpp" + + extern __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d; + static std::map<void *, int> registered_device_states; +@@ -1293,6 +1295,14 @@ void nvshmemid_hostlib_finalize(void *device_ctx, void *transport_device_ctx) { + /* Multi-init Multi-fini*/ + nvshmemi_state = NULL; + nvshmemi_device_state.nvshmemi_is_nvshmem_initialized = 0; ++ ++ // eep-dev ++ nvshmemi_mem_p2p_transport::destroy_instance(); ++ nvshmemi_mem_remote_transport::destroy_instance(); ++ free(nvshmemi_default_session); ++ nvshmemi_default_session = nullptr; ++ nvshmemi_device_state.nvshmemi_is_nvshmem_bootstrapped = false; ++ + nvshmemi_is_device_state_ready = false; + } else + nvshmemi_boot_handle.barrier(&nvshmemi_boot_handle); +diff --git a/src/include/internal/host/nvshmemi_mem_transport.hpp b/src/include/internal/host/nvshmemi_mem_transport.hpp +index 2495844..e4f408a 100644 +--- a/src/include/internal/host/nvshmemi_mem_transport.hpp ++++ b/src/include/internal/host/nvshmemi_mem_transport.hpp +@@ -36,6 +36,13 @@ class nvshmemi_mem_p2p_transport final { + return p2p_objref_; + } + } ++ // eep-dev ++ static void destroy_instance(void) { ++ if (p2p_objref_ != nullptr) { ++ delete p2p_objref_; ++ p2p_objref_ = nullptr; ++ } ++ } + + void print_mem_handle(int pe_id, int transport_idx, nvshmemi_symmetric_heap &obj); + +@@ -87,6 +94,14 @@ class nvshmemi_mem_remote_transport final { + } + } + ++ // eep-dev ++ static void destroy_instance(void) { ++ if (remote_objref_ != nullptr) { ++ delete remote_objref_; ++ remote_objref_ = nullptr; ++ } ++ } ++ + int gather_mem_handles(nvshmemi_symmetric_heap &obj, uint64_t heap_offset, size_t size); + /* On-demand registration and release of memory */ + int register_mem_handle(nvshmem_mem_handle_t *local_handles, int transport_idx, +diff --git a/src/modules/bootstrap/uid/bootstrap_uid.cpp b/src/modules/bootstrap/uid/bootstrap_uid.cpp +index a1fa748..788fa96 100644 +--- a/src/modules/bootstrap/uid/bootstrap_uid.cpp ++++ b/src/modules/bootstrap/uid/bootstrap_uid.cpp +@@ -630,6 +630,11 @@ int nvshmemi_bootstrap_plugin_pre_init(bootstrap_handle_t* handle, const int abi + // Discover the network for bootstrap, if not done previously. + // This code needs to be stateful to be able to be called multiple times by the caller + BOOTSTRAP_CHECK(bootstrap_net_init()); ++ // eep-dev ++ if (handle->pre_init_ops != nullptr) { ++ BOOTSTRAP_PTR_FREE(handle->pre_init_ops); ++ handle->pre_init_ops = nullptr; ++ } + if (handle->pre_init_ops == nullptr) { + BOOTSTRAP_CALLOC(&handle->pre_init_ops, 1); + handle->pre_init_ops->get_unique_id = bootstrap_get_unique_id; +-- +2.43.0 + diff --git a/tools/ep_kernels/elastic_ep/install_eep_libraries.sh b/tools/ep_kernels/elastic_ep/install_eep_libraries.sh new file mode 100644 index 000000000000..9d7dc1032f5e --- /dev/null +++ b/tools/ep_kernels/elastic_ep/install_eep_libraries.sh @@ -0,0 +1,86 @@ +#!/bin/bash + +set -ex + +# Default workspace directory +WORKSPACE=$(pwd)/eep_kernels_workspace +INSTALL_NVSHMEM=true + +# Parse command line arguments +while getopts "w:n" opt; do + case $opt in + w) + WORKSPACE="$OPTARG" + ;; + n) + INSTALL_NVSHMEM=false + ;; + \?) + echo "Invalid option: -$OPTARG" >&2 + exit 1 + ;; + esac +done + +if [ ! -d "$WORKSPACE" ]; then + mkdir -p $WORKSPACE +fi + + +# install dependencies if not installed +pip3 install cmake torch ninja + +# build nvshmem +pushd $WORKSPACE +# Reset NVSHMEM build if requested +if [ "$INSTALL_NVSHMEM" = true ]; then + mkdir -p nvshmem_src + wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz + tar -xvf nvshmem_src_3.2.5-1.txz -C nvshmem_src --strip-components=1 + pushd nvshmem_src + wget https://github.com/deepseek-ai/DeepEP/raw/main/third-party/nvshmem.patch + git init + git apply -vvv nvshmem.patch + git apply --reject --whitespace=fix ../../eep_nvshmem.patch +else + pushd nvshmem_src +fi + +# assume CUDA_HOME is set correctly +if [ -z "$CUDA_HOME" ]; then + echo "CUDA_HOME is not set, please set it to your CUDA installation directory." + exit 1 +fi + +# disable all features except IBGDA +export NVSHMEM_IBGDA_SUPPORT=1 + +export NVSHMEM_SHMEM_SUPPORT=0 +export NVSHMEM_UCX_SUPPORT=0 +export NVSHMEM_USE_NCCL=0 +export NVSHMEM_PMIX_SUPPORT=0 +export NVSHMEM_TIMEOUT_DEVICE_POLLING=0 +export NVSHMEM_USE_GDRCOPY=0 +export NVSHMEM_IBRC_SUPPORT=0 +export NVSHMEM_BUILD_TESTS=0 +export NVSHMEM_BUILD_EXAMPLES=0 +export NVSHMEM_MPI_SUPPORT=0 +export NVSHMEM_BUILD_HYDRA_LAUNCHER=0 +export NVSHMEM_BUILD_TXZ_PACKAGE=0 +export NVSHMEM_TIMEOUT_DEVICE_POLLING=0 + +cmake -G Ninja -S . -B $WORKSPACE/nvshmem_build/ -DCMAKE_INSTALL_PREFIX=$WORKSPACE/nvshmem_install +cmake --build $WORKSPACE/nvshmem_build/ --target install + +popd + +export CMAKE_PREFIX_PATH=$WORKSPACE/nvshmem_install:$CMAKE_PREFIX_PATH + +# build and install pplx, require pytorch installed +pushd $WORKSPACE +git clone https://github.com/ppl-ai/pplx-kernels +cd pplx-kernels +# see https://github.com/pypa/pip/issues/9955#issuecomment-838065925 +# PIP_NO_BUILD_ISOLATION=0 disables build isolation +PIP_NO_BUILD_ISOLATION=0 TORCH_CUDA_ARCH_LIST=9.0a+PTX pip install . --no-deps -v + diff --git a/tools/mypy.sh b/tools/mypy.sh index 77d342da1ec8..781d8fc02884 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -31,7 +31,5 @@ run_mypy vllm/inputs run_mypy vllm/lora run_mypy vllm/model_executor run_mypy vllm/plugins -run_mypy vllm/prompt_adapter -run_mypy vllm/spec_decode run_mypy vllm/worker run_mypy vllm/v1 diff --git a/typos.toml b/typos.toml deleted file mode 100644 index f51ce2f36208..000000000000 --- a/typos.toml +++ /dev/null @@ -1,179 +0,0 @@ -[files] -# these files may be written in non english words -extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*", - "benchmarks/sonnet.txt", "tests/lora/data/*", "build/*", - "vllm/third_party/*"] -ignore-hidden = true -ignore-files = true -ignore-dot = true -ignore-vcs = true -ignore-global = true -ignore-parent = true - -[default] -binary = false -check-filename = false -check-file = true -unicode = true -ignore-hex = true -identifier-leading-digits = false -locale = "en" -extend-ignore-identifiers-re = ["NVML_*", ".*Unc.*", ".*_thw", - ".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*", ".*fo.*", ".*ba.*", - ".*ot.*", ".*[Tt]h[rR].*"] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[default.extend-identifiers] -bbc5b7ede = "bbc5b7ede" -womens_doubles = "womens_doubles" -v_2nd = "v_2nd" -splitted_input = "splitted_input" -NOOPs = "NOOPs" -typ = "typ" -nin_shortcut = "nin_shortcut" -UperNetDecoder = "UperNetDecoder" -subtile = "subtile" -cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin" -SFOuput = "SFOuput" -# huggingface transformers repo uses these words -depthwise_seperable_out_channel = "depthwise_seperable_out_channel" -DepthWiseSeperableConv1d = "DepthWiseSeperableConv1d" -depthwise_seperable_CNN = "depthwise_seperable_CNN" - -[default.extend-words] -iy = "iy" -tendencias = "tendencias" -# intel cpu features -tme = "tme" -dout = "dout" -Pn = "Pn" -arange = "arange" - -[type.py] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.py.extend-identifiers] -arange = "arange" -NDArray = "NDArray" -EOFError = "EOFError" - -[type.py.extend-words] - -[type.cpp] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.cpp.extend-identifiers] -countr_one = "countr_one" - -[type.cpp.extend-words] - -[type.rust] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.rust.extend-identifiers] -flate2 = "flate2" - -[type.rust.extend-words] -ser = "ser" - -[type.lock] -extend-glob = [] -check-file = false -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.lock.extend-identifiers] - -[type.lock.extend-words] - -[type.jl] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.jl.extend-identifiers] - -[type.jl.extend-words] -modul = "modul" -egals = "egals" -usig = "usig" -egal = "egal" - -[type.go] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.go.extend-identifiers] -flate = "flate" - -[type.go.extend-words] - -[type.css] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.css.extend-identifiers] -nd = "nd" - -[type.css.extend-words] - -[type.man] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.man.extend-identifiers] -Nd = "Nd" - -[type.man.extend-words] - -[type.cert] -extend-glob = [] -check-file = false -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.cert.extend-identifiers] - -[type.cert.extend-words] - -[type.sh] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.sh.extend-identifiers] -stap = "stap" -ot = "ot" - -[type.sh.extend-words] - -[type.vimscript] -extend-glob = [] -extend-ignore-identifiers-re = [] -extend-ignore-words-re = [] -extend-ignore-re = [] - -[type.vimscript.extend-identifiers] -windo = "windo" - -[type.vimscript.extend-words] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index eb9d0b405892..cf296a3b534b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -13,7 +13,7 @@ logger = init_logger(__name__) -if not current_platform.is_tpu() and not current_platform.is_hpu(): +if not current_platform.is_tpu() and not current_platform.is_xpu(): try: import vllm._C except ImportError as e: @@ -956,35 +956,31 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, c_strides, per_act_token, per_out_ch) -def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, - a_scales: torch.Tensor, b_scales: torch.Tensor, - alphas: torch.Tensor, problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, - out_dtype: torch.dtype, device: torch.device): +def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, + b_tensors: torch.Tensor, a_scales: torch.Tensor, + b_scales: torch.Tensor, alphas: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor): """ - An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs the gemms for each combination based on the specified problem sizes. This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized input and expert weights. - a_/b_scales: The blockscales in FP8-E4M3 precision - - expert_offsets/sf_offsets: Indices that mark at which token index - each expert begins its computation. The number of tokens - computed with expert E is expert_offsets[E + 1] - - expert_offsets[E] And the sf_size per expert is + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is sf_offset[E+1] - sf_offset[E] - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped MMs used in the fused MoE operation. """ - m_topk = a_tensors.shape[0] - n = b_tensors.shape[1] - c_shape = (m_topk, n) - c = torch.empty(c_shape, device=device, dtype=out_dtype) - torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, - b_scales, alphas, problem_sizes, - expert_offsets, sf_offsets) - return c.to(out_dtype) + return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, alphas, + problem_sizes, expert_offsets, + sf_offsets) # aqlm @@ -1463,30 +1459,6 @@ def ggml_moe_get_block_size(quant_type: int) -> int: # mamba -def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - silu_activation: bool, pad_slot_id: int): - torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation, - pad_slot_id) - - -def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, - weight: torch.Tensor, bias_: Optional[torch.Tensor], - silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor], - pad_slot_id: int): - torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices, pad_slot_id) - - def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], @@ -1866,6 +1838,26 @@ def cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, return out +def sm100_cutlass_mla_decode(out: torch.Tensor, q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, page_table: torch.Tensor, + workspace: torch.Tensor, scale: float, + num_kv_splits: int) -> torch.Tensor: + torch.ops._C.sm100_cutlass_mla_decode(out, q_nope, q_pe, + kv_c_and_k_pe_cache, seq_lens, + page_table, workspace, scale, + num_kv_splits) + return out + + +def sm100_cutlass_mla_get_workspace_size(max_seq_len: int, num_batches: int, + sm_count: int, + num_kv_splits: int) -> int: + return torch.ops._C.sm100_cutlass_mla_get_workspace_size( + max_seq_len, num_batches, sm_count, num_kv_splits) + + if hasattr(torch.ops._C, "weight_packed_linear"): @register_fake("_C::weight_packed_linear") diff --git a/vllm/assets/video.py b/vllm/assets/video.py index 16412121cf0a..8ab0e9760be8 100644 --- a/vllm/assets/video.py +++ b/vllm/assets/video.py @@ -59,7 +59,9 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: if idx in frame_indices: # only decompress needed ret, frame = cap.retrieve() if ret: - frames.append(frame) + # OpenCV uses BGR format, we need to convert it to RGB + # for PIL and transformers compatibility + frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frames = np.stack(frames) if len(frames) < num_frames: @@ -71,10 +73,7 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: def video_to_pil_images_list(path: str, num_frames: int = -1) -> list[Image.Image]: frames = video_to_ndarrays(path, num_frames) - return [ - Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - for frame in frames - ] + return [Image.fromarray(frame) for frame in frames] def video_get_metadata(path: str) -> dict[str, Any]: diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 990ea054f338..ba20da4fd75f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,8 @@ import torch +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.multimodal import MultiModalPlaceholderMap if TYPE_CHECKING: @@ -267,7 +269,6 @@ def __init__( alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -289,7 +290,7 @@ def forward( raise NotImplementedError def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: tuple[int, int]): + group_shape: GroupShape): """ Does this attention implementation support fused output quantization. This is used by the AttnFusionPass to only fuse output quantization @@ -298,7 +299,7 @@ def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, TODO(luka) merge parameters into QuantDescriptor :param dtype: quantized dtype :param static: static or dynamic quantization - :param group_shape: quant group shape. (-1, -1) for per-tensor. + :param group_shape: quant group shape. :return: is fusion supported for this type of quantization """ return False diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py deleted file mode 100644 index fe9738d804cb..000000000000 --- a/vllm/attention/backends/blocksparse_attn.py +++ /dev/null @@ -1,465 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType) -from vllm.attention.backends.utils import (CommonAttentionState, - CommonMetadataBuilder) -from vllm.attention.ops.blocksparse_attention.interface import ( - LocalStridedBlockSparseAttn, get_head_sliding_step) -from vllm.attention.ops.paged_attn import PagedAttention -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - - -@dataclass -class BlocksparseParams: - max_seqlen: int - - # Num q heads per tensor-parallel rank/partition - num_heads: int # per TP partition - # Num kv heads per tensor-parallel rank/partition - num_kv_heads: int - - # block size used for blocksparse attention. - # This is the block_size used in `local_blocks`, `vert_stride`. - block_size: int - - # Number of blocks for local attention, i.e., number of - # local attended tokens / `sparse_block_size` - local_blocks: int - - # Attend to one block per every `vert_stride` blocks. - # Controlling the sparsity - vert_stride: int - """ - If to use the same vertical stride offset for all heads, - i.e., attend to the same block of tokens on all heads. - By default, it is False, i.e., attention on the non-local - blocks depends on the `head_idx`, that is on - blocks satisfying - `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` - where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, - `block_idx = position_id // sparse_block_size`. - See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` - for more detail. - """ - homo_head: bool = False - - # If within a group, the kv offsets that each q attends is the same or no. - homo_head_group: bool = False - - # Decided by homo_head and homo_head group - head_sliding_step: int = field(init=False) - - # range of q heads to for a TP rank - active_head_range: Tuple = field(init=False) - - def __post_init__(self): - assert self.block_size > 0 - assert self.local_blocks >= 0 - assert self.vert_stride >= 1 - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - total_heads = tp_size * self.num_heads - total_kv_heads = tp_size * self.num_kv_heads - - if self.homo_head: - self.head_sliding_step = 0 - elif self.homo_head_group: - head_sliding_step = get_head_sliding_step(total_kv_heads, - self.vert_stride) - # negative indicates sliding along kv heads, i.e., homo q group - self.head_sliding_step = -head_sliding_step - else: - self.head_sliding_step = get_head_sliding_step( - total_heads, self.vert_stride) - - self.active_head_range = ( - tp_rank * self.num_heads, - (tp_rank + 1) * self.num_heads, - ) - - -class BlocksparseFlashAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "BLOCK_SPARSE_FLASH_ATTN" - - @staticmethod - def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: - return BlocksparseFlashAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return BlocksparseFlashAttentionMetadata - - @staticmethod - def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: - return BlocksparseFlashAttentionMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - ) -> None: - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class BlocksparseFlashAttentionMetadata(AttentionMetadata): - """A copy of Metadata for FlashAttentionBackend, - to avoid having to install flash_attn. - - NOTE: Any python object stored here is not updated when it is - cuda-graph replayed. If you have values that need to be changed - dynamically, it should be stored in tensor. The tensor has to be - updated from `CUDAGraphRunner.forward` API. - """ - # (batch_size,). The sequence length per sequence. Sequence length means - # the computed tokens + new tokens None if it is a decoding. - seq_lens: Optional[List[int]] - # seq_lens stored as a tensor. - seq_lens_tensor: Optional[torch.Tensor] - - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ----------------------| - # |-- query_len ---| - - # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int - # (batch_size + 1,). The cumulative subquery lengths of the sequences in - # the batch, used to index into subquery. E.g., if the subquery length - # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] - # (batch_size + 1,). The cumulative sequence lengths of the sequences in - # the batch, used to index into sequence. E.g., if the sequence length is - # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - # (batch_size,) A tensor of context lengths (tokens that are computed - # so far). - context_lens_tensor: Optional[torch.Tensor] - - # (batch_size, max_blocks_per_seq). - # Block addresses per sequence. (Seq id -> list of physical block) - # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks - # in the kv cache. Each block can contain up to block_size tokens. - # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph - # captured. - block_tables: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - - # Max number of query tokens for among request in the batch. - max_decode_query_len: Optional[int] = None - - _cached_prefill_metadata: Optional[ - "BlocksparseFlashAttentionMetadata"] = None - _cached_decode_metadata: Optional[ - "BlocksparseFlashAttentionMetadata"] = None - - @property - def prefill_metadata( - self) -> Optional["BlocksparseFlashAttentionMetadata"]: - if self.num_prefills == 0: - return None - - if self._cached_prefill_metadata is not None: - return self._cached_prefill_metadata - - assert self.seq_lens is not None - assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None - assert self.block_tables is not None - assert self.seq_start_loc is not None - - self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( - num_prefills=self.num_prefills, - num_prefill_tokens=self.num_prefill_tokens, - num_decode_tokens=0, - slot_mapping=self.slot_mapping[:self.num_prefill_tokens], - multi_modal_placeholder_index_maps=self. - multi_modal_placeholder_index_maps, - enable_kv_scales_calculation=self.enable_kv_scales_calculation, - seq_lens=self.seq_lens[:self.num_prefills], - seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], - max_query_len=self.max_query_len, - max_prefill_seq_len=self.max_prefill_seq_len, - max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], - block_tables=self.block_tables[:self.num_prefills], - use_cuda_graph=False, - ) - return self._cached_prefill_metadata - - @property - def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: - if self.num_decode_tokens == 0: - return None - - if self._cached_decode_metadata is not None: - return self._cached_decode_metadata - assert self.block_tables is not None - assert self.seq_lens_tensor is not None - - self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=self.num_decode_tokens, - slot_mapping=self.slot_mapping[self.num_prefill_tokens:], - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - seq_lens=None, - seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, - max_prefill_seq_len=0, - max_decode_seq_len=self.max_decode_seq_len, - query_start_loc=None, - seq_start_loc=None, - context_lens_tensor=None, - block_tables=self.block_tables[self.num_prefills:], - use_cuda_graph=self.use_cuda_graph, - ) - return self._cached_decode_metadata - - -class BlocksparseFlashAttentionMetadataBuilder( - CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): - - _metadata_cls = BlocksparseFlashAttentionMetadata - - -class BlocksparseFlashAttentionImpl(AttentionImpl): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prompt_tokens -------------->| - |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| - - Otherwise, the layout is as follows: - |<------------------ num_generation_tokens (M) ----------------->| - |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - assert blocksparse_params is not None - assert alibi_slopes is None, ValueError( - "Alibi not support for blocksparse flash attention.") - assert sliding_window is None, ValueError( - "sliding_window is invalid for blocksparse attention.") - assert logits_soft_cap is None, ValueError( - "logits_soft_cap is invalid for blocksparse attention.") - - if "num_heads" not in blocksparse_params: - blocksparse_params["num_heads"] = num_heads - if "num_kv_heads" not in blocksparse_params: - blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads - self.blocksparse_params = BlocksparseParams(**blocksparse_params) - self.kv_cache_dtype = kv_cache_dtype - - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.alibi_slopes = alibi_slopes - self.num_kv_heads = num_kv_heads - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - self.local_blocks = self.blocksparse_params.local_blocks - self.vert_stride = self.blocksparse_params.vert_stride - self.sparse_block_size = self.blocksparse_params.block_size - self.head_sliding_step = self.blocksparse_params.head_sliding_step - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - - total_num_heads = num_heads * self.tp_size - self.bs_attn = LocalStridedBlockSparseAttn( - total_num_heads, - self.blocksparse_params.max_seqlen, - self.blocksparse_params.local_blocks, - self.blocksparse_params.vert_stride, - self.blocksparse_params.block_size, - homo_head=self.blocksparse_params.homo_head, - active_head_range=self.blocksparse_params.active_head_range, - ) - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "BlocksparseFlashAttentionImpl") - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: BlocksparseFlashAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with FlashAttention and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for BlocksparseFlashAttentionImpl") - - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - if prefill_meta := attn_metadata.prefill_metadata: - - # Prompt run. - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - - assert kv_cache.numel() == 0 \ - or prefill_meta.block_tables is None \ - or prefill_meta.block_tables.numel() == 0, \ - "Does not support prefix-enabled attention." - - output = self.bs_attn( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - sm_scale=self.scale, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - output = PagedAttention.forward_decode( - query, - key_cache, - value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - self.blocksparse_params.max_seqlen, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - tp_rank=self.tp_rank, - blocksparse_local_blocks=self.local_blocks, - blocksparse_vert_stride=self.vert_stride, - blocksparse_block_size=self.sparse_block_size, - blocksparse_head_sliding_step=self.head_sliding_step, - ) - - assert output is not None - # Reshape the output tensor. - return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/cpu_mla.py b/vllm/attention/backends/cpu_mla.py deleted file mode 100644 index 793cb87b7434..000000000000 --- a/vllm/attention/backends/cpu_mla.py +++ /dev/null @@ -1,307 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch - -import vllm._custom_ops as ops -from vllm._ipex_ops import ipex_ops -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState -from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata -from vllm.utils import make_tensor_with_pad -from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder - - -class CPUMLABackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "CPU_MLA" - - @staticmethod - def get_metadata_cls() -> Type["CPUMLAMetadata"]: - return CPUMLAMetadata - - @staticmethod - def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]: - return CPUMLAMetadataBuilder - - @staticmethod - def get_state_cls() -> Type["MLACommonState"]: - return MLACommonState - - @staticmethod - def get_impl_cls() -> Type["CPUMLAImpl"]: - return CPUMLAImpl - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, # assumed to be 1 for MLA - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - ops.copy_blocks_mla(kv_caches, src_to_dists) - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [576] - - -@dataclass -class CPUMLAMetadata(TorchSDPAMetadata): - # New for MLA - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor = None - - # required by MLACommonImpl - is_profile_run: bool = False - - -class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]): - - def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: - self.chunked_prefill = input_builder.chunked_prefill - self.input_builder = input_builder - assert not self.chunked_prefill, \ - "chunked prefill is currently not supported" - - def prepare(self): - self.input_data = self.input_builder.input_data - - def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size): - input_data = self.input_data - prefill_seq_lens = seq_lens[0:input_data.num_prefills] - prefill_query_lens = query_lens[0:input_data.num_prefills] - slot_mapping = torch.tensor(input_data.slot_mapping, - dtype=torch.long, - device="cpu") - - # metadata for prefill - if input_data.num_prefills > 0: - query_lens_tensor = torch.tensor(prefill_query_lens, - dtype=torch.int32, - device="cpu") - kv_lens_tensor = torch.tensor(prefill_seq_lens, - dtype=torch.int32, - device="cpu") - query_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - kv_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - torch.cumsum(query_lens_tensor, - dim=0, - dtype=torch.int32, - out=query_start_loc[1:]) - torch.cumsum(kv_lens_tensor, - dim=0, - dtype=torch.int32, - out=kv_start_loc[1:]) - max_query_len = max(prefill_query_lens) - max_kv_len = max(prefill_seq_lens) - - # for chunked-prefill - if self.chunked_prefill: - prefill_block_tables = make_tensor_with_pad( - self.input_data.prefill_block_tables, - pad=0, - dtype=torch.int32, - device="cpu", - ) - else: - prefill_block_tables = None - - else: - query_start_loc = None - kv_start_loc = None - max_query_len = None - max_kv_len = None - prefill_block_tables = None - - # metadata for decode - if input_data.num_decode_tokens != 0: - seq_lens_tensor = torch.tensor( - input_data.seq_lens[input_data.num_prefills:], - dtype=torch.int32, - device="cpu", - ) - block_tables = make_tensor_with_pad( - self.input_data.decode_block_tables, - pad=0, - dtype=torch.int32, - device="cpu", - ) - else: - block_tables = torch.tensor([]) - seq_lens_tensor = torch.tensor( - input_data.seq_lens[:input_data.num_prefills], - dtype=torch.int32, - device="cpu", - ) - - # For multi-modal models - placeholder_index_maps = None - if len(input_data.multi_modal_inputs_list) != 0: - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - input_data.multi_modal_placeholder_maps.items() - } - - return CPUMLAMetadata( - chunked_prefill=self.chunked_prefill, - seq_lens=prefill_seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_kv_len=max_kv_len, - prefill_query_start_loc=query_start_loc, - kv_start_loc=kv_start_loc, - max_decode_seq_len=input_data.max_decode_seq_len, - num_prefills=input_data.num_prefills, - num_prefill_tokens=input_data.num_prefill_tokens, - num_decode_tokens=input_data.num_decode_tokens, - block_tables=block_tables, - prefill_block_tables=prefill_block_tables, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=False, - input_positions=torch.tensor([self.input_data.input_positions])) - - -class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - # MLA Specific Arguments - **mla_args) -> None: - super().__init__(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **mla_args) - - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] - if any(unsupported_features): - raise NotImplementedError( - "CPUMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "CPUMLAImpl") - - # states is implemented. - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "CPUMLAImpl with FP8 KV cache not yet supported") - - def _forward_prefill( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: CPUMLAMetadata, # type: ignore[override] - ) -> torch.Tensor: - - prefill_metadata = attn_metadata.prefill_metadata - assert prefill_metadata is not None - - kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ - -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv_nope\ - .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim - v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], - value=0) - - output = torch.empty_like(q) - ipex_ops.varlen_attention( - query=q, - key=k, - value=v_padded, - out=output, - seqlen_q=prefill_metadata.prefill_query_start_loc, - seqlen_k=prefill_metadata.prefill_query_start_loc, - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.max_query_len, - pdropout=0.0, - softmax_scale=self.scale, - zero_tensors=False, - is_causal=True, - return_softmax=False, - gen_=None, - logits_soft_cap=0.0, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - - # remove padding - output = output.view(-1, self.num_heads, - q.shape[-1])[..., :v.shape[-1]] - return output.reshape(-1, self.num_heads * v.shape[-1]) - - def _forward_decode( - self, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: CPUMLAMetadata, # type: ignore[override] - ) -> torch.Tensor: - assert kv_c_and_k_pe_cache.numel() > 0 - - decode_meta = attn_metadata.decode_metadata - assert decode_meta is not None - - q = torch.cat([q_nope, q_pe], dim=-1) - o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank) - - # Run MQA - ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale, - decode_meta.block_tables, - decode_meta.seq_lens_tensor) - return self._v_up_proj(o) diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py new file mode 100644 index 000000000000..bd9bc427728d --- /dev/null +++ b/vllm/attention/backends/differential_flash_attn.py @@ -0,0 +1,996 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""" An implementation of https://arxiv.org/pdf/2410.05258 """ +from collections import defaultdict +from dataclasses import dataclass +from itertools import accumulate +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch +from einops import rearrange + +from vllm import _custom_ops as ops +# yapf conflicts with isort for this block +# yapf: disable +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType, + is_quantized_kv_cache) +from vllm.attention.backends.flash_attn import FlashAttentionBackend +# yapf: enable +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + compute_slot_mapping, + compute_slot_mapping_start_idx, + is_all_cross_attn_metadata_set, + is_all_encoder_attn_metadata_set, + is_block_tables_empty) +from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, + get_flash_attn_version) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalPlaceholderMap +from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +logger = init_logger(__name__) + + +class DifferentialFlashAttentionBackend(AttentionBackend): + accept_output_buffer = False + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + assert num_kv_heads % 2 == 0, "num_kv_heads must be divisible by 2" + return (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) + + @staticmethod + def get_name() -> str: + return "DIFFERENTIAL_FLASH_ATTN" + + @staticmethod + def get_impl_cls() -> Type["DifferentialFlashAttentionImpl"]: + return DifferentialFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["DifferentialFlashAttentionMetadata"]: + return DifferentialFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["DifferentialFlashAttentionMetadataBuilder"]: + return DifferentialFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class DifferentialFlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + + use_cuda_graph: bool + + # Maximum query length in the batch. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + _cached_prefill_metadata: Optional[ + "DifferentialFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "DifferentialFlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + encoder_seq_start_loc: Optional[torch.Tensor] = None + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + # Cross-layer shared attention block tables + cross_layer_shared_block_tables: Optional[torch.Tensor] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return is_all_encoder_attn_metadata_set(self) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return is_all_cross_attn_metadata_set(self) + + @property + def prefill_metadata( + self) -> Optional["DifferentialFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + seq_start_loc = (None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + cross_layer_shared_block_tables = ( + None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[:self.num_prefills]) + + self._cached_prefill_metadata = DifferentialFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=self. + multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata( + self) -> Optional["DifferentialFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + cross_layer_shared_block_tables = ( + None if self.cross_layer_shared_block_tables is None else + self.cross_layer_shared_block_tables[self.num_prefills:]) + self._cached_decode_metadata = DifferentialFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, + seq_lens=None, + seq_lens_tensor=seq_lens_tensor, + max_decode_query_len=self.max_decode_query_len, + max_query_len=self.max_query_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + # Batch may be composed of prefill|decodes, adjust query start + # indices to refer to the start of decodes. E.g. + # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. + query_start_loc=(self.query_start_loc[self.num_prefills:] - + self.query_start_loc[self.num_prefills]) + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Multi-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class DifferentialFlashAttentionMetadataBuilder( + AttentionMetadataBuilder[DifferentialFlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + + def prepare(self): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.cross_layer_shared_block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.multimodal_placeholder_maps: Dict[ + str, + MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + # TODO: add support for chunked prefill and prefix caching. + assert not chunked_prefill_enabled, \ + "chunked prefill is not supported for now" + assert not prefix_cache_hit, "prefix caching is not supported for now" + + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + mm_maps = inter_data.multi_modal_placeholder_maps + if mm_maps: + for modality, placeholders in mm_maps.items(): + self.multimodal_placeholder_maps[modality].extend( + placeholders) + + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + cross_layer_shared_block_table = [] + if prefix_cache_hit: + cross_layer_shared_block_table = block_tables[seq_id] + elif block_tables is not None: + if curr_sliding_window_block == 0: + cross_layer_shared_block_table = block_tables[seq_id] + else: + cross_layer_shared_block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.cross_layer_shared_block_tables.append( + cross_layer_shared_block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables(self, num_seqs: int, + block_tables: List[List[int]], + graph_block_tables) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + # max_batch_size, max_blocks = self.runner.graph_block_tables.shape + max_batch_size, max_blocks = graph_block_tables.shape + assert max_batch_size >= num_seqs + + # graph_block_tables = self.runner.graph_block_tables[:num_seqs] + graph_block_tables = graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + query_start_loc = list(accumulate(query_lens, initial=0)) + seq_start_loc = list(accumulate(seq_lens, initial=0)) + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + + self.cross_layer_shared_block_tables.extend([] * + cuda_graph_pad_size) + + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables, self.runner.graph_block_tables) + cross_layer_shared_block_tables = \ + self._get_graph_runner_block_tables( + num_seqs, self.cross_layer_shared_block_tables, + self.runner.cross_layer_shared_graph_block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + cross_layer_shared_block_tables = make_tensor_with_pad( + self.cross_layer_shared_block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, + device, + self.runner.pin_memory) + seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32, + device, self.runner.pin_memory) + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + self.multimodal_placeholder_maps.items() + } + + return DifferentialFlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc_tensor, + seq_start_loc=seq_start_loc_tensor, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + cross_layer_shared_block_tables=cross_layer_shared_block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class DifferentialFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + differential_flash_attention_config: Optional[Dict[str, Any]] = None, + ) -> None: + if differential_flash_attention_config is None: + differential_flash_attention_config = {} + self.differential_flash_attention_config = \ + differential_flash_attention_config + self.used_shared_kv_cache = kv_sharing_target_layer_name is not None + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + if use_irope: + logger.warning( + "Using irope in V0 is not supported yet, it will fall back " + "to global attention for long context.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + self.vllm_flash_attn_version = get_flash_attn_version( + requires_alibi=self.alibi_slopes is not None) + if is_quantized_kv_cache(self.kv_cache_dtype) and ( + not self.kv_cache_dtype.startswith("fp8") + or not flash_attn_supports_fp8()): + raise NotImplementedError( + f"FlashAttention does not support {self.kv_cache_dtype} " + "kv-cache on this device " + f"(FA supports fp8 = {flash_attn_supports_fp8()}).") + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type + + self.lambda_full = None + self.subln = self.differential_flash_attention_config["subln"] + + def split_heads(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + x = rearrange(x, "... (H two) D -> ... H two D", two=2) + x1 = x[..., 0, :] + x2 = x[..., 1, :] + return x1.contiguous(), x2.contiguous() + + def split_kv_cache(self, x): + # split by num_heads, the stripe pattern is friendly to tensor parallel. + if x.numel() == 0: + return torch.empty(0), torch.empty(0) + + x1, x2 = x[0], x[1] + return x1, x2 + + def populate_kv_cache(self, layer: AttentionLayer, key: torch.Tensor, + value: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata): + if kv_cache.numel() > 0 and key is not None and value is not None: + updated_slot_mapping = attn_metadata.slot_mapping + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + def forward_generate_kv_cache( + self, query: torch.Tensor, key: Optional[torch.Tensor], + value: Optional[torch.Tensor], k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata) -> torch.Tensor: + + head_size = self.head_size + num_heads = self.num_heads // 2 + num_kv_heads = self.num_kv_heads // 2 + + query = query.view(-1, num_heads, head_size) + if key is not None: + assert value is not None + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + else: + assert value is None + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[ + 0] == num_prefill_tokens + num_decode_tokens, "key shape mismatch" + assert value.shape[ + 0] == num_prefill_tokens + num_decode_tokens, "value shape mismatch" + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens, "query shape mismatch" + assert decode_query.shape[ + 0] == num_decode_tokens, "decode query shape mismatch" + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if k_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0: + # normal attention + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + assert prefill_output.shape == output[: + num_prefill_tokens].shape + output[:num_prefill_tokens] = prefill_output + else: + raise Exception("prefix caching not supported") + + if decode_meta := attn_metadata.decode_metadata: + block_tables_arg = decode_meta.block_tables + try: + output[num_prefill_tokens:] = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + except Exception as e: + logger.error("Error in PagedAttention.forward_decode: %s", + str(e)) + raise e + + # Reshape the output tensor. + return output.view(-1, num_heads, head_size) + + def forward_with_kv_cache_only( + self, + query: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata, + ): + if not attn_metadata.decode_metadata: + block_tables_arg = attn_metadata.cross_layer_shared_block_tables + else: + block_tables_arg = attn_metadata.block_tables + + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=k_cache, + v_cache=v_cache, + block_table=block_tables_arg, + cache_seqlens=attn_metadata.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ).squeeze(1) + return output + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: DifferentialFlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + output: shape = [num_tokens, num_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + NOTE: It in-place updates the output tensor. + NOTE: FP8 quantization, flash-attn expect the size of + {q,k,v}_descale to be (num_sequences, num_kv_heads). + We use torch's .expand() to avoid duplicating values + """ + if self.lambda_full is None: + self.lambda_init = self.differential_flash_attention_config[ + "lambda_init"] + lambda_q1 = self.differential_flash_attention_config["lambda_q1"] + lambda_k1 = self.differential_flash_attention_config["lambda_k1"] + lambda_q2 = self.differential_flash_attention_config["lambda_q2"] + lambda_k2 = self.differential_flash_attention_config["lambda_k2"] + lambda_1 = torch.exp( + torch.sum(lambda_q1 * lambda_k1, dim=-1).float()).type_as(q) + lambda_2 = torch.exp( + torch.sum(lambda_q2 * lambda_k2, dim=-1).float()).type_as(q) + self.lambda_full = lambda_1 - lambda_2 + self.lambda_init + + if not self.used_shared_kv_cache: # need to generate kv-cache + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size) + + q1, q2 = self.split_heads(q) + k1, k2 = self.split_heads(k) + v1, v2 = self.split_heads(v) + + # kv_cache shape is (2, 2, num_blocks, block_size, num_kv_heads // 2, head_size) # noqa: E501 + # Split by half along the first dimension. + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + assert kv_cache1.is_contiguous(), "kv_cache1 is not contiguous" + assert kv_cache2.is_contiguous(), "kv_cache2 is not contiguous" + + if kv_cache1.numel() != 0: + self.populate_kv_cache(layer, k1, v1, kv_cache1, attn_metadata) + self.populate_kv_cache(layer, k2, v2, kv_cache2, attn_metadata) + + key_cache1, value_cache1 = self.split_kv_cache(kv_cache1) + key_cache2, value_cache2 = self.split_kv_cache(kv_cache2) + else: + key_cache1, value_cache1 = torch.empty(0), torch.empty(0) + key_cache2, value_cache2 = torch.empty(0), torch.empty(0) + attn11 = self.forward_generate_kv_cache(q1, k1, v1, key_cache1, + value_cache1, + attn_metadata) + attn12 = self.forward_generate_kv_cache(q1, k1, v2, key_cache1, + value_cache2, + attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_generate_kv_cache(q2, k2, v1, key_cache2, + value_cache1, + attn_metadata) + attn22 = self.forward_generate_kv_cache(q2, k2, v2, key_cache2, + value_cache2, + attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + attn = attn1 - self.lambda_full * attn2 + # attn shape (-1, self.num_heads // 2, 2 * self.head_dim) + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, + "... H (two D) -> ... (H two) D", + two=2) + + else: # reuse the kv cache, full attention + q = q.view(-1, self.num_heads, self.head_size) + q1, q2 = self.split_heads(q) + # kv_cache shape is (2, num_blocks, block_size, num_kv_heads, head_size) # noqa: E501 + kv_cache1, kv_cache2 = self.split_kv_cache(kv_cache) + key_cache1, value_cache1 = kv_cache1[0], kv_cache1[1] + key_cache2, value_cache2 = kv_cache2[0], kv_cache2[1] + + attn11 = self.forward_with_kv_cache_only(q1, key_cache1, + value_cache1, + attn_metadata) + attn12 = self.forward_with_kv_cache_only(q1, key_cache1, + value_cache2, + attn_metadata) + attn11 = attn11.view(q1.shape) + attn12 = attn12.view(q1.shape) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = self.forward_with_kv_cache_only(q2, key_cache2, + value_cache1, + attn_metadata) + attn22 = self.forward_with_kv_cache_only(q2, key_cache2, + value_cache2, + attn_metadata) + attn21 = attn21.view(q2.shape) + attn22 = attn22.view(q2.shape) + attn2 = torch.cat([attn21, attn22], dim=-1) + + attn = attn1 - self.lambda_full * attn2 + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + # reshape back to 2 * num_head + attn_output = rearrange(attn, + "... H (two D) -> ... (H two) D", + two=2) + attn_output = attn_output.view(-1, self.num_heads * self.head_size) + return attn_output diff --git a/vllm/attention/backends/dual_chunk_flash_attn.py b/vllm/attention/backends/dual_chunk_flash_attn.py index f62a43b441f2..fa6f3f1b39cc 100644 --- a/vllm/attention/backends/dual_chunk_flash_attn.py +++ b/vllm/attention/backends/dual_chunk_flash_attn.py @@ -287,7 +287,6 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -295,7 +294,8 @@ def __init__( dual_chunk_attention_config: Optional[Dict[str, Any]] = None, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "DUAL_CHUNK_FLASH_ATTN backend.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -1055,7 +1055,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_intra, softmax_scale=softmax_scale, causal=True, - block_table=block_table, stage="intra", vertical_indices=vertical_buffer, slash_indices=slash_buffer, @@ -1070,7 +1069,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_intra, softmax_scale=softmax_scale, causal=True, - block_table=block_table, stage="intra", vertical_indices=intra_vertical_indices, slash_indices=intra_slash_indices, @@ -1085,7 +1083,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_succ, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="succ", vertical_indices=succ_vertical_buffer, slash_indices=succ_slash_buffer, @@ -1100,7 +1097,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_succ, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="succ", vertical_indices=succ_vertical_indices, slash_indices=succ_slash_indices, @@ -1115,7 +1111,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_inter, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="inter", vertical_indices=inter_vertical_buffer, slash_indices=inter_slash_buffer, @@ -1130,7 +1125,6 @@ def _dual_chunk_flash_attn_prefill_func( v_states_inter, softmax_scale=softmax_scale, causal=False, - block_table=block_table, stage="inter", vertical_indices=inter_vertical_indices, slash_indices=inter_slash_indices, @@ -1151,7 +1145,6 @@ def _do_flash_attn( value_states: torch.Tensor, softmax_scale: float, causal: bool = True, - block_table: torch.Tensor = None, max_seqlen_k: Optional[int] = None, stage: str = "intra", vertical_indices: Optional[torch.Tensor] = None, @@ -1230,7 +1223,6 @@ def _do_flash_attn( device=query_states.device), max_seqlen_k=max_seqlen_k, causal=causal, - block_table=block_table.unsqueeze(0), return_softmax_lse=True, ) softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bf8e373802f8..ee36fd19e012 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import dataclass from itertools import accumulate -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type import torch @@ -615,17 +615,14 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") + raise NotImplementedError("KV sharing is not supported in V0 " + "FLASH_ATTN backend.") if use_irope: logger.warning( "Using irope in V0 is not supported yet, it will fall back " diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b7d80f5194c0..56d3da699f40 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -11,7 +11,8 @@ try: from flashinfer import BatchDecodeWithPagedKVCacheWrapper - from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import (CUDAGraphBatchDecodeWithPagedKVCacheWrapper, + trtllm_batch_decode_with_kv_cache) from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -22,7 +23,10 @@ BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None BatchPrefillWithPagedKVCacheWrapper = None + trtllm_batch_decode_with_kv_cache = None FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + raise ImportError("FlashInfer is not installed. Please install it from " + "https://github.com/flashinfer-ai/flashinfer") from None import torch @@ -40,6 +44,7 @@ from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, make_tensor_with_pad) @@ -49,10 +54,9 @@ from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD" - class FlashInferBackend(AttentionBackend): + cached_sm100a_supported: Optional[bool] = None @staticmethod def get_name() -> str: @@ -85,7 +89,7 @@ def get_kv_cache_shape( @staticmethod def get_kv_cache_stride_order() -> Tuple[int, ...]: - cache_layout = FLASHINFER_KV_CACHE_LAYOUT + cache_layout = FlashInferState.get_kv_cache_layout() assert (cache_layout in ("NHD", "HND")) stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2, 4) @@ -119,6 +123,47 @@ def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: else: raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + @staticmethod + def use_trtllm_decode_attention( + batch_size: int, + max_seq_len: int, + kv_cache_dtype: str, + num_qo_heads: Optional[int], + num_kv_heads: Optional[int], + attn_head_size: Optional[int], + ) -> bool: + if FlashInferBackend.cached_sm100a_supported is None: + FlashInferBackend.cached_sm100a_supported = ( + current_platform.has_device_capability(100)) + if not FlashInferBackend.cached_sm100a_supported: + return False + # Check if the dimensions are supported by TRTLLM decode attention + if (attn_head_size is None or num_qo_heads is None + or num_kv_heads is None or num_qo_heads // num_kv_heads > 8 + or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + return False + env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", + env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + no_use_trtllm = (env_value == "0") + if not no_use_trtllm: + logger.info_once("Using TRTLLM decode attention.") + return not no_use_trtllm + else: + # Environment variable not set - use auto-detection + use_trtllm = (FlashInferBackend.cached_sm100a_supported + and batch_size <= 256 and max_seq_len < 131072 + and kv_cache_dtype == "auto") + if use_trtllm: + logger.warning_once( + "Using TRTLLM decode attention (auto-detected).") + return use_trtllm + @dataclass class PerLayerParameters: @@ -207,10 +252,19 @@ def _get_workspace_buffer(self): device=self.runner.device) return self._workspace_buffer - def get_kv_cache_layout(self): - if self._kv_cache_layout is None: - self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT - return self._kv_cache_layout + @staticmethod + def get_kv_cache_layout(): + from vllm.v1.attention.backends.utils import _KV_CACHE_LAYOUT_OVERRIDE + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + logger.info_once("Using KV cache layout %s", + _KV_CACHE_LAYOUT_OVERRIDE) + return _KV_CACHE_LAYOUT_OVERRIDE + cache_layout = envs.VLLM_KV_CACHE_LAYOUT + if cache_layout is None: + logger.info_once("Using default KV cache layout NHD") + return "NHD" + logger.info_once("Using KV cache layout %s", cache_layout) + return cache_layout def _get_prefill_wrapper(self): if self._prefill_wrapper is None: @@ -323,6 +377,8 @@ def graph_capture_get_metadata_for_batch( num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, + max_decode_seq_len=0, + seq_lens_tensor=self._graph_seq_lens, block_tables=self._graph_block_tables, paged_kv_indptr=paged_kv_indptr_tensor_host, paged_kv_indices=paged_kv_indices_tensor_host, @@ -348,6 +404,8 @@ def get_graph_input_buffers(self, attn_metadata, is_encoder_decoder_model: bool = False): return { + "block_tables": attn_metadata.block_tables, + "seq_lens_tensor": attn_metadata.seq_lens_tensor, "slot_mapping": attn_metadata.slot_mapping, } @@ -355,7 +413,13 @@ def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - return + # FlashInfer-specific logic: copy additional tensors + num_total_blocks = attn_metadata.decode_metadata.seq_lens_tensor.shape[ + 0] + input_buffers["seq_lens_tensor"][:num_total_blocks].copy_( + attn_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"][:num_total_blocks].copy_( + attn_metadata.block_tables, non_blocking=True) def begin_forward(self, model_input): assert not self._is_graph_capturing @@ -388,6 +452,8 @@ class FlashInferMetadata(AttentionMetadata): # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int + max_decode_seq_len: int + # Number of query tokens for each request in the batch. # Currently, we require that all requests have the same number of query # tokens during the decoding phase. When speculavie decoding is enabled, @@ -790,6 +856,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_captured_graph = cuda_graph_pad_size != -1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens decode_query_len = max(query_lens[self.num_prefills:], default=1) @@ -895,6 +962,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, block_tables=block_tables, paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, @@ -931,14 +999,14 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "FLASHINFER backend.") if use_irope: logger.warning_once( "Using irope in FlashInfer is not supported yet, it will fall" @@ -1081,13 +1149,36 @@ def forward( assert decode_meta.decode_wrapper._logits_soft_cap == ( logits_soft_cap or 0.0) assert decode_meta.decode_wrapper._sm_scale == softmax_scale - - decode_output = decode_meta.decode_wrapper.run( - decode_query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - ) + # TODO: @pavanimajety Remove this once the switch happens + # inside flashinfer. + if not FlashInferBackend.use_trtllm_decode_attention( + num_decode_tokens, attn_metadata.max_decode_seq_len, + kv_cache_dtype, attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, attn_metadata.head_dim): + decode_output = decode_meta.decode_wrapper.run( + decode_query, + kv_cache.permute(*stride_order), + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + ) + else: + workspace_buffer = ( + decode_meta.decode_wrapper._int_workspace_buffer) + assert FlashInferState.get_kv_cache_layout() == "HND" + decode_output = trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache.permute(*stride_order), + workspace_buffer=workspace_buffer, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + scale=softmax_scale, + block_tables=attn_metadata.block_tables, + seq_lens=decode_meta.seq_lens_tensor, + block_size=attn_metadata.page_size, + max_seq_len=attn_metadata.max_decode_seq_len, + kv_cache_dtype=kv_cache_dtype, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float) if prefill_output is None and decode_output is not None: # Decode only batch. diff --git a/vllm/attention/backends/flashmla.py b/vllm/attention/backends/flashmla.py index e185d0260d0a..a242ac9bbe0b 100644 --- a/vllm/attention/backends/flashmla.py +++ b/vllm/attention/backends/flashmla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import torch @@ -181,7 +181,6 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str] = None, @@ -189,20 +188,17 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py deleted file mode 100644 index bf778a1e5016..000000000000 --- a/vllm/attention/backends/hpu_attn.py +++ /dev/null @@ -1,318 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -############################################################################### - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -import vllm_hpu_extension.kernels as kernels -import vllm_hpu_extension.ops as ops -from vllm_hpu_extension.flags import enabled_flags -from vllm_hpu_extension.utils import Matmul, Softmax, VLLMKVCache - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention, - HPUPagedAttentionMetadata) -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class HPUAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "HPU_ATTN" - - @staticmethod - def get_impl_cls() -> Type["HPUAttentionImpl"]: - return HPUAttentionImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return HPUAttentionMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return HPUPagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dsts: torch.Tensor, - ) -> None: - HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dsts: torch.Tensor, - ) -> None: - HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) - - -@dataclass -class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): - """Metadata for HPUAttentionbackend.""" - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - attn_bias: Optional[torch.Tensor] - seq_lens_tensor: Optional[torch.Tensor] - context_lens_tensor: Optional[torch.Tensor] - - -class HPUAttentionImpl(AttentionImpl, torch.nn.Module): - """ - If the input tensors contain prompt tokens, the layout is as follows: - |<--------------- num_prefill_tokens ----------------->| - |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| - - Otherwise, the layout is as follows: - |<----------------- num_decode_tokens ------------------>| - |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| - - Generation tokens can contain padding when cuda-graph is used. - Currently, prompt tokens don't contain any padding. - - The prompts might have different lengths, while the generation tokens - always have length 1. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - max_seq_len: int = 4096, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - super(AttentionImpl, self).__init__() - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if use_irope: - logger.warning_once( - "Using irope in HPU is not supported yet, it will fall back " - "to global attention for long context.") - self.kv_cache_dtype = kv_cache_dtype - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.matmul_qk = Matmul() - self.softmax = Softmax() - self.matmul_av = Matmul() - self.batch2block_matmul = Matmul() - self.block2batch_matmul = Matmul() - self.k_cache = VLLMKVCache() - self.v_cache = VLLMKVCache() - self.fused_scaled_dot_product_attention = kernels.fsdpa() - - self.prefill_impl = 'naive' - if "flex_attention" in enabled_flags(): - self.prefill_impl = 'flex' - if "fsdpa" in enabled_flags(): - assert alibi_slopes is None, \ - 'Prefill with FusedSDPA not supported with alibi slopes!' - self.prefill_impl = 'fsdpa' - - self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - self.sliding_window = sliding_window - self.alibi_slopes = alibi_slopes - if alibi_slopes is not None: - alibi_slopes_tensor = torch.tensor(alibi_slopes, - dtype=torch.bfloat16) - self.alibi_slopes = alibi_slopes_tensor - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - - if self.prefill_impl == 'fsdpa': - assert alibi_slopes is None, \ - 'Prefill with FusedSDPA not supported with alibi slopes!' - - supported_head_sizes = HPUPagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - self.attn_type = attn_type - if self.attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "HPUAttentionImpl") - - if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "HPUAttention with FP8 KV cache not yet supported") - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: HPUAttentionMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with xFormers and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for HPUAttentionImpl") - - batch_size, seq_len, hidden_size = query.shape - _, seq_len_kv, _ = key.shape - - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - block_indices = attn_metadata.block_indices - block_offsets = attn_metadata.block_offsets - key_cache = None - value_cache = None - if attn_metadata.is_prompt and self.attn_type \ - is not AttentionType.ENCODER_ONLY: - key = key.unflatten(0, (block_indices.size(0), -1)) - value = value.unflatten(0, (block_indices.size(0), -1)) - if kv_cache is not None and isinstance(kv_cache, tuple): - key_cache, value_cache = HPUPagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, key_cache, block_indices, - block_offsets) - value_cache = self.v_cache(value, value_cache, block_indices, - block_offsets) - - if attn_metadata.is_prompt: - # Prompt run. - query_shape = (batch_size, seq_len, self.num_heads, self.head_size) - kv_shape = (batch_size, seq_len_kv, self.num_kv_heads, - self.head_size) - - attn_bias = attn_metadata.attn_bias - if attn_bias is not None and self.alibi_slopes is not None: - position_bias = _make_alibi_bias(self.alibi_slopes, - self.num_kv_heads, - attn_bias.dtype, - attn_bias.shape[-1]) - attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) - attn_bias.add_(position_bias) - - block_list = attn_metadata.block_list if attn_metadata \ - and attn_metadata.block_list is not None else None - - out = ops.prompt_attention( - impl=self.prefill_impl, - query=query.view(query_shape), - key=key.view(kv_shape), - value=value.view(kv_shape), - is_causal=True, - attn_bias=attn_bias, - valid_seq_lengths=attn_metadata.seq_lens_tensor, - **self.common_attention_args(block_list, key_cache, - value_cache)) - output = out.reshape(batch_size, seq_len, hidden_size) - else: - # Decoding run. - output = HPUPagedAttention.forward_decode( - query=query, - block_mapping=attn_metadata.block_mapping, - block_bias=attn_metadata.attn_bias, - block_groups=attn_metadata.block_groups, - **self.common_attention_args(attn_metadata.block_list, - key_cache, value_cache)) - # Reshape the output tensor. - return output.view(batch_size, seq_len, hidden_size) - - def common_attention_args(self, - block_list=None, - key_cache=None, - value_cache=None): - fsdpa_op = self.fused_scaled_dot_product_attention.apply \ - if self.fused_scaled_dot_product_attention is not None else None - return { - 'scale': self.scale, - 'matmul_qk_op': self.matmul_qk, - 'matmul_av_op': self.matmul_av, - 'batch2block_matmul_op': self.batch2block_matmul, - 'block2batch_matmul_op': self.block2batch_matmul, - 'fsdpa_op': fsdpa_op, - 'keys_fetch_func': self.k_cache.fetch_from_cache, - 'values_fetch_func': self.v_cache.fetch_from_cache, - 'softmax_op': self.softmax, - 'block_list': block_list, - 'key_cache': key_cache, - 'value_cache': value_cache, - } - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - num_kv_heads: int, - dtype: torch.dtype, - seq_len: int, -) -> torch.Tensor: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - # Calculate a matrix where each element represents ith element- jth - # element. - bias = bias[None, :] - bias[:, None] - - padded_len = (seq_len + 7) // 8 * 8 - num_heads = alibi_slopes.shape[0] - bias = torch.empty( - 1, # batch size - num_heads, - seq_len, - padded_len, - device=alibi_slopes.device, - dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) - if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - return bias diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py deleted file mode 100644 index 410ada3b0828..000000000000 --- a/vllm/attention/backends/ipex_attn.py +++ /dev/null @@ -1,403 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" Attention layer with torch scaled_dot_product_attention - and PagedAttention.""" -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch - -from vllm._ipex_ops import ipex_ops -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) -from vllm.logger import init_logger - -logger = init_logger(__name__) - -_PARTITION_SIZE = 512 - - -class IpexAttnBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "IPEX" - - @staticmethod - def get_impl_cls() -> Type["IpexAttnBackendImpl"]: - return IpexAttnBackendImpl - - @staticmethod - def get_metadata_cls() -> Type["IpexAttnMetadata"]: - return IpexAttnMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - from vllm._ipex_ops import ipex_ops as ops - ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - from vllm._ipex_ops import ipex_ops as ops - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -@dataclass -class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for IpexAttnBackend. - """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - is_prompt: bool - slot_mapping: torch.Tensor - seq_lens: Optional[List[int]] - seqlen_q: Optional[torch.Tensor] - max_seqlen: Optional[int] - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[torch.Tensor]] = None - - @property - def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self - - return None - - @property - def decode_metadata(self) -> Optional["IpexAttnMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0 - return None - - return self - - -class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if use_irope: - logger.warning_once( - "Using irope in Ipex is not supported yet, it will fall" - " back to global attention for long context.") - if blocksparse_params is not None: - raise ValueError( - "IPEX backend does not support block-sparse attention.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.need_mask = (self.sliding_window is not None) - if logits_soft_cap is None: - logits_soft_cap = -1 - self.logits_soft_cap = logits_soft_cap - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - if is_quantized_kv_cache(kv_cache_dtype): - raise NotImplementedError( - "IPEX backend does not support FP8 KV cache. " - "Please use xFormers backend instead.") - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "IpexAttnBackendImpl") - - def split_kv_cache( - self, - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - x = 1 - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) - return key_cache, value_cache - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: IpexAttnMetadata, # type: ignore - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with IPEX varlen_attention and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for IpexAttentionImpl") - - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache, value_cache = self.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - ipex_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - layer._k_scale_float, - layer._v_scale_float, - ) - - if attn_metadata.is_prompt: - assert attn_metadata.seq_lens is not None - if (kv_cache.numel() == 0 - or attn_metadata.block_tables.numel() == 0): - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, - dim=1) - - if attn_metadata.attn_bias is None: - if self.sliding_window is not None: - att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore - else: - att_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, None, dtype=query.dtype) - attn_metadata.attn_bias = att_masks - - output = torch.empty( - (num_tokens, self.num_heads, self.head_size), - dtype=query.dtype, - device=query.device) - ipex_ops.varlen_attention( - query, - key, - value, - output, - attn_metadata.seqlen_q, - attn_metadata.seqlen_q, - self.alibi_slopes, - attn_metadata.max_seqlen, - attn_metadata.max_seqlen, - pdropout=0.0, - softmax_scale=self.scale, - zero_tensors=False, - is_causal=True, - return_softmax=False, - gen_=None, - window_size_left=-1, - window_size_right=-1, - logits_soft_cap=self.logits_soft_cap, - ) - else: - # prefix-enabled attention - raise RuntimeError( - "IPEX backend doesn't support prefix decoding.") - - else: - # Decoding run. - max_seq_len = attn_metadata.max_decode_seq_len - output = torch.empty_like(query) - block_size = value_cache.shape[3] - num_seqs, num_heads, head_size = query.shape - max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // - _PARTITION_SIZE) - # NOTE(woosuk): We use a simple heuristic to decide whether to use - # PagedAttention V1 or V2. If the number of partitions is 1, we use - # V1 to avoid the overhead of reduction. Also, if the number of - # sequences or heads is large, we use V1 since there is enough work - # to parallelize. - # TODO(woosuk): Tune this heuristic. - # For context len > 8192, use V2 kernel to avoid shared memory - # shortage. - use_v1 = (max_seq_len <= 8192 and - (max_num_partitions == 1 or num_seqs * num_heads > 512)) - if use_v1: - # Run PagedAttention V1. - ipex_ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - self.num_kv_heads, - self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, - block_size, - max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - layer._k_scale_float, - layer._v_scale_float, - ) - else: - # Run PagedAttention V2. - assert _PARTITION_SIZE % block_size == 0 - tmp_output = torch.empty( - size=(num_seqs, num_heads, max_num_partitions, head_size), - dtype=output.dtype, - device=output.device, - ) - exp_sums = torch.empty( - size=(num_seqs, num_heads, max_num_partitions), - dtype=torch.float32, - device=output.device, - ) - max_logits = torch.empty_like(exp_sums) - ipex_ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - self.num_kv_heads, - self.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens_tensor, - block_size, - max_seq_len, - self.alibi_slopes, - self.kv_cache_dtype, - layer._k_scale_float, - layer._v_scale_float, - ) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[torch.Tensor]: - attn_biases = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)) - bias.mul_(alibi_slopes[:, None, None]) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype, - device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) - attn_biases.append((bias + inf_mask).to(dtype)) - - return attn_biases - - -def _make_sliding_window_bias( - seq_lens: List[int], - window_size: Optional[int], - dtype: torch.dtype, -) -> List[torch.Tensor]: - attn_biases = [] - for seq_len in seq_lens: - tensor = torch.full( - (1, seq_len, seq_len), - dtype=dtype, - fill_value=1, - ) - shift = 0 - mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore - if window_size is not None: - mask = torch.triu(mask, diagonal=shift - window_size + 1) - mask = torch.log(mask) - attn_biases.append(mask.to(dtype)) - - return attn_biases diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0c3ff26d04c8..52c4a9e7da3d 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -997,7 +997,6 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py deleted file mode 100644 index c900666955a3..000000000000 --- a/vllm/attention/backends/pallas.py +++ /dev/null @@ -1,356 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -import torch_xla.experimental.custom_kernel # Required to register custom ops. - -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, AttentionType, - is_quantized_kv_cache) -from vllm.attention.backends.utils import CommonAttentionState -from vllm.logger import init_logger - -logger = init_logger(__name__) - - -class PallasAttentionBackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "PALLAS" - - @staticmethod - def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: - return PallasAttentionBackendImpl - - @staticmethod - def get_metadata_cls() -> Type["PallasMetadata"]: - return PallasMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (num_kv_heads, num_blocks, block_size, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise RuntimeError("swap_blocks is not used for the TPU backend.") - - @torch.compile(backend="openxla") - @staticmethod - def copy_blocks( - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - src_to_dists: Tuple[torch.Tensor, torch.Tensor], - ) -> None: - src_indices, dst_indices = src_to_dists - for k_cache, v_cache in kv_caches: - torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) - k_cache[:, dst_indices] = k_cache[:, src_indices] - torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) - v_cache[:, dst_indices] = v_cache[:, src_indices] - - -@dataclass -class PallasMetadata(AttentionMetadata): - - # Currently, input sequences can only contain all prefills - # or all decoding. - block_tables: Optional[torch.Tensor] = None - context_lens: Optional[torch.Tensor] = None - effective_query_lens: Optional[torch.Tensor] = None - - @property - def prefill_metadata(self) -> Optional["PallasMetadata"]: - if self.num_prefills == 0: - return None - - assert self.num_decode_tokens == 0 - return self - - @property - def decode_metadata(self) -> Optional["PallasMetadata"]: - if self.num_decode_tokens == 0: - return None - - assert self.num_prefills == 0 - assert self.num_prefill_tokens == 0 - assert self.block_tables is not None - assert self.context_lens is not None - return self - - -class PallasAttentionBackendImpl(AttentionImpl): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if use_irope: - logger.warning_once( - "Using irope in Pallas is not supported yet, it will fall back " - "to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.logits_soft_cap = logits_soft_cap - if head_size % 128 != 0: - raise NotImplementedError( - f"Head size must be a multiple of 128, found {head_size}.") - if alibi_slopes is not None: - raise NotImplementedError("Alibi slopes is not supported.") - if sliding_window is not None: - raise NotImplementedError("Sliding window is not supported.") - if is_quantized_kv_cache(kv_cache_dtype): - raise NotImplementedError("FP8 KV cache dtype is not supported.") - if blocksparse_params is not None: - raise NotImplementedError("Blocksparse is not supported.") - - if torch_xla.tpu.version() < 4: - raise NotImplementedError("TPU version must be 4 or higher.") - - self.megacore_mode = None - tpu_env = torch_xla.tpu.get_tpu_env() - tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) - or tpu_env.get("TYPE", None) - or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) - assert tpu_type is not None - tpu_type = tpu_type.lower() - - if (("lite" not in tpu_type) and ("v6" not in tpu_type)): - if self.num_kv_heads % 2 == 0: - self.megacore_mode = "kv_head" - else: - # NOTE(woosuk): If the batch size is not a multiple of 2, the - # megacore mode will be None. - self.megacore_mode = "batch" - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor], - attn_metadata: PallasMetadata, - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with Pallas attention. - - Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] - kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] - NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor - with shape [0] for profiling run. - attn_metadata: Metadata for attention. - Returns: - shape = [batch_size, seq_len, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for PallasAttentionImpl") - - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - batch_size, seq_len, hidden_size = query.shape - query = query.view(batch_size, seq_len, self.num_heads, self.head_size) - key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) - value = value.view(batch_size, seq_len, self.num_kv_heads, - self.head_size) - - if kv_cache[0].numel() > 0: - slot_mapping = attn_metadata.slot_mapping - key_cache, value_cache = kv_cache - write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) - - query = query * self.scale - if attn_metadata.num_prefills > 0: - if attn_metadata.block_tables is None: - # Prefill without paged KV cache. - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") - - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, - dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, - dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, - self.head_size) - # FlashAttention kernel requires the input shape to be - # [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) - else: - # Prefill with paged KV cache. - # TODO(woosuk): Tune the below knobs. - num_kv_pages_per_compute_block = 16 - num_queries_per_compute_block = 16 - assert seq_len % num_queries_per_compute_block == 0 - output = torch.ops.xla.multi_queries_paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.effective_query_lens, - num_kv_pages_per_compute_block, - num_queries_per_compute_block, - use_kernel=True, - attn_logits_soft_cap=self.logits_soft_cap, - ) - else: - # Decoding run. - assert kv_cache[0].numel() > 0 - query = query.squeeze(dim=1) - pages_per_compute_block = 16 # TODO(woosuk): Tune this value. - - assert attn_metadata.block_tables is not None - assert attn_metadata.context_lens is not None - # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire - # block table in SMEM. Therefore, if the block table is too large, - # the kernel compilation will fail. To avoid this, we split the - # batch dimension into smaller chunks and run the kernel multiple - # times. - MAX_SMEM_USAGE = 512 * 1024 - size_per_seq = 4 * attn_metadata.block_tables.shape[1] - max_num_seq = MAX_SMEM_USAGE // size_per_seq - - if batch_size <= max_num_seq: - output = paged_attention( - query, - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - pages_per_compute_block, - self.megacore_mode, - attn_logits_soft_cap=self.logits_soft_cap, - ) - else: - chunk_size = max_num_seq - # Make sure the chunk size is a multiple of 2. - chunk_size = chunk_size // 2 * 2 - num_chunks = (batch_size + chunk_size - 1) // chunk_size - - output = torch.empty_like(query) - for chunk_idx in range(num_chunks): - chunk_start = chunk_idx * chunk_size - chunk_end = chunk_start + chunk_size - # NOTE(woosuk): We skip this line because it causes Dynamo - # compilation error. Instead, we rely on the slice operation - # to handle the out-of-bound case. - # chunk_end = min(chunk_end, batch_size) - chunk_output = paged_attention( - query[chunk_start:chunk_end], - key_cache, - value_cache, - attn_metadata.context_lens[chunk_start:chunk_end], - attn_metadata.block_tables[chunk_start:chunk_end], - pages_per_compute_block, - self.megacore_mode, - attn_logits_soft_cap=self.logits_soft_cap, - ) - output[chunk_start:chunk_end] = chunk_output - - # Reshape the output tensor. - return output.reshape(batch_size, seq_len, hidden_size) - - -def write_to_kv_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, -) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) - torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) - - key = key.flatten(0, 2) - value = value.flatten(0, 2) - key_cache = key_cache.flatten(0, 2) - value_cache = value_cache.flatten(0, 2) - key_cache.index_copy_(0, slot_mapping, key) - value_cache.index_copy_(0, slot_mapping, value) - - -def paged_attention( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - pages_per_compute_block: int, - megacore_mode: Optional[str], - *, - attn_logits_soft_cap: Optional[float], -) -> torch.Tensor: - batch_size = query.shape[0] - if megacore_mode == "batch" and batch_size % 2 != 0: - megacore_mode = None - else: - megacore_mode = megacore_mode - - return torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - megacore_mode=megacore_mode, - attn_logits_soft_cap=attn_logits_soft_cap, - ) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 1edf34351db3..a165a786d63d 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Type, Union +from typing import TYPE_CHECKING, Optional, Type, Union import torch @@ -367,7 +367,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -375,17 +374,14 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") from aiter import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 1e2c21f4e69d..1ee1dea729d9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -4,7 +4,7 @@ import itertools from dataclasses import dataclass from functools import cache -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple, Type import torch @@ -19,6 +19,8 @@ PagedAttentionMetadata) from vllm.config import get_current_vllm_config from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from vllm.platforms.rocm import use_rocm_custom_paged_attention @@ -492,21 +494,18 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") + raise NotImplementedError("KV sharing is not supported in V0 " + "ROCM_FLASH backend.") if use_irope: logger.warning_once( "Using irope in ROCm Flash Attention is not supported yet, it " "will fail back to global attention for long context.") - if blocksparse_params is not None: - raise ValueError( - "ROCmFlashAttention does not support blocksparse attention.") if use_irope: logger.warning( "Using irope in V0 is not supported yet, it will fall back " @@ -598,10 +597,10 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim)) def fused_output_quant_supported(self, dtype: torch.dtype, static: bool, - group_shape: tuple[int, int]): + group_shape: GroupShape): if self.use_triton_flash_attn: return dtype == current_platform.fp8_dtype( - ) and static and group_shape == (-1, -1) # per-tensor + ) and static and group_shape == GroupShape.PER_TENSOR # Only supported in the Triton backend return False diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py deleted file mode 100644 index af5fe81dc883..000000000000 --- a/vllm/attention/backends/torch_sdpa.py +++ /dev/null @@ -1,707 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" Attention layer with torch scaled_dot_product_attention - and PagedAttention.""" -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type - -import torch -from torch.nn.functional import scaled_dot_product_attention - -# yapf conflicts with isort for this block -# yapf: disable -from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionLayer, - AttentionMetadata, - AttentionMetadataBuilder, - AttentionType, - is_quantized_kv_cache) -# yapf: enable -from vllm.attention.backends.utils import CommonAttentionState -from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex -from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.logger import init_logger -from vllm.utils import make_tensor_with_pad -from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder - -logger = init_logger(__name__) - - -class TorchSDPABackend(AttentionBackend): - - @staticmethod - def get_name() -> str: - return "TORCH_SDPA" - - @staticmethod - def get_impl_cls() -> Type["TorchSDPABackendImpl"]: - return TorchSDPABackendImpl - - @staticmethod - def get_metadata_cls() -> Type["AttentionMetadata"]: - return TorchSDPAMetadata - - @staticmethod - def get_state_cls() -> Type["CommonAttentionState"]: - return CommonAttentionState - - @staticmethod - def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: - return TorchSDPAMetadataBuilder - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise NotImplementedError("Swap is not supported in TorchSDPABackend.") - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - ) -> None: - PagedAttention.copy_blocks(kv_caches, src_to_dists) - - -@dataclass -class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): - """Metadata for TorchSDPABackend. - """ - # Currently, input sequences can only contain all prompts - # or all decoding. True if all sequences are prompts. - chunked_prefill: bool - seq_lens: Optional[List[int]] = None # For non-chunked prefill - - # For chunked prefill only - max_query_len: Optional[int] = None - max_kv_len: Optional[int] = None - prefill_query_start_loc: Optional[torch.Tensor] = None - kv_start_loc: Optional[torch.Tensor] = None - prefill_block_tables: Optional[torch.Tensor] = None - - # For V1 logits index only - query_start_loc: Optional[torch.Tensor] = None - - # Begin encoder attn & enc/dec cross-attn fields... - # Encoder sequence lengths representation - encoder_seq_lens: Optional[List[int]] = None - encoder_seq_lens_tensor: Optional[torch.Tensor] = None - - # Maximum sequence length among encoder sequences - max_encoder_seq_len: Optional[int] = None - - # Number of tokens input to encoder - num_encoder_tokens: Optional[int] = None - - # Cross-attention memory-mapping data structures: slot mapping - # and block tables - cross_slot_mapping: Optional[torch.Tensor] = None - cross_block_tables: Optional[torch.Tensor] = None - - def __post_init__(self): - # Set during the execution of the first attention op. - # It is a list because it is needed to set per prompt - # when alibi slopes is used. It is because of the limitation - # from xformer API. - # will not appear in the __repr__ and __init__ - self.attn_bias: Optional[List[torch.Tensor]] = None - self.encoder_attn_bias: Optional[List[torch.Tensor]] = None - self.cross_attn_bias: Optional[List[torch.Tensor]] = None - - @property - def is_all_encoder_attn_metadata_set(self): - ''' - All attention metadata required for encoder attention is set. - ''' - return ((self.encoder_seq_lens is not None) - and (self.encoder_seq_lens_tensor is not None) - and (self.max_encoder_seq_len is not None)) - - @property - def is_all_cross_attn_metadata_set(self): - ''' - All attention metadata required for enc/dec cross-attention is set. - - Superset of encoder attention required metadata. - ''' - return (self.is_all_encoder_attn_metadata_set - and (self.cross_slot_mapping is not None) - and (self.cross_block_tables is not None)) - - @property - def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: - if self.num_prefill_tokens == 0: - return None - return self - - @property - def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: - if self.num_decode_tokens == 0: - return None - return self - - def get_seq_lens( - self, - attn_type: str, - ): - ''' - Extract appropriate sequence lengths from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate sequence lengths tensor for query - * Appropriate sequence lengths tensor for key & value - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - seq_lens_q = self.seq_lens - seq_lens_kv = self.seq_lens - elif attn_type == AttentionType.ENCODER: - seq_lens_q = self.encoder_seq_lens - seq_lens_kv = self.encoder_seq_lens - elif attn_type == AttentionType.ENCODER_DECODER: - seq_lens_q = self.seq_lens - seq_lens_kv = self.encoder_seq_lens - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - return seq_lens_q, seq_lens_kv - - def get_attn_bias( - self, - attn_type: str, - ) -> Optional[List[torch.Tensor]]: - ''' - Extract appropriate attention bias from attention metadata - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - * Appropriate attention bias value given the attention type - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - return self.attn_bias - elif attn_type == AttentionType.ENCODER: - return self.encoder_attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - return self.cross_attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - def set_attn_bias( - self, - attn_bias: List[torch.Tensor], - attn_type: str, - ) -> None: - ''' - Update appropriate attention bias field of attention metadata, - according to attention type. - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * attn_bias: The desired attention bias value - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - self.attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER: - self.encoder_attn_bias = attn_bias - elif attn_type == AttentionType.ENCODER_DECODER: - self.cross_attn_bias = attn_bias - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - def get_seq_len_block_table_args( - self, - attn_type: str, - ) -> tuple: - ''' - The particular choice of sequence-length- and block-table-related - attributes which should be extracted from attn_metadata is dependent - on the type of attention operation. - - Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & - cross-attn block-tables fields - Encoder attn -> select encoder sequence lengths fields & no block tables - - Arguments: - - * attn_metadata: Attention metadata structure associated with attention - * is_prompt: True if prefill, False otherwise - * attn_type: encoder attention, decoder self-attention, - encoder/decoder cross-attention - - Returns: - - * Appropriate sequence-lengths tensor - * Appropriate max sequence-length scalar - * Appropriate block tables (or None) - ''' - - if (attn_type == AttentionType.DECODER - or attn_type == AttentionType.ENCODER_ONLY): - # Decoder self-attention - # Choose max_seq_len based on whether we are in prompt_run - return (self.seq_lens_tensor, self.max_decode_seq_len, - self.block_tables) - elif attn_type == AttentionType.ENCODER_DECODER: - # Enc/dec cross-attention KVs match encoder sequence length; - # cross-attention utilizes special "cross" block tables - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - self.cross_block_tables) - elif attn_type == AttentionType.ENCODER: - # No block tables associated with encoder attention - return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, - None) - else: - raise AttributeError(f"Invalid attention type {str(attn_type)}") - - -class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): - - def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: - self.chunked_prefill = input_builder.chunked_prefill - self.input_builder = input_builder - - def prepare(self): - self.input_data = self.input_builder.input_data - - def build(self, seq_lens: List[int], query_lens: List[int], - cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: - input_data = self.input_data - prefill_seq_lens = seq_lens[0:input_data.num_prefills] - prefill_query_lens = query_lens[0:input_data.num_prefills] - slot_mapping = torch.tensor(input_data.slot_mapping, - dtype=torch.long, - device="cpu") - - # For chunked-prefill - if self.chunked_prefill and input_data.num_prefill_tokens != 0: - prefill_block_tables = make_tensor_with_pad( - self.input_data.prefill_block_tables, - pad=0, - dtype=torch.int32, - device="cpu", - ) - query_lens_tensor = torch.tensor(prefill_query_lens, - dtype=torch.int32, - device="cpu") - kv_lens_tensor = torch.tensor(prefill_seq_lens, - dtype=torch.int32, - device="cpu") - query_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - kv_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - torch.cumsum(query_lens_tensor, - dim=0, - dtype=torch.int32, - out=query_start_loc[1:]) - torch.cumsum(kv_lens_tensor, - dim=0, - dtype=torch.int32, - out=kv_start_loc[1:]) - max_query_len = max(prefill_query_lens) - max_kv_len = max(prefill_seq_lens) - else: - prefill_block_tables = None - query_start_loc = None - kv_start_loc = None - max_query_len = None - max_kv_len = None - - # For paged attention - if input_data.num_decode_tokens != 0: - seq_lens_tensor = torch.tensor( - input_data.seq_lens[input_data.num_prefills:], - dtype=torch.int32, - device="cpu", - ) - block_tables = make_tensor_with_pad( - self.input_data.decode_block_tables, - pad=0, - dtype=torch.int32, - device="cpu", - ) - else: - block_tables = torch.tensor([]) - seq_lens_tensor = torch.tensor( - input_data.seq_lens[:input_data.num_prefills], - dtype=torch.int32, - device="cpu", - ) - - # For multi-modal models - placeholder_index_maps = None - if len(input_data.multi_modal_inputs_list) != 0: - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - input_data.multi_modal_placeholder_maps.items() - } - - attn_metadata = TorchSDPAMetadata( - chunked_prefill=self.chunked_prefill, - seq_lens=prefill_seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_query_len=max_query_len, - max_kv_len=max_kv_len, - prefill_query_start_loc=query_start_loc, - kv_start_loc=kv_start_loc, - max_decode_seq_len=input_data.max_decode_seq_len, - num_prefills=input_data.num_prefills, - num_prefill_tokens=input_data.num_prefill_tokens, - num_decode_tokens=input_data.num_decode_tokens, - block_tables=block_tables, - prefill_block_tables=prefill_block_tables, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=False, - ) - - return attn_metadata - - -class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[List[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, - logits_soft_cap: Optional[float] = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, - ) -> None: - if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if blocksparse_params is not None: - raise ValueError( - "Torch SPDA does not support block-sparse attention.") - if logits_soft_cap is not None: - logger.warning_once("Torch SPDA does not support logits soft cap. " - "Outputs may be slightly off.") - if use_irope: - logger.warning_once( - "Using irope in Torch SPDA is not supported yet, it will fall" - " back to global attention for long context.") - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - if alibi_slopes is not None: - alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) - self.alibi_slopes = alibi_slopes - self.sliding_window = sliding_window - self.kv_cache_dtype = kv_cache_dtype - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.need_mask = (self.alibi_slopes is not None - or self.sliding_window is not None) - - supported_head_sizes = PagedAttention.get_supported_head_sizes() - if head_size not in supported_head_sizes: - raise ValueError( - f"Head size {head_size} is not supported by PagedAttention. " - f"Supported head sizes are: {supported_head_sizes}.") - - if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: - raise NotImplementedError( - "Torch SDPA backend FP8 KV cache requires " - "intel_extension_for_pytorch support.") - self.attn_type = attn_type - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: TorchSDPAMetadata, # type: ignore - output: Optional[torch.Tensor] = None, - output_scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass with torch SDPA and PagedAttention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] - NOTE: kv_cache will be an empty tensor with shape [0] - for profiling run. - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for TorchSDPABackendImpl") - - # For warming-up - if attn_metadata is None: - return query - - attn_type = self.attn_type - if (attn_type == AttentionType.ENCODER - and (not attn_metadata.is_all_encoder_attn_metadata_set)): - raise AttributeError("Encoder attention requires setting " - "encoder metadata attributes.") - elif (attn_type == AttentionType.ENCODER_DECODER - and (not attn_metadata.is_all_cross_attn_metadata_set)): - raise AttributeError("Encoder/decoder cross-attention " - "requires setting cross-attention " - "metadata attributes.") - - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - if key is not None: - assert value is not None - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - else: - assert value is None - - if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): - # KV-cache during decoder-self- or - # encoder-decoder-cross-attention, but not - # during encoder attention. - # - # Even if there are no new key/value pairs to cache, - # we still need to break out key_cache and value_cache - # i.e. for later use by paged attention - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_kv_heads, self.head_size) - - if (key is not None) and (value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - # During cross-attention decode, key & value will be None, - # preventing this IF-statement branch from running - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - PagedAttention.write_to_paged_cache( - key, value, key_cache, value_cache, updated_slot_mapping, - self.kv_cache_dtype, layer._k_scale, layer._v_scale) - - if attn_type != AttentionType.ENCODER: - # Decoder self-attention supports chunked prefill. - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - else: - # Encoder attention - chunked prefill is not applicable; - # derive token-count from query shape & and treat them - # as 100% prefill tokens - assert attn_metadata.num_encoder_tokens is not None - num_prefill_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 - - if attn_type == AttentionType.DECODER: - # Only enforce this shape-constraint for decoder - # self-attention - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) - if prefill_meta := attn_metadata.prefill_metadata: - if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore - assert attn_metadata.seq_lens is not None - self._run_sdpa_forward(output, - query, - key, - value, - prefill_meta, - attn_type=attn_type) - else: - # prefix-enabled attention - assert not self.need_mask - import intel_extension_for_pytorch.llm.modules as ipex_modules - output = torch.empty_like(query) - ipex_modules.PagedAttention.flash_attn_varlen_func( - output[:prefill_meta.num_prefill_tokens, :, :], - query[:prefill_meta.num_prefill_tokens, :, :], - key_cache, - value_cache, - prefill_meta.prefill_query_start_loc, - prefill_meta.kv_start_loc, - prefill_meta.max_query_len, - prefill_meta.max_kv_len, - self.scale, - True, - prefill_meta.prefill_block_tables, - self.alibi_slopes, - ) - - if decode_meta := attn_metadata.decode_metadata: - assert attn_type != AttentionType.ENCODER_ONLY, ( - "Encoder-only models should not have decode metadata.") - # Decoding run. - ( - seq_lens_arg, - max_seq_len_arg, - block_tables_arg, - ) = decode_meta.get_seq_len_block_table_args(attn_type) - - PagedAttention.forward_decode( - output[attn_metadata.num_prefill_tokens:, :, :], - query[attn_metadata.num_prefill_tokens:, :, :], - key_cache, - value_cache, - block_tables_arg, - seq_lens_arg, - max_seq_len_arg, - self.kv_cache_dtype, - self.num_kv_heads, - self.scale, - self.alibi_slopes, - layer._k_scale, - layer._v_scale, - ) - - # Reshape the output tensor. - return output.view(-1, self.num_heads * self.head_size) - - def _run_sdpa_forward( - self, - output: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attn_metadata: TorchSDPAMetadata, - attn_type: str = AttentionType.DECODER, - ) -> None: - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=1) - value = value.repeat_interleave(self.num_queries_per_kv, dim=1) - - attn_masks = attn_metadata.get_attn_bias(attn_type) - if attn_masks is None: - if self.alibi_slopes is not None: - attn_masks = _make_alibi_bias( - self.alibi_slopes, query.dtype, - attn_metadata.seq_lens) # type: ignore - elif self.sliding_window is not None: - assert attn_metadata.seq_lens is not None - attn_masks = _make_sliding_window_bias( - attn_metadata.seq_lens, self.sliding_window, - query.dtype) # type: ignore - else: - seq_lens, _ = attn_metadata.get_seq_lens(attn_type) - attn_masks = [None] * len(seq_lens) - attn_metadata.set_attn_bias(attn_masks, attn_type) - - query = query.movedim(0, query.dim() - 2) - key = key.movedim(0, key.dim() - 2) - value = value.movedim(0, value.dim() - 2) - - causal_attn = (attn_type == AttentionType.DECODER) - - seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) - start_q, start_kv = 0, 0 - for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, - attn_masks): - end_q = start_q + seq_len_q - end_kv = start_kv + seq_len_kv - sub_out = scaled_dot_product_attention( - query[None, :, start_q:end_q, :], - key[None, :, start_kv:end_kv, :], - value[None, :, start_kv:end_kv, :], - attn_mask=mask, - dropout_p=0.0, - is_causal=causal_attn and mask is None, - scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) - output[start_q:end_q, :, :] = sub_out - start_q, start_kv = end_q, end_kv - - -def _make_alibi_bias( - alibi_slopes: torch.Tensor, - dtype: torch.dtype, - seq_lens: List[int], -) -> List[torch.Tensor]: - attn_biases: List[torch.Tensor] = [] - for seq_len in seq_lens: - bias = torch.arange(seq_len, dtype=dtype) - # NOTE(zhuohan): HF uses - # `bias = bias[None, :].repeat(seq_len, 1)` - # here. We find that both biases give the same results, but - # the bias below more accurately follows the original ALiBi - # paper. - bias = bias[None, :] - bias[:, None] - - num_heads = alibi_slopes.shape[0] - bias = bias[None, :].repeat((num_heads, 1, 1)) - bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) - inf_mask = torch.empty( - (1, seq_len, seq_len), - dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) - attn_biases.append((bias + inf_mask).to(dtype)) - - return attn_biases - - -def _make_sliding_window_bias( - seq_lens: List[int], - window_size: Optional[int], - dtype: torch.dtype, -) -> List[torch.Tensor]: - attn_biases: List[torch.Tensor] = [] - for seq_len in seq_lens: - tensor = torch.full( - (1, seq_len, seq_len), - dtype=dtype, - fill_value=1, - ) - shift = 0 - mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore - if window_size is not None: - mask = torch.triu(mask, diagonal=shift - window_size + 1) - mask = torch.log(mask) - attn_biases.append(mask.to(dtype)) - - return attn_biases diff --git a/vllm/attention/backends/triton_mla.py b/vllm/attention/backends/triton_mla.py index e06f7d54e342..fba5b5f6bca8 100644 --- a/vllm/attention/backends/triton_mla.py +++ b/vllm/attention/backends/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional, Type import torch @@ -35,7 +35,6 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -43,17 +42,14 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index b583240c73c4..0bc38b414290 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with xFormers and PagedAttention.""" from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import torch from xformers import ops as xops @@ -387,17 +387,14 @@ def __init__( alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, ) -> None: if kv_sharing_target_layer_name is not None: - raise NotImplementedError("KV sharing is not supported in V0.") - if blocksparse_params is not None: - raise ValueError( - "XFormers does not support block-sparse attention.") + raise NotImplementedError("KV sharing is not supported in V0 " + "XFORMERS backend.") if logits_soft_cap is not None: logger.warning_once("XFormers does not support logits soft cap. " "Outputs may be slightly off.") diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index f0ad68b16405..178453ecdc4e 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer.""" -from typing import Any, Dict, List, Optional +from typing import List, Optional import torch import torch.nn as nn @@ -10,18 +10,47 @@ import vllm.envs as envs from vllm.attention import AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target from vllm.config import CacheConfig, get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.platforms import _Backend, current_platform from vllm.utils import direct_register_custom_op -from vllm.v1.attention.backends.utils import validate_kv_sharing_target + +logger = init_logger(__name__) +USE_XFORMERS_OPS = None + + +def check_xformers_availability(): + global USE_XFORMERS_OPS + if USE_XFORMERS_OPS is not None: + return USE_XFORMERS_OPS + + if current_platform.is_cuda() and current_platform.has_device_capability( + 100): + # Xformers FA is not compatible with B200 + USE_XFORMERS_OPS = False + else: + try: + from importlib.util import find_spec + + find_spec("xformers.ops") + USE_XFORMERS_OPS = True + except ImportError: + USE_XFORMERS_OPS = False + + # the warning only needs to be shown once + if not USE_XFORMERS_OPS: + logger.warning("Xformers is not available, falling back.") + + return USE_XFORMERS_OPS class Attention(nn.Module): @@ -45,7 +74,6 @@ def __init__( alibi_slopes: Optional[List[float]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, use_mla: bool = False, @@ -109,6 +137,15 @@ def __init__( self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window + # For v1 we have backend agnostic iRoPE (local chunked attention) + # we have to store the flag on the layer so gpu model runner can + # set KVSpec appropriately (and pop it so it doesnt get passed to + # the backends) + if envs.VLLM_USE_V1: + self.use_irope = extra_impl_args.pop("use_irope", False) + else: + self.use_irope = extra_impl_args.get("use_irope", False) + quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None and not isinstance( @@ -134,12 +171,11 @@ def __init__( kv_cache_dtype, block_size, is_attention_free, - blocksparse_params is not None, use_mla=use_mla) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **extra_impl_args) self.backend = backend_name_to_enum(attn_backend.get_name()) self.dtype = dtype @@ -160,10 +196,6 @@ def __init__( self.attn_type = attn_type if kv_sharing_target_layer_name is not None: - if not envs.VLLM_USE_V1: - raise NotImplementedError( - "Cross-layer KV sharing is not supported in V0.") - validate_kv_sharing_target( prefix, kv_sharing_target_layer_name, @@ -318,6 +350,10 @@ def __init__( _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.PALLAS_VLLM_V1 } else _Backend.TORCH_SDPA + if (self.attn_backend == _Backend.XFORMERS + and not check_xformers_availability()): + self.attn_backend = _Backend.TORCH_SDPA + def forward( self, query: torch.Tensor, diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py deleted file mode 100644 index 05fa9d11f228..000000000000 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ /dev/null @@ -1,433 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -from vllm.triton_utils import tl, triton - - -def blocksparse_flash_attn_varlen_fwd( - q, - k, - v, # (#tokens, n_heads, head_size) - cu_seqlens_k, - cu_seqlens_q, - sm_scale, - sparse_layout, - *, - block_size=64, - q_block_size=None, - max_seqlen=None): - # split q to blocks - - assert isinstance(sparse_layout, (list, tuple)) - - _, n_heads, head_size = q.shape - batch_size = cu_seqlens_k.size(0) - 1 - q_block_size = q_block_size or block_size - - assert q.dim() == k.dim() == v.dim() == 3 - assert q.size(1) % k.size(1) == 0 - assert q.size(2) == k.size(2) - # TODO(linxihui): allow k, v to have different head_size - assert k.shape == v.shape - assert cu_seqlens_k.dim() == 1 - - q_k_ratio = q.size(1) // k.size(1) - - if cu_seqlens_q is None: - if q.size(0) == batch_size: # decoding only - cu_seqlens_q = torch.arange( - 0, - batch_size + 1, - dtype=cu_seqlens_k.dtype, - device=cu_seqlens_k.device, - ) - elif q.size(0) == k.size(0): - cu_seqlens_q = cu_seqlens_k - else: - raise ValueError("cu_seqlens_q must be specified\ - if it mix of prefilling and decoding.") - else: - assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) - - # switch to use cpu to avoid too many kernel launches when iterated over - q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() - k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() - - assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( - "length of q should either be 1 (decoding) or same as k (prefilling).") - - if max_seqlen: - assert k_lens.max() <= max_seqlen - - n_blocks = (q_lens + q_block_size - 1) // q_block_size - - q_batch_ids = torch.tensor( - [i for i, n in enumerate(n_blocks) for _ in range(n)], - dtype=cu_seqlens_q.dtype, - device=cu_seqlens_q.device, - ) - q_start_sids = torch.tensor( - [i * q_block_size for n in n_blocks for i in range(n)], - dtype=cu_seqlens_q.dtype, - device=cu_seqlens_q.device, - ) - - out = q.new_empty(q.shape) - cu_seqlens_q = cu_seqlens_q.contiguous() - cu_seqlens_k = cu_seqlens_k.contiguous() - - layout_crow_indices, layout_col_indices = sparse_layout - block_d = triton.next_power_of_2(head_size) - - decoding_only = (q_lens == 1).all().item() - grid = (len(q_start_sids), n_heads, 1) - - _fwd_kernel_batch_inference[grid]( - q, - k, - v, - out, - sm_scale, - cu_seqlens_q[:-1], - cu_seqlens_q[1:], - cu_seqlens_k[:-1], - cu_seqlens_k[1:], - q_batch_ids, - q_start_sids, - 0, - *q.stride(), - 0, - *k.stride(), - 0, - *v.stride(), - 0, - *out.stride(), - layout_crow_indices, - layout_col_indices, - *layout_crow_indices.stride(), - *layout_col_indices.stride(), - q_k_ratio, - HAS_BATCH_DIM=False, - D_HEAD=head_size, - BLOCK_M=q_block_size, - BLOCK_N=block_size, - BLOCK_D=block_d, - BLOCK_M_LOADING=(16 if decoding_only else - q_block_size), # smaller for decoding - EVEN_D=block_d == head_size, - num_warps=1 if decoding_only else 4, - num_stages=3) - - return out - - -@triton.jit -def _fwd_kernel_inner( - acc, - l_i, - m_i, - q, - Q, - k_block_col_idx, - layout_col_ptr, - layout_col_stride_h, - layout_col_stride_m, - k_ptrs, - v_ptrs, - off_h, - offs_m, - offs_n, - offs_d, - stride_kt, - stride_vt, - sm_scale, - k_seqlen, - past_len, - LAST_K_BLOCK: tl.constexpr, - BLOCK_M_LOADING: tl.constexpr, - BLOCK_N: tl.constexpr, - D_HEAD: tl.constexpr, - EVEN_D: tl.constexpr, - M_LT_N: tl.constexpr, -): - k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + - k_block_col_idx * layout_col_stride_m).to(tl.int32) - start_n = k_block_id * BLOCK_N - if LAST_K_BLOCK: - if EVEN_D: - k = tl.load( - k_ptrs + start_n * stride_kt, - mask=offs_n[None, :] + start_n < k_seqlen, - other=0.0, - ) - else: - k = tl.load( - k_ptrs + start_n * stride_kt, - mask=(offs_n[None, :] + start_n < k_seqlen) & - (offs_d[:, None] < D_HEAD), - other=0.0, - ) - else: - if EVEN_D: - k = tl.load(k_ptrs + start_n * stride_kt) - else: - k = tl.load(k_ptrs + start_n * stride_kt, - mask=offs_d[:, None] < D_HEAD, - other=0.0) - - qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - - # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N - if LAST_K_BLOCK | M_LT_N: - qk += tl.where( - offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), - 0, - float("-inf"), - ) - - # flash-attn2 - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - p = tl.math.exp2(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - # update m_i - m_i = m_ij - l_i = l_i * alpha + l_ij - - p = p.to(Q.dtype.element_ty) - # update acc - if LAST_K_BLOCK: - if EVEN_D: - v = tl.load( - v_ptrs + start_n * stride_vt, - mask=offs_n[:, None] + start_n < k_seqlen, - other=0.0, - ) - else: - v = tl.load( - v_ptrs + start_n * stride_vt, - mask=(offs_n[:, None] + start_n < k_seqlen) & - (offs_d[None, :] < D_HEAD), - other=0.0, - ) - else: - if EVEN_D: - v = tl.load(v_ptrs + start_n * stride_vt) - else: - v = tl.load(v_ptrs + start_n * stride_vt, - mask=offs_d[None, :] < D_HEAD, - other=0.0) - - acc += tl.dot(p, v) - - return acc, l_i, m_i - - -@triton.heuristics({ - "M_LT_N": - lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], -}) -@triton.jit -def _fwd_kernel_batch_inference( - Q, - K, - V, - Out, - sm_scale, - q_batch_starts, - q_batch_ends, - k_batch_starts, - k_batch_ends, - q_batch_ids, - q_start_sids, - stride_qb, - stride_qt, - stride_qh, - stride_qd, - stride_kb, - stride_kt, - stride_kh, - stride_kd, - stride_vb, - stride_vt, - stride_vh, - stride_vd, - stride_ob, - stride_ot, - stride_oh, - stride_od, - layout_crow_ptr, - layout_col_ptr, - layout_crow_stride_h, - layout_crow_stride_m, - layout_col_stride_h, - layout_col_stride_m, - q_k_ratio, - HAS_BATCH_DIM: tl.constexpr, - D_HEAD: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D: tl.constexpr, - BLOCK_M_LOADING: tl.constexpr, - EVEN_D: tl.constexpr, - M_LT_N: tl.constexpr, -): - """ - NOTATION: - pid: position id - sid: storage id - sbid: storage block id - pbid: position block id - offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) - - TODO(linxihui): - Optimize grouped-attn - """ - off_zm = tl.program_id(0) - off_h = tl.program_id(1) - - off_h_for_kv = off_h // q_k_ratio - - if HAS_BATCH_DIM: - off_z = tl.program_id(2) - Q += off_z * stride_qb - K += off_z * stride_kb - V += off_z * stride_vb - Out += off_z * stride_ob - start_m = off_zm - q_start_sid = start_m * BLOCK_M # always 0 for decoding - else: - off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] - q_start_sid = tl.load(q_start_sids + off_zm) - start_m = q_start_sid // BLOCK_M # q_sbid - - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_D) - - q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) - q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start - k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) - k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start - past_len = k_seqlen - q_seqlen - - Q += q_cu_start * stride_qt + off_h * stride_qh - K += k_cu_start * stride_kt + off_h_for_kv * stride_kh - V += k_cu_start * stride_vt + off_h_for_kv * stride_vh - Out += q_cu_start * stride_ot + off_h * stride_oh - - q_pbid = (past_len + q_start_sid) // BLOCK_M - - if EVEN_D: - q = tl.load( - Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, - mask=offs_m[:, None] < q_seqlen, - other=0.0, - ) - else: - q = tl.load( - Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, - mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - other=0.0, - ) - - sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + - q_pbid * layout_crow_stride_m) - - # TODO(linxihui): load at once, with any Triton version - # that supports `tl.split`, e.g., Triton 3.0 - k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) - k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) - - m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) - - k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd - v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd - - sm_scale *= ( - 1.44269504 # 1/log2 as we use base2 for exponential and logarithm - ) - - for k_block_col_idx in range(k_block_start, k_block_end - 1): - acc, l_i, m_i = _fwd_kernel_inner( - acc, - l_i, - m_i, - q, - Q, - k_block_col_idx, - layout_col_ptr, - layout_col_stride_h, - layout_col_stride_m, - k_ptrs, - v_ptrs, - off_h, - offs_m, - offs_n, - offs_d, - stride_kt, - stride_vt, - sm_scale, - k_seqlen, - past_len, - False, - BLOCK_M_LOADING, - BLOCK_N, - D_HEAD, - EVEN_D, - M_LT_N, - ) - - acc, l_i, m_i = _fwd_kernel_inner( - acc, - l_i, - m_i, - q, - Q, - k_block_end - 1, - layout_col_ptr, - layout_col_stride_h, - layout_col_stride_m, - k_ptrs, - v_ptrs, - off_h, - offs_m, - offs_n, - offs_d, - stride_kt, - stride_vt, - sm_scale, - k_seqlen, - past_len, - True, - BLOCK_M_LOADING, - BLOCK_N, - D_HEAD, - EVEN_D, - M_LT_N, - ) - - # flash-attn 2 - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - - # write output - if EVEN_D: - tl.store( - Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, - acc, - mask=offs_m[:, None] < q_seqlen, - ) - else: - tl.store( - Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, - acc, - mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py deleted file mode 100644 index c6f6cc29793f..000000000000 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math - -import torch - -from vllm.platforms import current_platform - -from .utils import (dense_to_crow_col, get_head_sliding_step, - get_sparse_attn_mask) - -IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) - -if IS_COMPUTE_8_OR_ABOVE: - from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd - - -class LocalStridedBlockSparseAttn(torch.nn.Module): - - def __init__( - self, - n_heads, - max_seqlen, - local_blocks, - vert_stride, - block_size, - device=None, - dtype=None, - homo_head=False, - active_head_range=None, - q_block_size=None, - use_spda=None, - ): - super().__init__() - if use_spda is None: - use_spda = current_platform.is_rocm() or \ - current_platform.is_cpu() or not \ - IS_COMPUTE_8_OR_ABOVE - device = device or (torch.cuda.current_device() - if current_platform.is_cuda_alike() else "cpu") - device = torch.device(device) - # NOTE: vllm CPU backend support BF16 instead of FP16. - dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE - or device.type == "cpu" else torch.half) - - self.n_heads = n_heads - self.max_seqlen = max_seqlen - self.local_blocks = local_blocks - self.vert_stride = vert_stride - self.use_spda = use_spda - self.dtype = dtype - self.device = device - self.block_size = block_size - self.q_block_size = q_block_size - self.homo_head = homo_head - self.active_head_range = active_head_range - self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, - homo_head) - - sparse_layout, sparse_pattern, self.dense_attn_mask = ( - self.get_attn_pattern(dtype, device)) - - if q_block_size is not None and q_block_size != block_size: - if q_block_size > block_size: - assert q_block_size % block_size == 0 - blocks_to_merge = q_block_size // block_size - shape = sparse_pattern.shape - sparse_pattern = sparse_pattern.view(shape[0], -1, - blocks_to_merge, - shape[-1]) - sparse_pattern = sparse_pattern.sum(2) - sparse_layout = dense_to_crow_col(sparse_pattern) - else: - raise ValueError( - "Does not support smaller q_block_size. It will be slower." - ) - - self.sparse_layout = sparse_layout - - def get_attn_pattern(self, dtype, device): - sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( - self.n_heads, - self.max_seqlen, - self.max_seqlen, - dtype, - device, - block_size=self.block_size, - local_blocks=self.local_blocks, - vert_stride=self.vert_stride, - homo_head=self.homo_head, - return_dense=self.use_spda, - dense_mask_type="bias", - ) - if (not self.homo_head) and (self.active_head_range is not None): - assert isinstance(self.active_head_range, tuple) - assert (len(self.active_head_range) == 2) - h_start, h_end = self.active_head_range - sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) - if self.use_spda: - dense_attn_mask = dense_attn_mask[h_start:h_end] - return sparse_layout, sparse_pattern, dense_attn_mask - - def varlen_attn(self, - q, - k, - v, - cu_seqlens_k, - cu_seqlens_q=None, - sm_scale=None): - """ - q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). - Support grouped attention, with `q[:, i*r:(i*r + r)]` - is correspondent to `k[:, i]`, where `r` is the q/k ratio. - cu_seqlens_k: shape=(batch_size + 1,), - indicating segment of samples, - e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i - cu_seqlens_q: shape=(batch_size + 1, ). - Default None: same as cu_seqlens_k for prefilling or - [0, 1, .., batch_size] for decoding. - The only case you need to specify is when q is a mix of - prefilling and decoding. - sm_scale: softmax scale, default to 1/sqrt(head_size). - - return: tensor of shape as q. - """ - assert ( - IS_COMPUTE_8_OR_ABOVE - ), "Requires compute capability of 8 or above (Ampere or newer) to use \ - Triton kernel." - - sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) - - return blocksparse_flash_attn_varlen_fwd( - q, - k, - v, - cu_seqlens_k, - cu_seqlens_q, - sm_scale, - self.sparse_layout, - block_size=self.block_size, - q_block_size=self.q_block_size, - max_seqlen=self.max_seqlen, - ) - - @staticmethod - def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): - """ - :param x: (total_tokens, n_heads, head_size) - :return: (batch, n_heads, length, head_size) - """ - x_padded = x.new_empty( - len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) - cu_seqlens = cu_seqlens.cpu() - for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): - x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, - 1).unsqueeze(1)) - return x_padded.flatten(1, 2) - - @staticmethod - def transpose_and_unpad(x_padded, cu_seqlens): - """ - :param x_padded: (batch, n_heads, length, head_size) - :return: (total_tokens, n_heads, head_size) - """ - cu_seqlens = cu_seqlens.cpu() - total_n_tokens = cu_seqlens[-1] - x = x_padded.new_empty(total_n_tokens, x_padded.size(1), - x_padded.size(3)) - for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): - x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) - return x - - def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): - """For CPU, V100 or other older GPUs. - NOTE: torch SPDA supports nested tensor, - but seems extremely slow. Choose to pad instead. - """ - assert (cu_seqlens_q is None or - (cu_seqlens_q - == cu_seqlens_k).all()), "Can only handle prompt with SPDA." - assert q.size(0) == k.size(0), "can only handle prompt with SPDA." - - assert q.size(1) % k.size(1) == 0 - q_k_ratio = q.size(1) // k.size(1) - sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) - cu_seqlens = cu_seqlens_k.cpu() - maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() - - if (self.dense_attn_mask.dtype != q.dtype - or self.dense_attn_mask.device != q.device): - _, _, self.dense_attn_mask = self.get_attn_pattern( - q.dtype, q.device) - attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] - - q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) - k2, v2 = (self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) - for x in [k, v]) - spda_output = torch.nn.functional.scaled_dot_product_attention( - q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) - return self.transpose_and_unpad(spda_output, cu_seqlens) - - def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): - """Dispatch to `varlen_attn` (Ampere or newer) or - `self.spda`(cpu, Volta, Turing or older)based on - the type of device used and cuda compute capability. - - q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). - Support grouped attention, with `q[:, i*r:(i*r + r)]` - is correspondent to `k[:, i]`, where `r` is the q/k ratio. - cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, - e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i - cu_seqlens_q: shape=(batch_size + 1, ). - Default None: same as cu_seqlens_k for prefilling or - [0, 1, .., batch_size] for decoding. - The only case you need to specify - is when q is a mix of prefilling - and decoding. - sm_scale: softmax scale, default to 1/sqrt(head_size). - - return: tensor of shape as q. - """ - assert k.dim() == 3 - if self.use_spda: - return self.spda( - q, - k, - v, - cu_seqlens_k, - cu_seqlens_q=cu_seqlens_q, - sm_scale=sm_scale, - ) - return self.varlen_attn(q, - k, - v, - cu_seqlens_k, - cu_seqlens_q=cu_seqlens_q, - sm_scale=sm_scale) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py deleted file mode 100644 index 445720c709c4..000000000000 --- a/vllm/attention/ops/blocksparse_attention/utils.py +++ /dev/null @@ -1,246 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Helper functions for 3D sparse pattern -# These function are not optimized and very inefficient. -# Avoid calling them too frequent or use a cache mechanism. - -from functools import lru_cache - -import numpy as np -import torch - -from vllm.triton_utils import triton - - -class csr_matrix: - """Simple implementation of CSR matrix conversion without scipy. - This replaced scipy.sparse.csr_matrix() previously used.""" - - def __init__(self, input_array): - if not isinstance(input_array, np.ndarray): - raise ValueError("Input must be a NumPy array") - - self.shape = input_array.shape - rows, cols = self.shape - data = [] - indices = [] - indptr = [0] - - for i in range(rows): - for j in range(cols): - if input_array[i, j]: - data.append(input_array[i, j]) - indices.append(j) - indptr.append(len(indices)) - - self.data = np.array(data) - self.indices = np.array(indices) - self.indptr = np.array(indptr) - - -def dense_to_crow_col(x: torch.Tensor): - """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. - NOTE: col_indices padded -1 - """ - device = x.device - pad = -1 - dim = x.dim() - assert x.dim() in (2, 3) - if x.dim() == 2: - x = x[None] - x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] - crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) - cols = [torch.from_numpy(xi.indices) for xi in x] - max_cols = max(len(xi) for xi in cols) - cols = [ - torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) - for xi in cols - ] - cols = torch.vstack(cols) - if dim == 2: - crows = crows[0] - cols = cols[0] - return crows.to(device), cols.to(device) - - -def crow_col_to_dense(crows: torch.Tensor, - cols: torch.Tensor, - dtype: torch.dtype = torch.float16): - dim = crows.dim() - if dim == 1: - crows = crows[None] - cols = cols[None] - device = crows.device - crows, cols = crows.cpu(), cols.cpu() # faster in cpu - shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) - x = torch.zeros(shape, dtype=dtype) - for i in range(shape[0]): - for j in range(shape[1]): - x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 - if dim == 1: - x = x[0] - return x.to(device) - - -def dense_to_ccol_row(x: torch.Tensor): - """Similar, but to CSC format""" - x = x.transpose(-2, -1) - return dense_to_crow_col(x) - - -def ccol_row_to_dense(ccol: torch.Tensor, - rows: torch.Tensor, - dtype: torch.dtype = torch.float16): - return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() - - -def _get_sparse_attn_mask_homo_head( - q_len: int, - max_seqlen: int, - dtype: torch.dtype, - device: torch.device, - block_size: int = 128, - local_blocks: int = 4, - vert_stride: int = 4, - return_dense: bool = False, -): - """ - :return: a tuple of 3: - - tuple of crow_indices, col_indices representation - of CSR format. - - block dense mask - - all token dense mask (be aware that it can be - OOM if it is too big) if `return_dense==True`, - otherwise, None - """ - with torch.no_grad(): - num_blocks = triton.cdiv(max_seqlen, block_size) - q_pos = torch.arange(num_blocks)[:, None] - k_pos = torch.arange(num_blocks)[None] - mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 - block_mask_dense = (((q_pos >= k_pos) - & ((q_pos - k_pos < local_blocks) - | mask_vert_strided)).to(device).to(dtype)) - num_blocks_q = triton.cdiv(q_len, block_size) - block_mask_dense_output = (dense_to_crow_col( - block_mask_dense[-num_blocks_q:].contiguous())) - if return_dense: - mask_dense = torch.kron( - block_mask_dense, - block_mask_dense.new_ones((block_size, block_size)), - ) - causal_mask = torch.tril(torch.ones( - max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] - mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask - return ( - block_mask_dense_output, - block_mask_dense, - mask_dense, - ) - else: - return ( - block_mask_dense_output, - block_mask_dense, - None, - ) - - -def binary_mask_to_bias(mask_dense: torch.Tensor): - mask_dense = 1 - mask_dense - mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) - return mask_dense - - -def get_head_sliding_step(n_heads: int, - vert_stride: int, - homo_head: bool = False): - if homo_head: - return 0 - return max(1, int(vert_stride / n_heads)) - - -@lru_cache -def get_sparse_attn_mask( - n_heads: int, - q_len: int, - max_seqlen: int, - dtype: torch.dtype, - device: torch.device, - block_size: int = 64, - local_blocks: int = 4, - vert_stride: int = 4, - homo_head: bool = True, - return_dense: bool = False, - dense_mask_type: str = "binary", -): - """ - :param dense_mask_type: "binary" (0 for skip token, 1 for others) - or "bias" (-inf for skip token, 0 or others) - :return: a tuple of 3: - - tuple of crow_indices, col_indices representation - of CSR format. - - block dense mask - - all token dense mask (be aware that it can be OOM if it - is too big) if `return_dense==True`, otherwise, None - """ - assert dense_mask_type in ("binary", "bias") - if homo_head: - with torch.no_grad(): - (crow, col), block_mask_dense, mask_dense = ( - _get_sparse_attn_mask_homo_head( - q_len, - max_seqlen, - dtype, - device, - block_size, - local_blocks, - vert_stride, - return_dense, - )) - crow = crow[None].expand(n_heads, crow.shape[0]) - col = col[None].expand(n_heads, col.shape[0]) - if return_dense: - mask_dense = mask_dense[None].expand(n_heads, - *mask_dense.shape) - if dense_mask_type == "bias": - mask_dense = binary_mask_to_bias(mask_dense) - return (crow, col), block_mask_dense, mask_dense - - with torch.no_grad(): - num_blocks = triton.cdiv(max_seqlen, block_size) - q_pos = torch.arange(num_blocks)[None, :, None] - k_pos = torch.arange(num_blocks)[None, None] - head_sliding_step = get_head_sliding_step(n_heads, vert_stride) - mask_vert_strided = [ - (torch.arange(num_blocks) + h * head_sliding_step + 1) % - vert_stride == 0 for h in range(n_heads) - ] - mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) - block_mask_dense = (((q_pos >= k_pos) - & ((q_pos - k_pos < local_blocks) - | mask_vert_strided)).to(device).to(dtype)) - num_blocks_q = triton.cdiv(q_len, block_size) - block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] - if return_dense: - mask_dense = torch.kron( - block_mask_dense, - block_mask_dense.new_ones((block_size, block_size)), - ) - causal_mask = torch.tril(torch.ones( - max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] - mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] - if dense_mask_type == "bias": - mask_dense = binary_mask_to_bias(mask_dense) - - return ( - dense_to_crow_col(block_mask_dense_output), - block_mask_dense, - mask_dense, - ) - else: - return ( - dense_to_crow_col(block_mask_dense_output), - block_mask_dense, - None, - ) diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py deleted file mode 100644 index 412dd20ec1de..000000000000 --- a/vllm/attention/ops/hpu_paged_attn.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -############################################################################### - -from dataclasses import dataclass -from typing import List, Optional, Tuple - -import torch -from vllm_hpu_extension import cache_ops, ops - -# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 - - -@dataclass -class HPUPagedAttentionMetadata: - """Metadata for PagedAttention.""" - block_list: Optional[torch.Tensor] - block_mapping: Optional[torch.Tensor] - block_usage: Optional[torch.Tensor] - block_indices: Optional[torch.Tensor] - block_offsets: Optional[torch.Tensor] - block_groups: Optional[torch.Tensor] - - -class HPUPagedAttention: - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [64, 80, 96, 112, 128, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - ) -> Tuple[int, ...]: - return (num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - return key_cache, value_cache - - @staticmethod - def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, kv_cache_dtype: str, - is_prompt: bool) -> None: - cache_ops.reshape_and_cache(key, value, key_cache, value_cache, - slot_mapping, kv_cache_dtype, is_prompt) - - @staticmethod - def forward_decode(**kwargs) -> torch.Tensor: - return ops.flat_pa(**kwargs) - - @staticmethod - def swap_blocks( - src_kv_cache: Tuple[torch.Tensor, torch.Tensor], - dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], - src_to_dsts: torch.Tensor, - ) -> None: - src_key_cache = src_kv_cache[0] - dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts) - - src_value_cache = src_kv_cache[1] - dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts) - - @staticmethod - def copy_blocks( - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - src_to_dsts: torch.Tensor, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py deleted file mode 100644 index 891975498916..000000000000 --- a/vllm/attention/ops/ipex_attn.py +++ /dev/null @@ -1,195 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Tuple - -try: - import intel_extension_for_pytorch.llm.modules as ipex_modules - _use_ipex = True -# AttributeError is to handle a bug in ipex https://github.com/intel/intel-extension-for-pytorch/pull/813 -except (ImportError, AttributeError): - _use_ipex = False - -import torch - -from vllm import _custom_ops as ops - - -class _PagedAttention: - - @staticmethod - def get_supported_head_sizes() -> List[int]: - return [32, 64, 80, 96, 112, 128, 192, 256] - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - *args, - ) -> Tuple[int, ...]: - return 2, num_blocks, block_size * num_kv_heads * head_size - - @staticmethod - def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - *args, - ) -> Tuple[torch.Tensor, torch.Tensor]: - x = 16 // kv_cache.element_size() - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, - -1, x) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) - return key_cache, value_cache - - @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - - @staticmethod - def forward_decode( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - tp_rank: int = 0 - blocksparse_local_blocks: int = 0 - blocksparse_vert_stride: int = 0 - blocksparse_block_size: int = 64 - blocksparse_head_sliding_step: int = 0 - block_size = value_cache.shape[3] - - ops.paged_attention_v1( - output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: torch.Tensor, - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) - - -class _IPEXPagedAttention(_PagedAttention): - - @staticmethod - def split_kv_cache( - kv_cache: torch.Tensor, - num_kv_heads: int, - head_size: int, - *args, - ) -> Tuple[torch.Tensor, torch.Tensor]: - num_blocks = kv_cache.shape[1] - - key_cache = kv_cache[0] - key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) - value_cache = kv_cache[1] - value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) - return key_cache, value_cache - - @staticmethod - def write_to_paged_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - ipex_modules.PagedAttention.reshape_and_cache( - key, value, key_cache, value_cache, - slot_mapping.flatten().int()) - - @staticmethod - def forward_decode( - output: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - context_lens: torch.Tensor, - max_context_len: int, - kv_cache_dtype: str, - num_kv_heads: int, - scale: float, - alibi_slopes: Optional[torch.Tensor], - k_scale: torch.Tensor, - v_scale: torch.Tensor, - *args, - ) -> None: - block_size = value_cache.shape[2] - head_mapping = torch.arange( - 0, - num_kv_heads, - device="cpu", - dtype=torch.int32, - ).view(num_kv_heads, - 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() - ipex_modules.PagedAttention.single_query_cached_kv_attention( - output, query.contiguous(), key_cache, value_cache, head_mapping, - scale, block_tables, context_lens, block_size, max_context_len, - alibi_slopes) - - -PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index cce6b4639460..d91cda255ff3 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -6,7 +6,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer def get_aiter_mla_metadata(max_batch_size: int, block_size: int, @@ -93,8 +93,12 @@ def mla_decode_fwd_fake( if current_platform.is_rocm(): + if is_torch_equal_or_newer("2.7.0"): + tags = () + else: + tags = (torch.Tag.needs_fixed_stride_order, ), direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", op_func=mla_decode_fwd_impl, mutates_args=["o"], fake_impl=mla_decode_fwd_fake, - tags=[torch.Tag.needs_fixed_stride_order]) + tags=tags) diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index c65f09523a3c..eb9c4f1c1030 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -8,10 +8,9 @@ # - Thomas Parnell <tpa@zurich.ibm.com> import torch -import triton -import triton.language as tl from vllm.logger import init_logger +from vllm.triton_utils import tl, triton logger = init_logger(__name__) @@ -145,7 +144,19 @@ def kernel_unified_attention_2d( mask=query_mask_1, other=0.0) - num_blocks = cdiv_fn(seq_len, BLOCK_SIZE) + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + ( + BLOCK_M - 1) // num_queries_per_kv + 1 + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles (blocks) that need to be processed to + # cover the longest sequence prefix (due to causal masking, blocks beyond + # this prefix can be skipped) + num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE) # iterate through tiles for j in range(0, num_blocks): diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index df14aea729f3..2e3c8638125f 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -3,6 +3,7 @@ import os from contextlib import contextmanager +from dataclasses import dataclass from functools import cache from typing import Generator, Optional, Union @@ -79,31 +80,61 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -def supports_head_size( +@dataclass(frozen=True) +class _IsSupported: + can_import: bool + head_size: bool + dtype: bool + + def __bool__(self) -> bool: + return self.can_import and self.head_size and self.dtype + + +def is_attn_backend_supported( attn_backend: Union[str, type[AttentionBackend]], head_size: int, -) -> bool: + dtype: torch.dtype, + *, + allow_import_error: bool = True, +) -> _IsSupported: if isinstance(attn_backend, str): try: attn_backend = resolve_obj_by_qualname(attn_backend) except ImportError: - return False + if not allow_import_error: + raise + + return _IsSupported(can_import=False, head_size=False, dtype=False) assert isinstance(attn_backend, type) # TODO: Update the interface once V0 is removed if get_supported_head_sizes := getattr(attn_backend, "get_supported_head_sizes", None): - return head_size in get_supported_head_sizes() - if validate_head_size := getattr(attn_backend, "validate_head_size", None): + is_head_size_supported = head_size in get_supported_head_sizes() + elif validate_head_size := getattr(attn_backend, "validate_head_size", + None): try: validate_head_size(head_size) - return True + is_head_size_supported = True except Exception: - return False + is_head_size_supported = False + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "head size validation") - raise NotImplementedError(f"{attn_backend.__name__} does not support " - "head size validation") + if get_supported_dtypes := getattr(attn_backend, "get_supported_dtypes", + None): + is_dtype_supported = dtype in get_supported_dtypes() + else: + raise NotImplementedError(f"{attn_backend.__name__} does not support " + "dtype validation") + + return _IsSupported( + can_import=True, + head_size=is_head_size_supported, + dtype=is_dtype_supported, + ) def get_attn_backend( @@ -112,7 +143,6 @@ def get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, - is_blocksparse: bool = False, use_mla: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" @@ -126,7 +156,6 @@ def get_attn_backend( kv_cache_dtype=kv_cache_dtype, block_size=block_size, is_attention_free=is_attention_free, - is_blocksparse=is_blocksparse, use_v1=envs.VLLM_USE_V1, use_mla=use_mla, ) @@ -139,16 +168,9 @@ def _cached_get_attn_backend( kv_cache_dtype: Optional[str], block_size: int, is_attention_free: bool, - is_blocksparse: bool = False, use_v1: bool = False, use_mla: bool = False, ) -> type[AttentionBackend]: - if is_blocksparse: - logger.info("Using BlocksparseFlashAttention backend.") - from vllm.attention.backends.blocksparse_attn import ( - BlocksparseFlashAttentionBackend) - return BlocksparseFlashAttentionBackend - # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION if is_attention_free: diff --git a/vllm/spec_decode/__init__.py b/vllm/attention/utils/__init__.py similarity index 100% rename from vllm/spec_decode/__init__.py rename to vllm/attention/utils/__init__.py diff --git a/vllm/attention/utils/kv_sharing_utils.py b/vllm/attention/utils/kv_sharing_utils.py new file mode 100644 index 000000000000..b4ae8bdf4d76 --- /dev/null +++ b/vllm/attention/utils/kv_sharing_utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +def validate_kv_sharing_target(current_layer_name, target_layer_name, + static_forward_context): + error_msg = (f"Specified KV sharing target layer for {current_layer_name} " + f"is not valid: target layer {target_layer_name} ") + + if current_layer_name == target_layer_name: + raise ValueError(error_msg + + "cannot be the same as the current layer.") + + if target_layer_name not in static_forward_context: + from vllm.model_executor.models.utils import extract_layer_index + + # If target layer name is not in the static fwd context, it means either + # a) the target layer does not come BEFORE the current layer, or + # b) the target layer is not an Attention layer that exists in the model + current_layer_idx = extract_layer_index(current_layer_name) + target_layer_idx = extract_layer_index(target_layer_name) + if current_layer_idx <= target_layer_idx: + raise ValueError(error_msg + "must come before the current layer.") + else: + raise ValueError(error_msg + + "is not a valid Attention layer in the model.") + + # Currently KV sharing is only supported between layers of the same type + target_layer_attn_type = static_forward_context[ + target_layer_name].attn_type + expected = static_forward_context[current_layer_name].attn_type + if target_layer_attn_type != expected: + raise ValueError( + error_msg + + f"must be the same type as the current layer ({expected}).") diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index b3688d2340e4..45b58035ebe3 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -481,6 +481,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser): choices=["sharegpt", "burstgpt", "sonnet", "random", "hf", "custom"], help="Name of the dataset to benchmark on.", ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Do not load the dataset in streaming mode.", + ) parser.add_argument( "--dataset-path", type=str, @@ -649,6 +654,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: dataset_class = ASRDataset args.hf_split = "train" + elif args.dataset_path in MLPerfDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MLPerfDataset + args.hf_split = "train" else: supported_datasets = set([ dataset_name for cls in HuggingFaceDataset.__subclasses__() @@ -674,6 +682,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]: dataset_subset=args.hf_subset, dataset_split=args.hf_split, random_seed=args.seed, + no_stream=args.no_stream, ).sample( num_requests=args.num_prompts, tokenizer=tokenizer, @@ -971,6 +980,7 @@ def __init__( self, dataset_path: str, dataset_split: str, + no_stream: bool = False, dataset_subset: Optional[str] = None, **kwargs, ) -> None: @@ -978,6 +988,7 @@ def __init__( self.dataset_split = dataset_split self.dataset_subset = dataset_subset + self.load_stream = not no_stream self.load_data() def load_data(self) -> None: @@ -986,7 +997,7 @@ def load_data(self) -> None: self.dataset_path, name=self.dataset_subset, split=self.dataset_split, - streaming=True, + streaming=self.load_stream, ) self.data = self.data.shuffle(seed=self.random_seed) @@ -1439,3 +1450,82 @@ def sample( ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests + + +# ----------------------------------------------------------------------------- +# MLPerf Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MLPerfDataset(HuggingFaceDataset): + """ + MLPerf Inference Dataset. + + Dataset on HF: + https://huggingface.co/datasets/mgoin/mlperf-inference-llama2-data + https://huggingface.co/datasets/mgoin/mlperf-inference-llama3.1-data + + Each record contains: + - "system_prompt": system role instruction. + - "question": user question. + - "output": reference answer. + + We combine the system prompt and question into a chat-formatted prompt + (using the tokenizer's chat template) and set the expected output length to + the tokenized length of the provided reference answer. + """ + + SUPPORTED_DATASET_PATHS = { + "mgoin/mlperf-inference-llama2-data", + "mgoin/mlperf-inference-llama3.1-data", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list[SampleRequest]: + # Force dynamic output length based on reference completion. + dynamic_output = output_len is None + sampled_requests: list[SampleRequest] = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + + system_prompt = item["system_prompt"] + question = item["question"] + reference_answer = item["output"] + + # Build chat-style prompt using tokenizer template, if available. + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": question}, + ] + prompt_formatted = tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=False + ) + prompt_len = len(tokenizer(prompt_formatted).input_ids) + + # Determine output length from reference answer tokens. + ref_out_len = len( + tokenizer(reference_answer, add_special_tokens=False).input_ids + ) + expected_output_len = ref_out_len if dynamic_output else output_len + + # Validate sequence lengths. + if not is_valid_sequence(prompt_len, expected_output_len): + continue + + sampled_requests.append( + SampleRequest( + prompt=prompt_formatted, + prompt_len=prompt_len, + expected_output_len=expected_output_len, + ) + ) + + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 8b16fea9e3d3..f4506c9ce6f4 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -138,31 +138,54 @@ async def get_request( input_requests = list(input_requests) total_requests = len(input_requests) - request_index = 0 + assert total_requests > 0, "No requests provided." - for request in input_requests: + # Precompute delays among requests to minimize request send laggings + request_rates = [] + delay_ts = [] + for request_index, request in enumerate(input_requests): current_request_rate = _get_current_request_rate(ramp_up_strategy, ramp_up_start_rps, ramp_up_end_rps, request_index, total_requests, request_rate) - - yield request, current_request_rate - - request_index += 1 - + request_rates.append(current_request_rate) if current_request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - - theta = 1.0 / (current_request_rate * burstiness) - - # Sample the request interval from the gamma distribution. - # If burstiness is 1, it follows exponential distribution. - interval = np.random.gamma(shape=burstiness, scale=theta) - # The next request will be sent after the interval. - await asyncio.sleep(interval) + delay_ts.append(0) + else: + theta = 1.0 / (current_request_rate * burstiness) + + # Sample the request interval from the gamma distribution. + # If burstiness is 1, it follows exponential distribution. + delay_ts.append(np.random.gamma(shape=burstiness, scale=theta)) + + # Calculate the cumulative delay time from the first sent out requests. + for i in range(1, len(delay_ts)): + delay_ts[i] += delay_ts[i - 1] + if ramp_up_strategy is None and delay_ts[-1] != 0: + # When ramp_up_strategy is not set, we assume the request rate is fixed + # and all requests should be sent in target_total_delay_s, the following + # logic would re-scale delay time to ensure the final delay_ts + # align with target_total_delay_s. + # + # NOTE: If we simply accumulate the random delta values + # from the gamma distribution, their sum would have 1-2% gap + # from target_total_delay_s. The purpose of the following logic is to + # close the gap for stablizing the throughput data + # from different random seeds. + target_total_delay_s = total_requests / request_rate + normalize_factor = target_total_delay_s / delay_ts[-1] + delay_ts = [delay * normalize_factor for delay in delay_ts] + + start_ts = time.time() + for request_index, request in enumerate(input_requests): + if delay_ts[request_index] > 0: + current_ts = time.time() + sleep_interval_s = start_ts + delay_ts[request_index] - current_ts + if sleep_interval_s > 0: + await asyncio.sleep(sleep_interval_s) + yield request, request_rates[request_index] def calculate_metrics( diff --git a/vllm/collect_env.py b/vllm/collect_env.py index 64172a9bf91d..ee43ad12e8a5 100644 --- a/vllm/collect_env.py +++ b/vllm/collect_env.py @@ -96,25 +96,30 @@ def run(command): """Return (return-code, stdout, stderr).""" shell = True if type(command) is str else False - p = subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=shell) - raw_output, raw_err = p.communicate() - rc = p.returncode - if get_platform() == 'win32': - enc = 'oem' - else: - enc = locale.getpreferredencoding() - output = raw_output.decode(enc) - if command == 'nvidia-smi topo -m': - # don't remove the leading whitespace of `nvidia-smi topo -m` - # because they are meaningful - output = output.rstrip() - else: - output = output.strip() - err = raw_err.decode(enc) - return rc, output, err.strip() + try: + p = subprocess.Popen(command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=shell) + raw_output, raw_err = p.communicate() + rc = p.returncode + if get_platform() == 'win32': + enc = 'oem' + else: + enc = locale.getpreferredencoding() + output = raw_output.decode(enc) + if command == 'nvidia-smi topo -m': + # don't remove the leading whitespace of `nvidia-smi topo -m` + # because they are meaningful + output = output.rstrip() + else: + output = output.strip() + err = raw_err.decode(enc) + return rc, output, err.strip() + + except FileNotFoundError: + cmd_str = command if isinstance(command, str) else command[0] + return 127, '', f"Command not found: {cmd_str}" def run_and_read_all(run_lambda, command): @@ -148,7 +153,7 @@ def get_conda_packages(run_lambda, patterns=None): if patterns is None: patterns = DEFAULT_CONDA_PATTERNS conda = os.environ.get('CONDA_EXE', 'conda') - out = run_and_read_all(run_lambda, "{} list".format(conda)) + out = run_and_read_all(run_lambda, [conda, 'list']) if out is None: return out diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a2bb053cec4a..673fb5866234 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -120,10 +120,15 @@ def load(self, handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] compiled_graph = self.compiler.load(handle, graph, example_inputs, graph_index, runtime_shape) - logger.debug( - "Directly load the %s-th graph for shape %s from %s via " - "handle %s", graph_index, str(runtime_shape), self.compiler.name, - handle) + if runtime_shape is None: + logger.debug( + "Directly load the %s-th graph for dynamic shape from %s via " + "handle %s", graph_index, self.compiler.name, handle) + else: + logger.debug( + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), + self.compiler.name, handle) return compiled_graph def compile(self, @@ -152,9 +157,15 @@ def compile(self, # there can be multiple graphs due to piecewise compilation. now = time.time() elapsed = now - compilation_start_time - logger.info( - "Directly load the compiled graph(s) for shape %s " - "from the cache, took %.3f s", str(runtime_shape), elapsed) + if runtime_shape is None: + logger.info( + "Directly load the compiled graph(s) for dynamic shape " + "from the cache, took %.3f s", elapsed) + else: + logger.info( + "Directly load the compiled graph(s) for shape %s " + "from the cache, took %.3f s", str(runtime_shape), + elapsed) return compiled_graph # no compiler cached the graph, or the cache is disabled, @@ -172,17 +183,28 @@ def compile(self, assert compiled_graph is not None, "Failed to compile the graph" # store the artifact in the cache - if handle is not None: + if not envs.VLLM_DISABLE_COMPILE_CACHE and handle is not None: self.cache[(runtime_shape, graph_index, self.compiler.name)] = handle + compilation_counter.num_cache_entries_updated += 1 self.is_cache_updated = True if graph_index == 0: # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug( - "store the %s-th graph for shape %s from %s via handle %s", - graph_index, str(runtime_shape), self.compiler.name, handle) + if runtime_shape is None: + logger.info( + "Cache the graph for dynamic shape for later use") + else: + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + if runtime_shape is None: + logger.debug( + "Store the %s-th graph for dynamic shape from %s via " + "handle %s", graph_index, self.compiler.name, handle) + else: + logger.debug( + "Store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, + handle) # after compiling the last graph, record the end time if graph_index == num_graphs - 1: @@ -190,7 +212,7 @@ def compile(self, elapsed = now - compilation_start_time compilation_config.compilation_time += elapsed if runtime_shape is None: - logger.info("Compiling a graph for general shape takes %.2f s", + logger.info("Compiling a graph for dynamic shape takes %.2f s", elapsed) else: logger.info("Compiling a graph for shape %s takes %.2f s", @@ -308,7 +330,7 @@ def call_module(self, target: torch.fx.node.Target, i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_general_shape = self.vllm_backend.\ + compiled_graph_for_dynamic_shape = self.vllm_backend.\ compiler_manager.compile( submod, args, @@ -323,7 +345,7 @@ def call_module(self, target: torch.fx.node.Target, self.module.__dict__[target] = piecewise_backend( submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, - compiled_graph_for_general_shape, self.vllm_backend) + compiled_graph_for_dynamic_shape, self.vllm_backend) compilation_counter.num_piecewise_capturable_graphs_seen += 1 diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index f754fc2388b2..0e7961841bd3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,23 +1,41 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from importlib.util import find_spec from typing import Optional import torch import torch._inductor.pattern_matcher as pm import torch.fx as fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group from vllm.config import VllmConfig -from vllm.distributed import get_tp_group +from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.utils import direct_register_custom_op from .vllm_inductor_pass import VllmInductorPass +if find_spec("flashinfer"): + try: + import flashinfer.comm as flashinfer_comm + flashinfer_comm = (flashinfer_comm if hasattr( + flashinfer_comm, "trtllm_allreduce_fusion") else None) + except ImportError: + flashinfer_comm = None +else: + flashinfer_comm = None +from vllm.platforms import current_platform + logger = init_logger(__name__) +ALLREDUCE_OP = torch.ops.vllm.all_reduce.default +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default + class BasePattern: @@ -43,7 +61,8 @@ def pattern(mul: torch.Tensor, mm_weight: torch.Tensor): mm, dim=0, world_size=self.tp_size, - group_name=self.tp.unique_name) + group_name=self.tp.unique_name, + ) return reduce_scatter def replacement(mul: torch.Tensor, mm_weight: torch.Tensor): @@ -79,7 +98,8 @@ def pattern( x, dim=0, world_size=self.tp_size, - group_name=self.tp.unique_name) + group_name=self.tp.unique_name, + ) return torch.ops.aten.mm.default(all_gather, weight) @@ -125,3 +145,343 @@ def __call__(self, graph: fx.Graph): logger.debug("Replaced %s patterns", count) self.dump_graph(graph, "after_async_tp_pass") self.end_and_log() + + +if flashinfer_comm is not None: + _FI_WORKSPACE_TENSOR = None + + MiB = 1024 * 1024 + # Max size of the input tensor per world size + # to use flashinfer fused allreduce + _FI_MAX_SIZES = { + 2: MiB, # 1MB + 4: MiB, # 1MB + 6: MiB // 2, # 512KB + 8: MiB // 2, # 512KB + } + # opt for a more conservative default value + # when world size is not in _FI_MAX_SIZES + _DEFAULT_FI_MAX_SIZE = MiB // 2 + + def call_trtllm_fused_allreduce_norm( + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + norm_out: Optional[torch.Tensor] = None, + ) -> None: + + num_tokens, hidden_size = allreduce_in.shape + element_size = allreduce_in.element_size() + current_tensor_size = num_tokens * hidden_size * element_size + max_fusion_size = max_token_num * hidden_size * element_size + use_flashinfer = current_tensor_size <= min( + _FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE), + max_fusion_size, + ) + + if use_flashinfer: + assert (_FI_WORKSPACE_TENSOR is not None + ), "Flashinfer must be enabled when using flashinfer" + if norm_out is None: + norm_out = allreduce_in + residual_out = residual + else: + # return residual_out as allreduce_out with zeroed residual_in + # as flashinfer does not support rms_norm + # and allreduce_out together + residual_out = allreduce_in + # For the sizes that are smaller than the max size, + # we only use flashinfer one shot allreduce + flashinfer_comm.trtllm_allreduce_fusion( + allreduce_in=allreduce_in, + token_num=allreduce_in.shape[0], + residual_in=residual, + residual_out=residual_out, + norm_out=norm_out, + rms_gamma=rms_gamma, + rms_eps=rms_eps, + world_rank=world_rank, + world_size=world_size, + hidden_dim=allreduce_in.shape[-1], + workspace_ptrs=_FI_WORKSPACE_TENSOR, + launch_with_pdl=launch_with_pdl, + use_oneshot=True, + trigger_completion_at_end=trigger_completion_at_end, + fp32_acc=fp32_acc, + pattern_code=flashinfer_comm.AllReduceFusionPattern. + kARResidualRMSNorm, + allreduce_out=None, + quant_out=None, + scale_out=None, + layout_code=None, + scale_factor=None, + ) + else: + allreduce_out = tensor_model_parallel_all_reduce(allreduce_in) + if norm_out is None: + torch.ops._C.fused_add_rms_norm(allreduce_out, residual, + rms_gamma, rms_eps) + else: + torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, + rms_eps) + allreduce_in.copy_(allreduce_out) + + def call_trtllm_fused_allreduce_norm_fake( + allreduce_in: torch.Tensor, + residual: torch.Tensor, + rms_gamma: torch.Tensor, + rms_eps: float, + world_rank: int, + world_size: int, + launch_with_pdl: bool, + trigger_completion_at_end: bool, + fp32_acc: bool, + max_token_num: int, + norm_out: Optional[torch.Tensor] = None, + ) -> None: + pass + + direct_register_custom_op( + op_name="flashinfer_trtllm_fused_allreduce_norm", + op_func=call_trtllm_fused_allreduce_norm, + mutates_args=[ + "allreduce_in", + "residual", + "norm_out", + ], + fake_impl=call_trtllm_fused_allreduce_norm_fake, + dispatch_key=current_platform.dispatch_key, + ) + flashinfer_trtllm_fused_allreduce_norm = ( + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default) + + +class FlashInferFusedAllReduceParams: + """Parameters for FlashInfer fused allreduce operations.""" + + def __init__( + self, + rank: int, + world_size: int, + use_fp32_lamport: bool = False, + max_token_num: int = 1024, + ): + self.rank = rank + self.world_size = world_size + self.use_fp32_lamport = use_fp32_lamport + self.trigger_completion_at_end = True + self.launch_with_pdl = True + self.fp32_acc = True + self.use_oneshot = False + self.max_token_num = max_token_num + + def get_trtllm_fused_allreduce_kwargs(self): + return { + "world_rank": self.rank, + "world_size": self.world_size, + "launch_with_pdl": self.launch_with_pdl, + "trigger_completion_at_end": self.trigger_completion_at_end, + "fp32_acc": self.fp32_acc, + "max_token_num": self.max_token_num, + } + + +class AllReduceRMSNORMPattern(BasePattern): + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([1, 8, 4], device=self.device, dtype=self.dtype) + rms_result = torch.empty([1, 8, 4], + device=self.device, + dtype=self.dtype) + weight = torch.empty([4], device=self.device, dtype=self.dtype) + + return [input, rms_result, weight] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(input: torch.Tensor, rms_result: torch.Tensor, + weight: torch.Tensor): + all_reduce_output = tensor_model_parallel_all_reduce(input) + rms = auto_functionalized( + RMS_OP, + result=rms_result, + input=all_reduce_output, + weight=weight, + epsilon=self.epsilon, + ) + return rms[1], all_reduce_output + + def replacement(input: torch.Tensor, rms_result: torch.Tensor, + weight: torch.Tensor): + residual = torch.zeros_like(input) + allreduce = auto_functionalized( + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + allreduce_in=input, + residual=residual, + norm_out=rms_result, + rms_gamma=weight, + rms_eps=self.epsilon, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + + return allreduce[3], allreduce[1] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusedAddRMSNormPattern(BasePattern): + + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str, + allreduce_params: FlashInferFusedAllReduceParams, + ): + super().__init__(dtype, device) + self.epsilon = epsilon + self.allreduce_params = allreduce_params + + def get_inputs(self): + input = torch.empty([4, 4], device=self.device, dtype=self.dtype) + residual = torch.empty([4, 4], device=self.device, dtype=self.dtype) + weight = torch.empty([4, 4], device=self.device, dtype=self.dtype) + return [ + residual, + input, + weight, + ] + + def register(self, pm_pass: PatternMatcherPass): + + def pattern(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor): + all_reduce_output = tensor_model_parallel_all_reduce(input) + rms = auto_functionalized( + RMS_ADD_OP, + input=all_reduce_output, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return rms[1], rms[2] + + def replacement(residual: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor): + allreduce = auto_functionalized( + torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default, + allreduce_in=input, + residual=residual, + rms_gamma=weight, + rms_eps=self.epsilon, + norm_out=None, + **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(), + ) + return allreduce[1], allreduce[2] + + pm.register_replacement(pattern, replacement, self.get_inputs(), + pm.fwd_only, pm_pass) + + +class AllReduceFusionPass(VllmInductorPass): + + def __init__(self, config: VllmConfig): + super().__init__(config) + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + return + self.patterns: PatternMatcherPass = PatternMatcherPass( + pass_name="all_reduce_fusion_pass") + if config.model_config is None: + return + self.hidden_dim = config.model_config.get_hidden_size() + self.group = get_tp_group().device_group + rank = get_tensor_model_parallel_rank() + use_fp32_lamport = self.model_dtype == torch.float32 + if flashinfer_comm is None: + logger.warning( + "Flashinfer is not installed or comm module not found, " + "skipping allreduce fusion pass") + return + # Check if the world size is supported + if self.tp_size not in _FI_MAX_SIZES: + logger.warning( + "Flashinfer allreduce fusion is not " + "supported for world size %s", + self.tp_size, + ) + return + + self.ipc_handles, workspace_tensor = ( + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( + tp_rank=rank, + tp_size=self.tp_size, + max_token_num=config.compilation_config.pass_config. + fi_allreduce_fusion_max_token_num, + hidden_dim=self.hidden_dim, + group=self.group, + use_fp32_lamport=use_fp32_lamport, + )) + + global _FI_WORKSPACE_TENSOR + _FI_WORKSPACE_TENSOR = workspace_tensor + self.allreduce_params = FlashInferFusedAllReduceParams( + rank=rank, + world_size=self.tp_size, + use_fp32_lamport=use_fp32_lamport, + max_token_num=config.compilation_config.pass_config. + fi_allreduce_fusion_max_token_num, + ) + + for epsilon in [1e-5, 1e-6]: + AllReduceRMSNORMPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + AllReduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + self.allreduce_params, + ).register(self.patterns) + + self.disabled = False + + def __call__(self, graph: fx.Graph): + if self.disabled: + return + self.begin() + self.dump_graph(graph, "before_all_reduce_fusion_pass") + count = self.patterns.apply(graph) + logger.debug("Replaced %s patterns", count) + self.dump_graph(graph, "after_all_reduce_fusion_pass") + self.end_and_log() + + def __del__(self): + if self.disabled: + return + if flashinfer_comm is not None: + flashinfer_comm.trtllm_destroy_ipc_workspace( + self.ipc_handles, self.group) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index fd39a6127d00..7158fd685964 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -213,7 +213,9 @@ def compile( # Save the compiled artifact to disk in the specified path assert key is not None path = os.path.join(self.cache_dir, key) - compiled_graph.save(path=path, format="unpacked") + if not envs.VLLM_DISABLE_COMPILE_CACHE: + compiled_graph.save(path=path, format="unpacked") + compilation_counter.num_compiled_artifacts_saved += 1 return compiled_graph, (key, path) def load(self, @@ -421,6 +423,12 @@ def _get_shape_env() -> AlwaysHitShapeEnv: if is_torch_equal_or_newer("2.6"): stack.enter_context( torch._inductor.config.patch(fx_graph_remote_cache=False)) + # InductorAdaptor (unfortunately) requires AOTAutogradCache + # to be turned off to run. It will fail to acquire the hash_str + # and error if not. + # StandaloneInductorAdaptor (PyTorch 2.8+) fixes this problem. + stack.enter_context( + torch._functorch.config.patch(enable_autograd_cache=False)) stack.enter_context( torch._functorch.config.patch( enable_remote_autograd_cache=False)) diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 9d7a25689b56..6acb8abb3deb 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -23,6 +23,10 @@ class CompilationCounter: num_inductor_compiles: int = 0 # EagerAdapter.compile calls num_eager_compiles: int = 0 + # The number of time vLLM's compiler cache entry was updated + num_cache_entries_updated: int = 0 + # The number of standalone_compile compiled artifacts saved + num_compiled_artifacts_saved: int = 0 def clone(self) -> "CompilationCounter": return copy.deepcopy(self) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 05e4ca9f08b3..f3592324d8cf 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -20,9 +20,38 @@ logger = init_logger(__name__) +IGNORE_COMPILE_KEY = "_ignore_compile_vllm" + _T = TypeVar("_T", bound=type[nn.Module]) +def ignore_torch_compile(cls: _T) -> _T: + """ + A decorator to ignore support_torch_compile decorator + on the class. This is useful when a parent class has + a support_torch_compile decorator, but we don't want to + compile the class `cls` that inherits the parent class. + This only ignores compiling the forward of the class the + decorator is applied to. + + If the parent has ignore_torch_compile but the child has + support_torch_compile, the child will still be compiled. + + If the class has one or more submodules + that have support_torch_compile decorator applied, compile will + not be ignored for those submodules. + """ + setattr(cls, IGNORE_COMPILE_KEY, True) + return cls + + +def _should_ignore_torch_compile(cls) -> bool: + """ + Check if the class should be ignored for torch.compile. + """ + return getattr(cls, IGNORE_COMPILE_KEY, False) + + @overload def support_torch_compile( *, @@ -148,6 +177,8 @@ def _support_torch_compile( old_init = cls.__init__ + setattr(cls, IGNORE_COMPILE_KEY, False) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): old_init(self, vllm_config=vllm_config, prefix=prefix, **kwargs) self.vllm_config = vllm_config @@ -156,9 +187,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): self.do_not_compile = \ vllm_config.compilation_config.level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS - ] or not supports_dynamo() + ] or not supports_dynamo() or _should_ignore_torch_compile( + self.__class__) if self.do_not_compile: return + compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 951a2861e3a4..3dec939c2835 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Callable, ClassVar, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional import torch import torch._inductor.pattern_matcher as pm @@ -11,6 +11,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe @@ -33,27 +35,6 @@ def empty_fp32(*args, **kwargs): RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -# Use proxy as NamedTuple direct subclasses cannot have static members -class _GroupShape(NamedTuple): - row: int - col: int - - -class GroupShape(_GroupShape): - """ - This class describes the quantization group shape. - It includes static members for common shapes (per-tensor, per-token). - """ - - # Aliases for common quantization group shapes - PER_TENSOR: ClassVar['GroupShape'] - PER_TOKEN: ClassVar['GroupShape'] - - -GroupShape.PER_TENSOR = GroupShape(-1, -1) -GroupShape.PER_TOKEN = GroupShape(1, -1) - - class QuantKey(NamedTuple): """ Named tuple for identifying the type of quantization. diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 3ce00e3610c5..58216a1f0ed3 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -7,7 +7,7 @@ from vllm.logger import init_logger from .activation_quant_fusion import ActivationQuantFusionPass -from .collective_fusion import AsyncTPPass +from .collective_fusion import AllReduceFusionPass, AsyncTPPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .fusion_attn import AttnFusionPass @@ -62,7 +62,8 @@ def configure(self, config: VllmConfig): if self.pass_config.enable_attn_fusion: self.passes += [AttnFusionPass(config)] - + if self.pass_config.enable_fi_allreduce_fusion: + self.passes += [AllReduceFusionPass(config)] self.fix_functionalization = FixFunctionalizationPass(config) def add(self, pass_: InductorPass): diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 628e9e204c55..b822b05b0f1e 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -6,13 +6,7 @@ import torch from torch._dynamo.utils import lazy_format_graph_code -from vllm.config import PassConfig, VllmConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable +from vllm.config import VllmConfig from vllm.logger import init_logger from .inductor_pass import InductorPass @@ -34,22 +28,9 @@ def __init__(self, config: VllmConfig): else None self.pass_name = self.__class__.__name__ - def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False): + def dump_graph(self, graph: torch.fx.Graph, stage: str): lazy_format_graph_code(stage, graph.owning_module) - if stage in self.pass_config.dump_graph_stages or always: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.pass_config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("%s printing graph to %s", self.pass_name, filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) - def begin(self): self._start_time = time.perf_counter_ns() @@ -61,10 +42,9 @@ def end_and_log(self): class PrinterInductorPass(VllmInductorPass): - def __init__(self, name: str, config: PassConfig, always=False): + def __init__(self, name: str, config: VllmConfig): super().__init__(config) self.name = name - self.always = always def __call__(self, graph: torch.fx.Graph): - self.dump_graph(graph, self.name, always=self.always) + self.dump_graph(graph, self.name) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 2a261c84c3fc..8d5df1061eda 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -93,9 +93,10 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): return self.compiled_codes.append(new_code) - local_cache_dir = self.vllm_config.compilation_config.local_cache_dir - if isinstance(local_cache_dir, str): - decompiled_file = os.path.join(local_cache_dir, + debug_dump_dir = self.vllm_config.compilation_config.debug_dump_path + if isinstance(debug_dump_dir, str) and debug_dump_dir != "": + rank = self.vllm_config.parallel_config.rank + decompiled_file = os.path.join(debug_dump_dir, f"rank_{rank}", "transformed_code.py") if not os.path.exists(decompiled_file): try: @@ -105,6 +106,7 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): # not a reversible process. import depyf src = depyf.decompile(new_code) + with open(decompiled_file, "w") as f: f.write(src) diff --git a/vllm/config.py b/vllm/config.py index 718f218171f2..f038cdd64c67 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,7 +16,6 @@ replace) from functools import cached_property from importlib.util import find_spec -from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional, Protocol, TypeVar, Union, cast, get_args) @@ -27,7 +26,7 @@ from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp -from typing_extensions import Self, deprecated, runtime_checkable +from typing_extensions import Self, runtime_checkable import vllm.envs as envs from vllm import version @@ -72,6 +71,7 @@ ConfigType = type[DataclassInstance] HfOverrides = Union[dict, Callable[[type], type]] else: + DataclassInstance = Any PlacementGroup = Any PretrainedConfig = Any ExecutorBase = Any @@ -88,28 +88,23 @@ "vllm.model_executor.models") logger = init_logger(__name__) - +DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance) ConfigT = TypeVar("ConfigT", bound=ConfigType) TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", - "score", "reward", "transcription"] + "score", "reward", "transcription", "draft"] -_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft", - "transcription"] +_ResolvedTask = Literal["generate", "transcription", "encode", "embed", + "classify", "reward", "draft"] -RunnerType = Literal["generate", "pooling", "draft", "transcription"] +RunnerOption = Literal["auto", "generate", "pooling", "draft"] -_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { - "generate": ["generate"], - "pooling": ["embed", "classify", "reward"], - "draft": ["draft"], - "transcription": ["transcription"], -} +RunnerType = Literal["generate", "pooling", "draft"] -_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = { - task: runner - for runner, tasks in _RUNNER_TASKS.items() - for task in tasks +_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = { + "generate": ["generate", "transcription"], + "pooling": ["encode", "embed", "classify", "reward"], + "draft": [], } @@ -224,6 +219,8 @@ def is_init_field(cls: ConfigType, name: str) -> bool: TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] +LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs", + "processed_logits"] @config @@ -231,15 +228,18 @@ def is_init_field(cls: ConfigType, name: str) -> bool: class ModelConfig: """Configuration for the model.""" - model: str = "facebook/opt-125m" + model: str = "Qwen/Qwen3-0.6B" """Name or path of the Hugging Face model to use. It is also used as the content for `model_name` tag in metrics output when `served_model_name` is not specified.""" - task: Literal[TaskOption, Literal["draft"]] = "auto" - """The task to use the model for. Each vLLM instance only supports one - task, even if the same model can be used for multiple tasks. When the model - only supports one task, "auto" can be used to select it; otherwise, you - must specify explicitly which task to use.""" + runner: RunnerOption = "auto" + """The type of model runner to use. Each vLLM instance only supports one + model runner, even if the same model can be used for multiple types.""" + task: TaskOption = "auto" + """The task to use the model for. If the model supports more than one + model runner, this is used to select which model runner to run. + + Note that the model may support other tasks using the same model runner.""" tokenizer: SkipValidation[str] = None # type: ignore """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" @@ -318,6 +318,13 @@ class ModelConfig: """Maximum number of log probabilities to return when `logprobs` is specified in `SamplingParams`. The default value comes the default for the OpenAI Chat Completions API.""" + logprobs_mode: LogprobsMode = "raw_logprobs" + """Indicates the content returned in the logprobs and prompt_logprobs. + Supported mode: + 1) raw_logprobs, 2) processed_logprobs, 3) raw_logits, 4) processed_logits. + Raw means the values before applying logit processors, like bad words. + Processed means the values after applying such processors. + """ disable_sliding_window: bool = False """Whether to disable sliding window. If True, we will disable the sliding window functionality of the model, capping to sliding window size. If the @@ -347,9 +354,12 @@ class ModelConfig: limit_mm_per_prompt: dict[str, int] = field(default_factory=dict) """Maximum number of data items per modality per prompt. Only applicable for multimodal models.""" + interleave_mm_strings: bool = False + """Enable fully interleaved support for multimodal prompts, while using + --chat-template-content-format=string. Defaults to False.""" media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ use_async_output_proc: bool = True """Whether to use async output processor.""" @@ -532,16 +542,12 @@ def __post_init__(self) -> None: self.config_format = ConfigFormat(self.config_format) hf_config = get_config(self.hf_config_path or self.model, - self.trust_remote_code, self.revision, - self.code_revision, self.config_format) - - if hf_overrides_kw: - logger.debug("Overriding HF config with %s", hf_overrides_kw) - hf_config.update(hf_overrides_kw) - if hf_overrides_fn: - logger.debug("Overriding HF config with %s", hf_overrides_fn) - hf_config = hf_overrides_fn(hf_config) - + self.trust_remote_code, + self.revision, + self.code_revision, + self.config_format, + hf_overrides_kw=hf_overrides_kw, + hf_overrides_fn=hf_overrides_fn) self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) @@ -551,18 +557,49 @@ def __post_init__(self) -> None: self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision) - supported_tasks, task = self._resolve_task(self.task) - self.supported_tasks = supported_tasks - self.task = task - if self.task in ("draft", "generate"): - self.truncation_side = "left" - else: - self.truncation_side = "right" + # For pooling models, self.task is used to indicate the + # user-selected task + if self.task == "score": + if self._is_classify_task(self.architectures): + self.task = "classify" + else: + self.task = "embed" + elif self.task == "embedding": + msg = ("The 'embedding' task has been renamed to 'embed', please " + "use the new name. The old name will be removed in v1.0.") + warnings.warn(msg, DeprecationWarning, stacklevel=2) + + self.task = "embed" model_info, arch = self.registry.inspect_model_cls(self.architectures) self._model_info = model_info self._architecture = arch + all_supported_tasks = self._get_supported_tasks(self.task) + logger.debug("Tasks supported by runner type: %s", all_supported_tasks) + supported_runner_types = self._get_supported_runner_types( + all_supported_tasks) + runner_type = self._resolve_runner(self.runner, self.task, + supported_runner_types, + all_supported_tasks) + + logger.debug("Selected runner type: %s", runner_type) + # For pooling models, self.task is used to indicate the + # user-selected task + if runner_type == "pooling" and self.task == "auto": + selected_task = all_supported_tasks[runner_type][-1] + assert selected_task != "encode" + self.task = selected_task + self.supported_runner_types = supported_runner_types + self.runner_type = runner_type + self.supported_tasks = all_supported_tasks[runner_type] + + if self.runner_type in ("draft", + "generate") and self.task != "transcription": + self.truncation_side = "left" + else: + self.truncation_side = "right" + self.pooler_config = self._init_pooler_config() self.dtype = _get_and_verify_dtype( @@ -614,6 +651,8 @@ def __post_init__(self) -> None: self.original_max_model_len = self.max_model_len self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.multimodal_config = self._init_multimodal_config() + self.model_supports_multimodal_raw_input = ( + self.registry.supports_multimodal_raw_input(self.architectures)) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() @@ -646,6 +685,16 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": "max_model_len must be an integer after __post_init__.") return self + def _get_transformers_backend_cls(self) -> str: + """Determine which Transformers backend class will be used if + `model_impl` is set to `transformers` or `auto`.""" + if self.hf_config != self.hf_text_config: + # If 'hf_text_config' is the same as 'hf_config'. If not, it is + # probably a composite config, i.e. multimodal + return "TransformersForMultimodalLM" + else: + return "TransformersForCausalLM" + @property def registry(self): return me_models.ModelRegistry @@ -653,7 +702,19 @@ def registry(self): @property def architectures(self) -> list[str]: # architectures in the model config. - return getattr(self.hf_config, "architectures", []) + architectures = getattr(self.hf_config, "architectures", []) + # The registry assumes that it can always inspect the vLLM model class + # for a given architecture. This assumption breaks down for the + # Transformers backend, which may use a different class depending on + # the model type. To work around this, we add the correct Transformers + # backend class to the architectures list. We must do this here because + # we need access to the `hf_config` to determine the backend class. + transformers_backend_cls = self._get_transformers_backend_cls() + if (self.model_impl != ModelImpl.VLLM.value + and all(arch != transformers_backend_cls + for arch in architectures)): + architectures.append(transformers_backend_cls) + return architectures @property def architecture(self) -> str: @@ -684,8 +745,11 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, # If tokenizer is same as model, download to same directory if model == tokenizer: - s3_model.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + s3_model.pull_files(model, + ignore_pattern=[ + "*.pt", "*.safetensors", "*.bin", + "*.tensors" + ]) self.tokenizer = s3_model.dir return @@ -693,7 +757,8 @@ def maybe_pull_model_tokenizer_for_s3(self, model: str, if is_s3(tokenizer): s3_tokenizer = S3Model() s3_tokenizer.pull_files( - model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) + model, + ignore_pattern=["*.pt", "*.safetensors", "*.bin", "*.tensors"]) self.tokenizer = s3_tokenizer.dir def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: @@ -703,7 +768,8 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: media_io_kwargs=self.media_io_kwargs, mm_processor_kwargs=self.mm_processor_kwargs, disable_mm_preprocessor_cache=self. - disable_mm_preprocessor_cache) + disable_mm_preprocessor_cache, + interleave_mm_strings=self.interleave_mm_strings) if self.limit_mm_per_prompt: raise ValueError("`limit_mm_per_prompt` is only supported for " @@ -714,6 +780,9 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self.disable_mm_preprocessor_cache: raise ValueError("`disable_mm_preprocessor_cache` is only " "supported for multimodal models.") + if self.interleave_mm_strings: + raise ValueError("`interleave_mm_strings` is only " + "supported for multimodal models.") return None @@ -770,107 +839,155 @@ def _verify_tokenizer_mode(self) -> None: f"one of {get_args(TokenizerMode)}.") self.tokenizer_mode = tokenizer_mode - def _get_preferred_task( + def _is_classify_task(self, architectures: list[str]): + for arch in architectures: + if arch.endswith("ForSequenceClassification"): + return True + return self.registry.is_cross_encoder_model(architectures) + + def _get_preferred_pooling_task( self, architectures: list[str], - supported_tasks: set[_ResolvedTask], - ) -> Optional[_ResolvedTask]: + ) -> _ResolvedTask: model_id = self.model if get_pooling_config(model_id, self.revision): return "embed" - if self.registry.is_cross_encoder_model(architectures): - return "classify" if self.registry.is_transcription_model(architectures): return "transcription" suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [ # Other models follow this pattern - ("ForCausalLM", "generate"), - ("ForConditionalGeneration", "generate"), - ("ForSequenceClassification", "classify"), - ("ChatModel", "generate"), - ("LMHeadModel", "generate"), ("EmbeddingModel", "embed"), ("RewardModel", "reward"), ] - _, arch = self.registry.inspect_model_cls(architectures) for suffix, pref_task in suffix_to_preferred_task: - if arch.endswith(suffix) and pref_task in supported_tasks: + if self.architecture.endswith(suffix): return pref_task - return None + return "embed" - def _resolve_task( + def _get_supported_generation_tasks( self, - task_option: Literal[TaskOption, Literal["draft"]], - ) -> tuple[set[_ResolvedTask], _ResolvedTask]: - if task_option == "draft": - return {"draft"}, "draft" + task_option: TaskOption, + ) -> list[_ResolvedTask]: + registry = self.registry + architectures = self.architectures + + if registry.is_transcription_only_model(architectures): + return ["transcription"] + + supported_tasks = list[_ResolvedTask]() + if registry.is_text_generation_model(architectures): + supported_tasks.append("generate") + if registry.is_transcription_model(architectures): + supported_tasks.append("transcription") + + return supported_tasks + + def _get_supported_pooling_tasks( + self, + task_option: TaskOption, + ) -> list[_ResolvedTask]: registry = self.registry architectures = self.architectures - runner_support: dict[RunnerType, bool] = { - # NOTE: Listed from highest to lowest priority, - # in case the model supports multiple of them - "transcription": registry.is_transcription_model(architectures), - "generate": registry.is_text_generation_model(architectures), - "pooling": registry.is_pooling_model(architectures), - } - supported_runner_types_lst: list[RunnerType] = [ - runner_type - for runner_type, is_supported in runner_support.items() - if is_supported - ] + supported_tasks = list[_ResolvedTask]() + if registry.is_pooling_model(architectures): + supported_tasks.append("encode") - supported_tasks_lst: list[_ResolvedTask] = [ - task for runner_type in supported_runner_types_lst - for task in _RUNNER_TASKS[runner_type] - ] - supported_tasks = set(supported_tasks_lst) + # For now, users must specify the task (other than "pooling") + # to use for pooling models + if task_option == "auto": + preferred_task = self._get_preferred_pooling_task( + architectures) - if task_option == "auto": - selected_task = next(iter(supported_tasks_lst)) + supported_tasks.append(preferred_task) + elif task_option in _RUNNER_TASKS["pooling"]: + supported_tasks.append(cast(_ResolvedTask, task_option)) - if len(supported_tasks_lst) > 1: - preferred_task = self._get_preferred_task( - architectures, supported_tasks) - if preferred_task is not None: - selected_task = preferred_task + return supported_tasks - logger.info( - "This model supports multiple tasks: %s. " - "Defaulting to '%s'.", supported_tasks, selected_task) + def _get_supported_tasks( + self, + task_option: TaskOption, + ) -> dict[RunnerType, list[_ResolvedTask]]: + if self._is_classify_task(self.architectures): + return {"generate": [], "pooling": ["classify"], "draft": []} else: - if task_option == "score": - if not runner_support["pooling"]: - msg = (f"This model does not support the '{task_option}' " - f"task. Supported tasks: {supported_tasks}") - raise ValueError(msg) - if self.registry.is_cross_encoder_model(architectures): - task_option = "classify" - else: - task_option = "embed" + return { + "generate": self._get_supported_generation_tasks(task_option), + "pooling": self._get_supported_pooling_tasks(task_option), + "draft": ["draft"] + } + + def _get_supported_runner_types( + self, + supported_tasks: dict[RunnerType, list[_ResolvedTask]], + ) -> set[RunnerType]: + return { + runner + for runner, runner_tasks in supported_tasks.items() + if len(runner_tasks) > 0 + } + + def _resolve_runner( + self, + runner_option: RunnerOption, + task_option: TaskOption, + supported_runner_types: set[RunnerType], + supported_tasks: dict[RunnerType, list[_ResolvedTask]], + ) -> RunnerType: + if not supported_runner_types: + raise ValueError("This model does not support any model runners!") + + if runner_option != "auto": + if runner_option not in supported_runner_types: + raise ValueError( + f"This model does not support runner={runner_option!r}. " + f"Available runners: {supported_runner_types}") + + return runner_option + + if task_option != "auto": + for runner, runner_tasks in supported_tasks.items(): + if task_option in runner_tasks: + return runner else: - # Aliases - if task_option == "embedding": - msg = ("The 'embedding' task has been renamed to " - "'embed', please use the new name. The old name " - "will be removed in v1.0.") - warnings.warn(msg, DeprecationWarning, stacklevel=2) + task_runner: RunnerType = next( + runner for runner, tasks in _RUNNER_TASKS.items() + if task_option in tasks) + raise ValueError( + f"This model does not support task={task_option!r}. " + f"Available tasks for runner={task_runner!r}: " + f"{supported_tasks[task_runner]}") - task_option = "embed" + if "classify" in supported_tasks.get("pooling", []): + # When multiple pooling tasks are present, default to + # pooling (eg cross-encoder) for non-standard architectures. + return "pooling" - if task_option not in supported_tasks: - msg = ( - f"This model does not support the '{task_option}' task. " - f"Supported tasks: {supported_tasks}") - raise ValueError(msg) + suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [ + ("ForCausalLM", "generate"), + ("ForConditionalGeneration", "generate"), + ("ChatModel", "generate"), + ("LMHeadModel", "generate"), + ("EmbeddingModel", "pooling"), + ("RewardModel", "pooling"), + ] + + for suffix, pref_runner in suffix_to_preferred_runner: + if self.architecture.endswith( + suffix) and pref_runner in supported_runner_types: + return pref_runner - selected_task = task_option + if "generate" in supported_runner_types: + return "generate" + if "pooling" in supported_runner_types: + return "pooling" - return supported_tasks, selected_task + raise AssertionError("This line should not be reached") def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -884,7 +1001,7 @@ def _verify_quantization(self) -> None: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", - "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" + "quark", "modelopt_fp4", "bitblas", "gptq_bitblas", "inc" ] if self.quantization is not None: self.quantization = cast(me_quant.QuantizationMethods, @@ -894,9 +1011,13 @@ def _verify_quantization(self) -> None: quant_cfg = self._parse_quant_hf_config() if quant_cfg is not None: + # Use the community standard 'quant_method' quant_method = quant_cfg.get("quant_method", "").lower() + + # Normalize library names quant_method = quant_method.replace("compressed_tensors", "compressed-tensors") + quant_cfg["quant_method"] = quant_method # Quantization methods which are overrides (i.e. they have a @@ -911,6 +1032,8 @@ def _verify_quantization(self) -> None: "awq_marlin", "ipex", "moe_wna16", + "modelopt", + "modelopt_fp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides @@ -1122,17 +1245,17 @@ def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]: return self.get_hf_config_sliding_window() def get_vocab_size(self) -> int: - return self.hf_text_config.vocab_size + return getattr(self.hf_text_config, "vocab_size", 0) def get_hidden_size(self) -> int: - return self.hf_text_config.hidden_size + return getattr(self.hf_text_config, "hidden_size", 0) @property def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): return False elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'): return self.hf_text_config.kv_lora_rank is not None elif self.hf_text_config.model_type == 'eagle': # if the model is an EAGLE module, check for the @@ -1248,7 +1371,8 @@ def get_layers_start_end_indices( self, parallel_config: "ParallelConfig") -> tuple[int, int]: from vllm.distributed.utils import get_pp_indices if (self.hf_text_config.model_type == "deepseek_mtp" - or self.hf_config.model_type == "mimo_mtp"): + or self.hf_config.model_type == "mimo_mtp" + or self.hf_config.model_type == "glm4_moe_mtp"): total_num_hidden_layers = getattr(self.hf_text_config, "num_nextn_predict_layers", 0) else: @@ -1320,6 +1444,17 @@ def get_num_layers_by_block_type( return sum(t == 1 for t in attn_type_list[start:end]) + def get_mamba_chunk_size(self) -> Optional[int]: + """ + Returns the mamba chunk size if it exists + """ + # used by e.g. Bamba, FalconH1, Granite, PLaMo2 + chunk_size = getattr(self.hf_text_config, "mamba_chunk_size", None) + if chunk_size is None: + # used by e.g. Mamba2, NemotronH, Zamba + chunk_size = getattr(self.hf_text_config, "chunk_size", None) + return chunk_size + def get_multimodal_config(self) -> "MultiModalConfig": """ Get the multimodal configuration of the model. @@ -1428,14 +1563,6 @@ def is_cross_encoder(self) -> bool: def use_mla(self) -> bool: return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE - @property - def supported_runner_types(self) -> set[RunnerType]: - return {_TASK_RUNNER[task] for task in self.supported_tasks} - - @property - def runner_type(self) -> RunnerType: - return _TASK_RUNNER[cast(_ResolvedTask, self.task)] - @property def is_v1_compatible(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) @@ -1443,21 +1570,25 @@ def is_v1_compatible(self) -> bool: @property def is_matryoshka(self) -> bool: - return (hasattr(self.hf_config, "matryoshka_dimensions") + return (bool(getattr(self.hf_config, "matryoshka_dimensions", None)) or getattr(self.hf_config, "is_matryoshka", False)) @property def matryoshka_dimensions(self): return getattr(self.hf_config, "matryoshka_dimensions", None) + @property + def use_pad_token(self) -> bool: + # cross_encoder models defaults to using pad_token. + # `llm as reranker` models defaults to not using pad_token. + return getattr(self.hf_config, "use_pad_token", True) + def get_and_verify_max_len(self, max_model_len: int): - # For pooling models, the tokenizer's `model_max_length` is often a - # reliable source for the maximum sequence length. However, for - # generative models, this can be incorrect and unduly limit the - # context window (e.g., DeepSeek-R1). Therefore, we only consider - # tokenizer_config for pooling models. + # Consider max_model_len in tokenizer_config only when + # pooling models use absolute position_embedding. tokenizer_config = None - if self.runner_type == "pooling": + if (self.runner_type == "pooling" and getattr( + self.hf_config, "position_embedding_type", "") == "absolute"): tokenizer_config = try_get_tokenizer_config( self.tokenizer, trust_remote_code=self.trust_remote_code, @@ -1475,8 +1606,8 @@ def get_and_verify_max_len(self, max_model_len: int): BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] -PrefixCachingHashAlgo = Literal["builtin", "sha256"] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +PrefixCachingHashAlgo = Literal["builtin", "sha256", "sha256_cbor_64bit"] @config @@ -1505,7 +1636,7 @@ class CacheConfig: cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports - fp8 (=fp8_e4m3).""" + fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" is_attention_free: bool = False """Whether the model is attention-free. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -1521,7 +1652,12 @@ class CacheConfig: prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" """Set the hash algorithm for prefix caching:\n - "builtin" is Python's built-in hash.\n - - "sha256" is collision resistant but with certain overheads.""" + - "sha256" is collision resistant but with certain overheads. + This option uses Pickle for object serialization before hashing.\n + - "sha256_cbor_64bit" provides a reproducible, cross-language compatible + hash. It serializes objects using canonical CBOR and hashes them with + SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256 + digest.""" cpu_offload_gb: float = 0 """The space in GiB to offload to CPU, per GPU. Default is 0, which means no offloading. Intuitively, this argument can be seen as a virtual way to @@ -1537,6 +1673,9 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" + mamba_page_size_padded: Optional[int] = None + """ Optional override for mamba page size; used by hybrid mamba/attention + models to ensure exact alignment with attention page size.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) @@ -1595,7 +1734,7 @@ def _verify_cache_dtype(self) -> None: "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor") + "scaling factor.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") @@ -1634,35 +1773,6 @@ def verify_with_parallel_config( logger.warning("Possibly too large swap space. %s", msg) -@config -@dataclass -class TokenizerPoolConfig: - """This config is deprecated and will be removed in a future release. - - Passing these parameters will have no effect. Please remove them from your - configurations. - """ - - pool_size: int = 0 - """This parameter is deprecated and will be removed in a future release. - Passing this parameter will have no effect. Please remove it from your - configurations.""" - pool_type: str = "ray" - """This parameter is deprecated and will be removed in a future release. - Passing this parameter will have no effect. Please remove it from your - configurations.""" - extra_config: dict = field(default_factory=dict) - """This parameter is deprecated and will be removed in a future release. - Passing this parameter will have no effect. Please remove it from your - configurations.""" - - def __post_init__(self) -> None: - logger.warning_once( - "TokenizerPoolConfig is deprecated and will be removed in a " - "future release. Passing this parameter will have no effect. " - "Please remove it from your configurations.") - - class LoadFormat(str, enum.Enum): AUTO = "auto" PT = "pt" @@ -1714,6 +1824,9 @@ class LoadConfig: default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format.""" + device: Optional[str] = None + """Device to which model weights will be loaded, default to + device_config.device""" ignore_patterns: Optional[Union[list[str], str]] = None """The list of patterns to ignore when loading the model. Default to "original/**/*" to avoid repeated loading of llama's checkpoints.""" @@ -1795,8 +1908,16 @@ class ParallelConfig: """Backend to use for data parallel, either "mp" or "ray".""" data_parallel_external_lb: bool = False """Whether to use "external" DP LB mode. Applies only to online serving - and when data_parallel_size > 0. Set implicitly when - data_parallel_rank is provided explicitly to vllm serve.""" + and when data_parallel_size > 0. This is useful for a "one-pod-per-rank" + wide-EP setup in Kuberentes. Set implicitly when --data-parallel-rank + is provided explicitly to vllm serve.""" + data_parallel_hybrid_lb: bool = False + """Whether to use "hybrid" DP LB mode. Applies only to online serving + and when data_parallel_size > 0. Enables running an AsyncLLM + and API server on a "per-node" basis where vLLM load balances + between local data parallel ranks, but an external LB balances + between vLLM nodes/replicas. Set explicitly in conjunction with + --data-parallel-start-rank.""" enable_expert_parallel: bool = False """Use expert parallelism instead of tensor parallelism for MoE layers.""" enable_eplb: bool = False @@ -1826,10 +1947,6 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" - tokenizer_pool_config: Optional[TokenizerPoolConfig] = None - """This parameter is deprecated and will be removed in a future release. - Please remove it from your configs""" - ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" @@ -1844,7 +1961,7 @@ class ParallelConfig: or equal to the number of GPUs available, "mp" will be used to keep processing on a single host. Otherwise, this will default to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference.""" + only support Ray for distributed inference.""" worker_cls: str = "auto" """The full name of the worker class to use. If "auto", the worker class @@ -1938,6 +2055,19 @@ def has_unfinished_dp(dp_group: "ProcessGroup", aggregated_has_unfinished = bool(tensor.item()) return aggregated_has_unfinished + @staticmethod + def sync_kv_cache_memory_size(dp_group: "ProcessGroup", + kv_cache_memory: int) -> int: + if kv_cache_memory == -1: + kv_cache_memory = torch.iinfo(torch.int64).max + tensor = torch.tensor([kv_cache_memory], + dtype=torch.int64, + device="cpu") + # we cannot use broadcast for stateless dp group since it depends + # on global rank + torch.distributed.all_reduce(tensor, op=ReduceOp.MIN, group=dp_group) + return tensor.item() + def compute_hash(self): """ Provide a hash that uniquely identifies all the configs @@ -1997,6 +2127,15 @@ def __post_init__(self) -> None: raise ValueError( "num_redundant_experts must be non-negative, but got " f"{self.num_redundant_experts}.") + if not self.enable_expert_parallel: + raise ValueError( + "enable_expert_parallel must be True to use EPLB.") + if self.tensor_parallel_size * self.data_parallel_size <= 1: + raise ValueError( + "EPLB requires tensor_parallel_size or data_parallel_size " + f"to be greater than 1, but got " + f"TP={self.tensor_parallel_size},DP={self.data_parallel_size}." + ) else: if self.num_redundant_experts != 0: raise ValueError( @@ -2017,10 +2156,11 @@ def __post_init__(self) -> None: elif (current_platform.is_cuda() and cuda_device_count_stateless() < self.world_size): if not ray_found: - raise ValueError("Unable to load Ray which is " + raise ValueError("Unable to load Ray: " + f"{ray_utils.ray_import_err}. Ray is " "required for multi-node inference, " "please install Ray with `pip install " - "ray`.") from ray_utils.ray_import_err + "ray`.") backend = "ray" elif self.data_parallel_backend == "ray": logger.info("Using ray distributed inference because " @@ -2131,11 +2271,12 @@ class SchedulerConfig: NOTE: This will be replaced by speculative config in the future; it is present to enable correctness tests until then.""" - cuda_graph_sizes: list[int] = field(default_factory=lambda: [512]) - """Cuda graph capture sizes, default is 512. - 1. if one value is provided, then the capture list would follow the + cuda_graph_sizes: list[int] = field(default_factory=list) + """Cuda graph capture sizes + 1. if none provided, then default set to [min(max_num_seqs * 2, 512)] + 2. if one value is provided, then the capture list would follow the pattern: [1, 2, 4] + [i for i in range(8, cuda_graph_sizes + 1, 8)] - 2. more than one value (e.g. 1 2 128) is provided, then the capture list + 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" delay_factor: float = 0.0 @@ -2214,6 +2355,13 @@ class SchedulerConfig: like full attention and sliding window attention. """ + async_scheduling: bool = False + """EXPERIMENTAL: If set to True, perform async scheduling. This may help + reduce the CPU overheads, leading to better latency and throughput. However, + async scheduling is currently not supported with some features such as + structured outputs, speculative decoding, and pipeline parallelism. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2300,6 +2448,17 @@ def __post_init__(self) -> None: self.max_num_partial_prefills, self.max_long_partial_prefills, self.long_prefill_token_threshold) + # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)]. + # This avoids OOM in tight memory scenarios with small max_num_seqs, + # and prevents capture of many large graphs (>512) that would greatly + # increase startup time with limited performance benefit. + if not self.cuda_graph_sizes: + self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)] + + if self.async_scheduling: + self.scheduler_cls = ( + "vllm.v1.core.sched.async_scheduler.AsyncScheduler") + @model_validator(mode='after') def _verify_args(self) -> Self: if (self.max_num_batched_tokens < self.max_model_len @@ -2367,7 +2526,7 @@ def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 -Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] +Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu"] @config @@ -2434,8 +2593,6 @@ def __post_init__(self): SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", "mlp_speculator", "draft_model", "deepseek_mtp"] -SpeculativeAcceptanceMethod = Literal["rejection_sampler", - "typical_acceptance_sampler"] @config @@ -2458,13 +2615,6 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" - acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler" - """The method to use for accepting draft tokens:\n - - "rejection_sampler" maps to `RejectionSampler`.\n - - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`. - - If using `typical_acceptance_sampler`, the related configuration - `posterior_threshold` and `posterior_alpha` should be considered.""" draft_tensor_parallel_size: Optional[int] = None """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" @@ -2491,9 +2641,6 @@ class SpeculativeConfig: will use the default version.""" # Advanced control - disable_mqa_scorer: bool = False - """Disable the MQA scorer and fall back to batch expansion for scoring - proposals.""" disable_by_batch_size: Optional[int] = None """Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided.""" @@ -2506,16 +2653,6 @@ class SpeculativeConfig: """Minimum size of ngram token window when using Ngram proposer, if provided. Defaults to 1.""" - # Typical acceptance sampler configuration - posterior_threshold: Optional[float] = None - """A threshold value that sets a lower bound on the posterior probability - of a token in the target model for it to be accepted. This threshold is - used only when we use the `TypicalAcceptanceSampler` for token acceptance. - """ - posterior_alpha: Optional[float] = None - """Scaling factor for entropy-based threshold, applied when using - `TypicalAcceptanceSampler`.""" - speculative_token_tree: Optional[str] = None """Specifies the tree structure for speculative token generation. """ @@ -2583,7 +2720,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: "n_predict": n_predict, "architectures": ["MiMoMTPModel"] }) - return hf_config + + if hf_config.architectures[0] == "Glm4MoeForCausalLM": + hf_config.model_type = "glm4_moe_mtp" + n_predict = getattr(hf_config, "num_nextn_predict_layers", None) + hf_config.update({ + "num_hidden_layers": 0, + "n_predict": n_predict, + "architectures": ["Glm4MoeMTPModel"] + }) return hf_config @@ -2659,7 +2804,7 @@ def __post_init__(self): if self.model is not None: self.draft_model_config = ModelConfig( model=self.model, - task="draft", + runner="draft", tokenizer=self.target_model_config.tokenizer, tokenizer_mode=self.target_model_config.tokenizer_mode, trust_remote_code=self.target_model_config. @@ -2693,8 +2838,8 @@ def __post_init__(self): elif (self.draft_model_config.hf_config.model_type == "mlp_speculator"): self.method = "mlp_speculator" - elif (self.draft_model_config.hf_config.model_type == - "deepseek_mtp"): + elif (self.draft_model_config.hf_config.model_type + in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")): self.method = "deepseek_mtp" if self.num_speculative_tokens > 1: logger.warning( @@ -2704,6 +2849,11 @@ def __post_init__(self): ) else: self.method = "draft_model" + raise NotImplementedError( + "Speculative decoding with draft model is not " + "supported yet. Please consider using other " + "speculative decoding methods such as ngram, medusa, " + "eagle, or deepseek_mtp.") # Replace hf_config for EAGLE draft_model if self.method in ("eagle", "eagle3"): @@ -2762,12 +2912,6 @@ def __post_init__(self): self.target_parallel_config, self.draft_tensor_parallel_size)) - if self.acceptance_method == "typical_acceptance_sampler": - if self.posterior_threshold is None: - self.posterior_threshold = 0.09 - if self.posterior_alpha is None: - self.posterior_alpha = 0.3 - @staticmethod def _maybe_override_draft_max_model_len( speculative_max_model_len: Optional[int], @@ -2873,30 +3017,6 @@ def _verify_args(self) -> Self: if self.draft_model_config: self.draft_model_config.verify_with_parallel_config( self.draft_parallel_config) - # Validate and set draft token acceptance related settings. - - if self.acceptance_method is None: - raise ValueError("acceptance_method is not set. " - "Expected values are rejection_sampler or " - "typical_acceptance_sampler.") - - if (self.acceptance_method != 'rejection_sampler' - and self.acceptance_method != 'typical_acceptance_sampler'): - raise ValueError( - "Expected acceptance_method to be either " - "rejection_sampler or typical_acceptance_sampler. Instead it " - f"is {self.acceptance_method}") - - if self.acceptance_method == "typical_acceptance_sampler" and ( - (self.posterior_threshold is not None - and self.posterior_threshold < 0) or - (self.posterior_alpha is not None and self.posterior_alpha < 0)): - raise ValueError( - "Expected the posterior_threshold and posterior_alpha of " - "typical_acceptance_sampler to be > 0. " - "Instead found posterior_threshold = " - f"{self.posterior_threshold} and posterior_alpha = " - f"{self.posterior_alpha}") if (self.disable_by_batch_size is not None and self.disable_by_batch_size < 2): @@ -2959,12 +3079,7 @@ class LoRAConfig: (added to the base model vocabulary).""" lora_vocab_padding_size: ClassVar[int] = current_platform\ .get_lora_vocab_padding_size() - long_lora_scaling_factors: Optional[tuple[float, ...]] = None - """Specify multiple scaling factors (which can be different from base model - scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters - trained with those scaling factors to be used at the same time. If not - specified, only adapters trained with the base model scaling factor are - allowed.""" + default_mm_loras: Optional[dict[str, str]] = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -2997,7 +3112,6 @@ def compute_hash(self) -> str: factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) factors.append(self.lora_vocab_padding_size) - factors.append(self.long_lora_scaling_factors) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() @@ -3036,64 +3150,6 @@ def verify_with_model_config(self, model_config: ModelConfig): elif isinstance(self.lora_dtype, str): self.lora_dtype = getattr(torch, self.lora_dtype) - def verify_lora_support(self): - if self.long_lora_scaling_factors is not None and envs.VLLM_USE_V1: - raise ValueError( - "V1 LoRA does not support long LoRA, please use V0.") - - -@config -@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) -class PromptAdapterConfig: - """Configuration for PromptAdapters.""" - - max_prompt_adapters: int = 1 - """Max number of PromptAdapters in a batch.""" - max_prompt_adapter_token: int = 0 - """Max number of PromptAdapters tokens.""" - max_cpu_prompt_adapters: Optional[int] = None - """Maximum number of PromptAdapters to store in CPU memory. Must be >= than - `max_prompt_adapters`.""" - prompt_adapter_dtype: Union[torch.dtype, str] = "auto" - """Data type for PromptAdapter. If auto, will default to base model dtype. - """ - - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, - ensure that it is included in the factors list if - it affects the computation graph. - - Provide a hash that uniquely identifies all the configs - that affect the structure of the computation - graph from input ids/embeddings to the final hidden states, - excluding anything before input ids/embeddings and after - the final hidden states. - """ - # no factors to consider. - # this config will not affect the computation graph. - factors: list[Any] = [] - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest() - return hash_str - - def __post_init__(self): - - if self.max_prompt_adapters < 1: - raise ValueError(f"max_prompt_adapters " - f"({self.max_prompt_adapters}) must be >= 1.") - if self.max_prompt_adapter_token == 0: - raise ValueError("max_prompt_adapter_token must be set.") - if self.max_cpu_prompt_adapters is None: - self.max_cpu_prompt_adapters = self.max_prompt_adapters - - def verify_with_model_config(self, model_config: ModelConfig): - if self.prompt_adapter_dtype == "auto": - self.prompt_adapter_dtype = model_config.dtype - elif isinstance(self.prompt_adapter_dtype, str): - self.prompt_adapter_dtype = getattr(torch, - self.prompt_adapter_dtype) - @config @dataclass @@ -3111,8 +3167,8 @@ class MultiModalConfig: """ media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict) - """Additional args passed to process media inputs, keyed by modalities. - For example, to set num_frames for video, set + """Additional args passed to process media inputs, keyed by modalities. + For example, to set num_frames for video, set `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ mm_processor_kwargs: Optional[dict[str, object]] = None @@ -3131,6 +3187,11 @@ class MultiModalConfig: If `True`, disable caching of the processed multi-modal inputs. """ + interleave_mm_strings: bool = False + """ + Enable fully interleaved support for multimodal prompts. + """ + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -3561,7 +3622,8 @@ def get_served_model_name(model: str, GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer", "xgrammar", "guidance"] -GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"] + +GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance", "outlines"] GuidedDecodingBackend = Literal[GuidedDecodingBackendV0, GuidedDecodingBackendV1] @@ -3571,18 +3633,6 @@ def get_served_model_name(model: str, class DecodingConfig: """Dataclass which contains the decoding strategy of the engine.""" - @property - @deprecated( - "`guided_decoding_backend` is deprecated and has been renamed to " - "`backend`. This will be removed in v0.10.0. Please use the " - "`backend` argument instead.") - def guided_decoding_backend(self) -> GuidedDecodingBackend: - return self.backend - - @guided_decoding_backend.setter - def guided_decoding_backend(self, value: GuidedDecodingBackend): - self.backend = value - backend: GuidedDecodingBackend = "auto" if envs.VLLM_USE_V1 else "xgrammar" """Which engine will be used for guided decoding (JSON schema / regex etc) by default. With "auto", we will make opinionated choices based on request @@ -3625,9 +3675,6 @@ def compute_hash(self) -> str: return hash_str def __post_init__(self): - if ":" in self.backend: - self._extract_backend_options() - if envs.VLLM_USE_V1: valid_guided_backends = get_args(GuidedDecodingBackendV1) else: @@ -3643,24 +3690,6 @@ def __post_init__(self): raise ValueError("disable_additional_properties is only supported " "for the guidance backend.") - @deprecated( - "Passing guided decoding backend options inside backend in the format " - "'backend:...' is deprecated. This will be removed in v0.10.0. Please " - "use the dedicated arguments '--disable-fallback', " - "'--disable-any-whitespace' and '--disable-additional-properties' " - "instead.") - def _extract_backend_options(self): - """Extract backend options from the backend string.""" - backend, options = self.backend.split(":") - self.backend = cast(GuidedDecodingBackend, backend) - options_set = set(options.strip().split(",")) - if "no-fallback" in options_set: - self.disable_fallback = True - if "disable-any-whitespace" in options_set: - self.disable_any_whitespace = True - if "no-additional-properties" in options_set: - self.disable_additional_properties = True - DetailedTraceModules = Literal["model", "worker", "all"] @@ -3911,11 +3940,6 @@ class PassConfig: don't all have access to full configuration - that would create a cycle as the `PassManager` is set as a property of config.""" - dump_graph_stages: list[str] = field(default_factory=list) - """List of stages for which we want to dump the graph. Each pass defines - its own stages (before, after, maybe in-between).""" - dump_graph_dir: Path = Path(".") - """Directory to dump the graphs.""" enable_fusion: bool = field(default_factory=lambda: not envs.VLLM_USE_V1) """Whether to enable the custom fusion (RMSNorm/SiluMul+quant) pass.""" enable_attn_fusion: bool = False @@ -3926,6 +3950,10 @@ class PassConfig: """Whether to enable sequence parallelism.""" enable_async_tp: bool = False """Whether to enable async TP.""" + enable_fi_allreduce_fusion: bool = False + """Whether to enable flashinfer allreduce fusion.""" + fi_allreduce_fusion_max_token_num: int = 1024 + """Max number of tokens to used in flashinfer allreduce fusion.""" # TODO(luka) better pass enabling system. @@ -3933,12 +3961,9 @@ def uuid(self): """ Produces a hash unique to the pass configuration. Any new fields that affect compilation should be added to the hash. - Do not include dump_graph_* in the hash - they don't affect - compilation. + Any future fields that don't affect compilation should be excluded. """ - exclude = {"dump_graph_stages", "dump_graph_dir"} - dict_ = {k: v for k, v in asdict(self).items() if k not in exclude} - return InductorPass.hash_dict(dict_) + return InductorPass.hash_dict(asdict(self)) def __post_init__(self) -> None: if not self.enable_noop: @@ -4043,7 +4068,7 @@ class CompilationConfig: - True: inductor compilation is used (custom_ops disabled by default). One graph for symbolic shape and one graph per size in compile_sizes are compiled using configurations in inductor_compile_config. - + This setting is ignored if level<PIECEWISE.""" compile_sizes: Optional[list[Union[int, str]]] = None """Sizes to compile for inductor. In addition @@ -4299,6 +4324,7 @@ def set_splitting_ops_for_v1(self): self.splitting_ops = [] if self.full_cuda_graph else [ "vllm.unified_attention", "vllm.unified_attention_with_output", + "vllm.mamba_mixer2", ] @@ -4331,8 +4357,6 @@ class VllmConfig: """Decoding configuration.""" observability_config: Optional[ObservabilityConfig] = None """Observability configuration.""" - prompt_adapter_config: Optional[PromptAdapterConfig] = None - """Prompt adapter configuration.""" quant_config: Optional[QuantizationConfig] = None """Quantization configuration.""" compilation_config: CompilationConfig = field( @@ -4341,7 +4365,7 @@ class VllmConfig: As a shorthand, `-O<n>` can be used to directly specify the compilation level `n`: `-O3` is equivalent to `-O.level=3` (same as `-O='{"level":3}'`). - Currently, -O <n> and -O=<n> are supported as well but this will likely be + Currently, -O <n> and -O=<n> are supported as well but this will likely be removed in favor of clearer -O<n> syntax in the future. NOTE: level 0 is the default level without any optimization. level 1 and 2 @@ -4429,10 +4453,6 @@ def compute_hash(self) -> str: vllm_factors.append(self.observability_config.compute_hash()) else: vllm_factors.append("None") - if self.prompt_adapter_config: - vllm_factors.append(self.prompt_adapter_config.compute_hash()) - else: - vllm_factors.append("None") if self.quant_config: pass # should be captured by model_config.quantization if self.compilation_config: @@ -4540,10 +4560,6 @@ def __post_init__(self): if self.lora_config is not None: self.lora_config.verify_with_cache_config(self.cache_config) self.lora_config.verify_with_model_config(self.model_config) - self.lora_config.verify_lora_support() - if self.prompt_adapter_config is not None: - self.prompt_adapter_config.verify_with_model_config( - self.model_config) if self.quant_config is None and self.model_config is not None: self.quant_config = VllmConfig._get_quantization_config( @@ -4651,6 +4667,13 @@ def __post_init__(self): if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True + if self.model_config is not None and \ + self.model_config.attention_chunk_size is not None and \ + self.speculative_config is not None and \ + self.speculative_config.use_eagle(): + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + self.scheduler_config.disable_hybrid_kv_cache_manager = True def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: @@ -4703,7 +4726,6 @@ def _set_cudagraph_sizes(self): # calculate the default `batch_size_capture_list` if not envs.VLLM_USE_V1: batch_size_capture_list = [] - max_batchsize_to_capture = 0 if self.scheduler_config is not None and \ self.model_config is not None and \ not self.model_config.enforce_eager: @@ -4769,11 +4791,15 @@ def try_verify_and_update_config(self): if architecture is None: return - from vllm.model_executor.models.config import MODELS_CONFIG_MAP + from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) cls = MODELS_CONFIG_MAP.get(architecture, None) if cls is not None: cls.verify_and_update_config(self) + if self.model_config.is_hybrid: + HybridAttentionMambaModelConfig.verify_and_update_config(self) + if self.model_config.task == "classify": # Maybe convert ForCausalLM into ForSequenceClassification model. from vllm.model_executor.models.adapters import ( @@ -4922,3 +4948,52 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig, vllm_config.compilation_config.static_forward_context.items() if isinstance(layer, layer_type) } + + +@config +@dataclass +class SpeechToTextConfig: + """Configuration for speech-to-text models.""" + + sample_rate: float = 16_000 + """Sample rate (Hz) to resample input audio to. Most speech models expect + 16kHz audio input. The input audio will be automatically resampled to this + rate before processing.""" + + max_audio_clip_s: int = 30 + """Maximum duration in seconds for a single audio clip without chunking. + Audio longer than this will be split into smaller chunks if + `allow_audio_chunking` evaluates to True, otherwise it will be rejected.""" + + overlap_chunk_second: int = 1 + """Overlap duration in seconds between consecutive audio chunks when + splitting long audio. This helps maintain context across chunk boundaries + and improves transcription quality at split points.""" + + min_energy_split_window_size: Optional[int] = 1600 + """Window size in samples for finding low-energy (quiet) regions to split + audio chunks. The algorithm looks for the quietest moment within this + window to minimize cutting through speech. Default 1600 samples ≈ 100ms + at 16kHz. If None, no chunking will be done.""" + + @property + def allow_audio_chunking(self) -> bool: + return self.min_energy_split_window_size is not None + + +def update_config(config: DataclassInstanceT, + overrides: dict[str, Any]) -> DataclassInstanceT: + processed_overrides = {} + for field_name, value in overrides.items(): + assert hasattr( + config, field_name), f"{type(config)} has no field `{field_name}`" + current_value = getattr(config, field_name) + if is_dataclass(current_value) and not is_dataclass(value): + assert isinstance(value, dict), ( + f"Overrides to {type(config)}.{field_name} must be a dict" + f" or {type(current_value)}, but got {type(value)}") + value = update_config( + current_value, # type: ignore[type-var] + value) + processed_overrides[field_name] = value + return replace(config, **processed_overrides) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index ea490c32791c..92bc5e157e14 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -7,7 +7,6 @@ DeviceAwareBlockAllocator) from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator -from vllm.platforms import current_platform from vllm.utils import Device @@ -56,8 +55,7 @@ def create( - The block IDs are assigned contiguously, with GPU block IDs coming before CPU block IDs. """ - # For HPU, block id 0 is used only for padding - reserved_blocks = 1 if current_platform.is_hpu() else 0 + reserved_blocks = 0 block_ids = list( range(reserved_blocks, num_gpu_blocks + num_cpu_blocks)) num_gpu_blocks -= reserved_blocks diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0ef0396996b6..61346da145bb 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -15,7 +15,6 @@ from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupBase, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStage, @@ -165,8 +164,6 @@ def __post_init__(self): if self.num_loras > 0: self._sort_by_lora_ids() - self.num_prompt_adapters: int = len(self.prompt_adapter_requests) - def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in @@ -194,14 +191,6 @@ def lora_requests(self) -> Set[LoRARequest]: if g.seq_group.lora_request is not None } - @property - def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: - return { - g.seq_group.prompt_adapter_request - for g in self.scheduled_seq_groups - if g.seq_group.prompt_adapter_request is not None - } - @dataclass class SchedulerRunningOutputs: @@ -1648,7 +1637,6 @@ def schedule( multi_modal_placeholders=( seq_group.multi_modal_placeholders if scheduler_outputs.num_prefill_groups > 0 else None), - prompt_adapter_request=seq_group.prompt_adapter_request, ) else: # When SPMD mode is enabled, we only send delta data except for diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 1bc2d8e0281c..dc5923cdc5a0 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading -from typing import Optional +from typing import Optional, Union from weakref import WeakValueDictionary import torch @@ -138,6 +138,14 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None + ) -> Union[torch.Tensor, list[torch.Tensor]]: + raise NotImplementedError + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -172,6 +180,12 @@ def reduce_scatter(self, # Reshape before returning return output_tensor.movedim(0, dim).contiguous() + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None) -> torch.Tensor: + raise NotImplementedError + def gather(self, input_: torch.Tensor, dst: int = 0, @@ -240,8 +254,7 @@ def prepare_communication_buffer_for_model(self, if module.__class__.__name__ == "FusedMoE" ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module.moe_config, - module.quant_config) + module.quant_method.init_prepare_finalize(module.moe_config) def dispatch( self, hidden_states: torch.Tensor, diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 94effa0b2ca8..bda567f8489c 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -2,11 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional +from typing import Any, Optional, Union import torch from torch.distributed import ProcessGroup +from vllm.distributed.utils import pickle from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum @@ -26,7 +27,8 @@ def __init__(self, if (current_platform.get_cpu_architecture() == CpuArchEnum.X86) and hasattr( torch.ops._C, - "init_shm_manager") and unique_name.startswith("tp"): + "init_shm_manager") and (unique_name.startswith("tp") + or unique_name.startswith("pp")): self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): @@ -94,6 +96,19 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_size[dim + 1:]) return output_tensor + def send_tensor_dict( + self, + tensor_dict: dict[str, Union[torch.Tensor, Any]], + dst: int, + ) -> None: + return self.dist_module.send_tensor_dict(tensor_dict, dst) + + def recv_tensor_dict( + self, + src: int, + ) -> dict[str, Union[torch.Tensor, Any]]: + return self.dist_module.recv_tensor_dict(src) + class _CPUSHMDistributed: @@ -143,3 +158,44 @@ def all_gather_into_tensor(self, input: torch.Tensor, group: Optional[ProcessGroup] = None) -> None: torch.ops._C.shm_all_gather(self.handle, input, output) + + def send_tensor_dict( + self, + tensor_dict: dict[str, Union[torch.Tensor, Any]], + dst: int, + ) -> None: + key_list = list(tensor_dict.keys()) + value_list = list(tensor_dict.values()) + size_list = [] + for v in value_list: + if not isinstance(v, torch.Tensor): + raise RuntimeError( + "CpuCommunicator only supports sending tensors.") + size_list.append(v.size()) + key_size_tensor = torch.frombuffer(pickle.dumps([key_list, size_list]), + dtype=torch.uint8) + value_list.append(key_size_tensor) + + torch.ops._C.shm_send_tensor_list(self.handle, value_list, dst) + + return None + + def recv_tensor_dict( + self, + src: int, + ) -> dict[str, Union[torch.Tensor, Any]]: + tensor_list = torch.ops._C.shm_recv_tensor_list(self.handle, src) + + value_list: list[torch.Tensor] = tensor_list[:-1] + key_size_tensor = tensor_list[-1] + + key_size = pickle.loads(key_size_tensor.numpy().tobytes()) + key_list = key_size[0] + size_list = key_size[1] + assert len(key_list) == len(size_list) + assert len(key_list) == len(value_list) + + tensor_dict: dict[str, torch.Tensor] = {} + for key, size, t in zip(key_list, size_list, value_list): + tensor_dict[key] = t.view(size) + return tensor_dict diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 3958d566b174..e4804691f0f6 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Optional, Union import torch from torch.distributed import ProcessGroup @@ -142,6 +142,42 @@ def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): # Reshape before returning return output.movedim(0, dim).contiguous() + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None): + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + + # Note: This will produce an incorrect answer if we don't make + # the input_tensor contiguous. Possible bug in reduce_scatter_tensor? + input_tensor = input_.movedim(0, dim).contiguous() + + if sizes is not None: + assert len(sizes) == world_size + assert input_tensor.shape[0] == sum(sizes) + chunk_size = sizes[self.rank_in_group] + else: + assert input_tensor.shape[0] % world_size == 0 + chunk_size = input_tensor.shape[0] // world_size + output_shape = (chunk_size, ) + input_tensor.shape[1:] + + output = torch.empty(output_shape, + dtype=input_tensor.dtype, + device=input_tensor.device) + + if sizes is not None: + pynccl_comm.reduce_scatterv(output, input_, sizes=sizes) + else: + pynccl_comm.reduce_scatter(output, input_) + + # Reshape before returning + return output.movedim(0, dim).contiguous() + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: """Sends a tensor to the destination rank in a non-blocking way""" """NOTE: `dst` is the local rank of the destination rank.""" @@ -180,6 +216,51 @@ def destroy(self): self.all2all_manager.destroy() self.all2all_manager = None + def all_gatherv(self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None): + if dim != 0: + raise NotImplementedError("only dim 0 all-gatherv is supported") + world_size = self.world_size + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None and not pynccl_comm.disabled + + # 'sizes' is not needed if all inputs in the same group have the same + # shape + if sizes is not None and all(s == sizes[0] for s in sizes): + sizes = None + + def _all_gather_single(input_: torch.Tensor, + sizes: Optional[list[int]] = None): + input_size = input_.size() + if sizes is not None: + assert len(sizes) == world_size + assert input_.shape[dim] == sizes[self.rank_in_group] + output_size = (sum(sizes), ) + input_size[1:] + else: + output_size = (input_size[0] * world_size, ) + input_size[1:] + # Allocate output tensor. + output_tensor = torch.empty(output_size, + dtype=input_.dtype, + device=input_.device) + if sizes is not None: + pynccl_comm.all_gatherv(output_tensor, input_, sizes=sizes) + else: + pynccl_comm.all_gather(output_tensor, input_) + return output_tensor + + if isinstance(input_, torch.Tensor): + return _all_gather_single(input_, sizes) + + output_list = [] + pynccl_comm.group_start() + for inp in input_: + output_list.append(_all_gather_single(inp, sizes=sizes)) + pynccl_comm.group_end() + + return output_list + def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/distributed/device_communicators/hpu_communicator.py b/vllm/distributed/device_communicators/hpu_communicator.py deleted file mode 100644 index f00f6b62bf24..000000000000 --- a/vllm/distributed/device_communicators/hpu_communicator.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.distributed as dist - -from vllm.platforms import current_platform - -from .base_device_communicator import DeviceCommunicatorBase - -if current_platform.is_hpu(): - import habana_frameworks.torch as htorch # noqa: F401 - - -class HpuCommunicator(DeviceCommunicatorBase): - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - # FIXME(kzawora): this is a workaround for a bug in Habana PT bridge - # occurring when PT_HPU_ENABLE_LAZY_COLLECTIVES=true env var is used - # (which is required for tensor parallel HPUGraph inference) - htorch.core.mark_step() - dist.all_reduce(input_, group=self.device_group) - return input_ - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) - # All-gather. - htorch.core.mark_step() - dist.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) - return output_tensor diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 29486292996a..502bfd39005a 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -152,6 +152,40 @@ def all_gather(self, ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, cudaStream_t(stream.cuda_stream)) + def all_gatherv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + assert output_tensor.shape[0] == sum(sizes) + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + dst_slice = output_tensor[split_offset:split_offset + split_size] + self.nccl.ncclBroadcast( + buffer_type(input_tensor.data_ptr()), + buffer_type(dst_slice.data_ptr()), + dst_slice.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + root, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + split_offset += split_size + self.nccl.ncclGroupEnd() + def reduce_scatter(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, @@ -174,6 +208,38 @@ def reduce_scatter(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def reduce_scatterv( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + sizes: list[int], + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = current_stream() + + split_offset = 0 + self.nccl.ncclGroupStart() + for root, split_size in enumerate(sizes): + chunk = input_tensor[split_offset:split_offset + split_size, ...] + self.nccl.ncclReduce( + buffer_type(chunk.data_ptr()), + buffer_type(output_tensor.data_ptr()), chunk.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), root, self.comm, + cudaStream_t(stream.cuda_stream)) + split_offset += split_size + self.nccl.ncclGroupEnd() + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -216,3 +282,9 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + + def group_start(self): + self.nccl.ncclGroupStart() + + def group_end(self): + self.nccl.ncclGroupEnd() diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 3018a92da07c..a930b63bc26f 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -154,6 +154,17 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, int root, + # ncclComm_t comm, cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ctypes.c_int, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclAllGather( # const void* sendbuff, void* recvbuff, size_t count, # ncclDataType_t datatype, ncclComm_t comm, @@ -207,6 +218,10 @@ class NCCLLibrary: # it is better not to call it at all. # ncclResult_t ncclCommDestroy(ncclComm_t comm); Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + # ncclResult_t ncclGroupStart(); + Function("ncclGroupStart", ncclResult_t, []), + # ncclResult_t ncclGroupEnd(); + Function("ncclGroupEnd", ncclResult_t, []), ] # class attribute to store the mapping from the path to the library @@ -300,6 +315,18 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, root: int, + comm: ncclComm_t, stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduce"](sendbuff, recvbuff, count, + datatype, op, root, comm, + stream)) + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, count: int, datatype: int, op: int, comm: ncclComm_t, stream: cudaStream_t) -> None: @@ -342,6 +369,12 @@ def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclGroupStart(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupStart"]()) + + def ncclGroupEnd(self) -> None: + self.NCCL_CHECK(self._funcs["ncclGroupEnd"]()) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/distributed/device_communicators/xpu_communicator.py b/vllm/distributed/device_communicators/xpu_communicator.py index 216ff85c8bb7..dee5ed7a2883 100644 --- a/vllm/distributed/device_communicators/xpu_communicator.py +++ b/vllm/distributed/device_communicators/xpu_communicator.py @@ -53,3 +53,6 @@ def gather(self, else: output_tensor = None return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0) -> None: + dist.broadcast(input_, src=src, group=self.device_group) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 6b0a126ca9b2..af6462084968 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -29,12 +29,15 @@ import time from collections.abc import Sequence from dataclasses import dataclass +from typing import Optional, Union import torch -from torch.distributed import all_gather, all_reduce +from torch.distributed import ProcessGroup, all_gather, all_reduce from vllm.config import ParallelConfig -from vllm.distributed.parallel_state import get_ep_group, get_node_count +from vllm.distributed.parallel_state import (get_ep_group, get_node_count, + in_the_same_node_as) +from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -172,6 +175,9 @@ def build( model: MixtureOfExperts, device: torch.device, parallel_config: ParallelConfig, + global_expert_load: Optional[torch.Tensor] = None, + old_global_expert_indices: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None, ) -> "EplbState": """ Build the initial EPLB state. @@ -185,8 +191,16 @@ def build( physical_to_logical_map_list, device=device, ) + # Assuming 8 GPUs per node, this supports up to + # (1023 + 1) / 8 = 128 nodes for now. + # TODO(rui): make this configurable + MAX_EXPERT_REDUNDANCY = 1023 + assert model.num_redundant_experts <= MAX_EXPERT_REDUNDANCY, ( + f"num_redundant_experts {model.num_redundant_experts} " + f"must be less than or equal to {MAX_EXPERT_REDUNDANCY}") + max_slots_per_logical_expert = MAX_EXPERT_REDUNDANCY + 1 logical_to_physical_map = torch.full( - (model.num_logical_experts, model.num_redundant_experts + 1), + (model.num_logical_experts, max_slots_per_logical_expert), -1, device=device, ) @@ -235,11 +249,63 @@ def build( expert_rearrangement_step = max( 0, eplb_step_interval - eplb_step_interval // 4) + if global_expert_load is not None: + ep_group = get_ep_group().device_group + assert global_expert_load.shape == (model.num_moe_layers, + model.num_logical_experts) + assert global_expert_load.dtype == torch.int64 + + num_replicas = model.num_physical_experts + num_groups = model.num_expert_groups + num_nodes = get_node_count() + num_gpus = ep_group.size() + + if num_gpus % num_nodes != 0: + num_nodes = 1 + logger.warning_once( + f"num_gpus % num_nodes != 0, " + "not using hierarchical rearrangement algorithm.\n" + f"{num_gpus=}, {num_nodes=}") + + # Get new expert mappings + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = (rebalance_experts( + global_expert_load, + num_replicas, + num_groups, + num_nodes, + num_gpus, + )) + + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert max_physical_slots <= logical_to_physical_map.shape[-1] + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, logical_to_physical_map.shape[-1] - max_physical_slots), + value=-1, + ) + physical_to_logical_map = new_physical_to_logical_map.to(device) + logical_to_physical_map.copy_(new_logical_to_physical_map) + logical_replica_count.copy_(new_logical_replica_count) + model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) + if global_expert_load is not None: + rearrange_expert_weights_inplace( + old_global_expert_indices, + new_physical_to_logical_map, + model.expert_weights, + ep_group, + False, + rank_mapping, + ) + expert_rearrangement_step = 0 return cls( physical_to_logical_map, @@ -337,7 +403,10 @@ def step(self, def rearrange(self, model: MixtureOfExperts, - is_profile: bool = False) -> None: + is_profile: bool = False, + execute_shuffle: bool = True, + global_expert_load: Optional[torch.Tensor] = None, + rank_mapping: Optional[dict[int, int]] = None) -> None: """ Rearrange the experts according to the current load. """ @@ -353,42 +422,79 @@ def rearrange(self, logger.info("Rearranging experts %s...", "(profile)" if is_profile else "") - # This mapping is only used here, so we do not store it in the state - physical_expert_start = ep_rank * model.num_local_physical_experts - physical_expert_end = (physical_expert_start + - model.num_local_physical_experts) - # (num_moe_layers, num_local_physical_experts) - local_physical_to_logical_map = self.physical_to_logical_map[ - :, - physical_expert_start:physical_expert_end, - ] + if global_expert_load is None: + # This mapping is only used here, so we do not store it in the state + physical_expert_start = ep_rank * model.num_local_physical_experts + physical_expert_end = (physical_expert_start + + model.num_local_physical_experts) + # (num_moe_layers, num_local_physical_experts) + local_physical_to_logical_map = self.physical_to_logical_map[ + :, + physical_expert_start:physical_expert_end, + ] - # Map the local physical expert load to global logical experts - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - model.num_moe_layers, - model.num_logical_experts, - dtype=self.expert_load_window.dtype, - device=self.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=local_physical_to_logical_map.unsqueeze(0).expand_as( - self.expert_load_window).long(), - src=self.expert_load_window, - ) + # Map the local physical expert load to global logical experts + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + model.num_moe_layers, + model.num_logical_experts, + dtype=self.expert_load_window.dtype, + device=self.expert_load_window.device, + ) + logical_expert_load_window.scatter_add_( + dim=-1, + index=local_physical_to_logical_map.unsqueeze(0).expand_as( + self.expert_load_window).long(), + src=self.expert_load_window, + ) - # Perform all-reduce to get the expert load across all ranks - global_expert_load_window = logical_expert_load_window.sum(dim=0) - all_reduce(global_expert_load_window, group=ep_group) + if not execute_shuffle: + metadata = torch.tensor( + [ + model.num_moe_layers, model.num_logical_experts, + self.physical_to_logical_map.shape[1] + ], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.broadcast(metadata, + group=get_ep_group().cpu_group, + group_src=0) + + # Perform all-reduce to get the expert load across all ranks + global_expert_load_window = logical_expert_load_window.sum(dim=0) + all_reduce(global_expert_load_window, group=ep_group) + + if not execute_shuffle: + # (num_moe_layers, old_num_physical_experts) + old_global_expert_indices = self.physical_to_logical_map + torch.distributed.broadcast(old_global_expert_indices, + group=ep_group, + group_src=0) + return global_expert_load_window + else: + assert execute_shuffle + global_expert_load_window = global_expert_load # TODO(bowen): Treat differently for prefill and decode nodes num_replicas = model.num_physical_experts num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() + if rank_mapping is not None and len(rank_mapping) == ep_group.size(): + # NOTE(yongji): scale down, we need to rebalance the experts on + # remaining GPUs, transfer the experts while we haven't shutdown + # the GPUs to be released. + cpu_group = get_ep_group().cpu_group + num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) + num_gpus = sum(new_rank != -1 + for new_rank in rank_mapping.values()) + num_replicas = num_replicas // ep_group.size( + ) * num_gpus # handle num replicas change + else: + num_nodes = get_node_count() + num_gpus = ep_group.size() if num_gpus % num_nodes != 0: + self.num_nodes = 1 logger.warning_once( f"num_gpus % num_nodes != 0, " "not using hierarchical rearrangement algorithm.\n" @@ -414,10 +520,24 @@ def rearrange(self, model.expert_weights, ep_group, is_profile, + rank_mapping, ) if not is_profile: - self.physical_to_logical_map.copy_(new_physical_to_logical_map) + if self.physical_to_logical_map.shape[ + 1] != new_physical_to_logical_map.shape[1]: + self.physical_to_logical_map = new_physical_to_logical_map.to( + self.physical_to_logical_map.device) + else: + self.physical_to_logical_map.copy_(new_physical_to_logical_map) + max_physical_slots = new_logical_to_physical_map.shape[-1] + assert max_physical_slots <= self.logical_to_physical_map.shape[-1] + new_logical_to_physical_map = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, + self.logical_to_physical_map.shape[-1] - max_physical_slots), + value=-1, + ) self.logical_to_physical_map.copy_(new_logical_to_physical_map) self.logical_replica_count.copy_(new_logical_replica_count) @@ -430,3 +550,69 @@ def rearrange(self, " (profile) " if is_profile else " ", time_end - time_start, ) + + @staticmethod + def recv_state() -> tuple[torch.Tensor, torch.Tensor]: + """ + Receive the expert load and old placement from the master rank. + """ + ep_group = get_ep_group() + metadata = torch.empty(3, dtype=torch.int32, device="cpu") + torch.distributed.broadcast(metadata, + group=ep_group.cpu_group, + group_src=0) + num_moe_layers, num_logical_experts, num_old_physical_experts = ( + metadata.tolist()) + global_expert_load = torch.zeros( + (num_moe_layers, num_logical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + all_reduce(global_expert_load, group=ep_group.device_group) + old_global_expert_indices = torch.empty( + (num_moe_layers, num_old_physical_experts), + dtype=torch.int64, + device=ep_group.device, + ) + torch.distributed.broadcast(old_global_expert_indices, + group=ep_group.device_group, + group_src=0) + + return global_expert_load, old_global_expert_indices + + +def _node_count_with_rank_mapping( + pg: Union[ProcessGroup, StatelessProcessGroup], + rank_mapping: dict[int, int], +) -> int: + if isinstance(pg, ProcessGroup): + world_size = torch.distributed.get_world_size(group=pg) + else: + world_size = pg.world_size + + if world_size == 1: + return 1 + + # Build node assignment map + node_assignment = [0] * world_size # rank -> node_id + next_node_id = 0 + + for current_rank in range(world_size): + if node_assignment[current_rank] != 0: + continue # Already assigned to a node + + assert current_rank in rank_mapping + if rank_mapping[current_rank] == -1: + continue # Pending shutdown + + # Assign current rank to a new node + next_node_id += 1 + node_assignment[current_rank] = next_node_id + + # Find all ranks on the same node as current_rank + same_node_flags = in_the_same_node_as(pg, current_rank) + for other_rank, is_same_node in enumerate(same_node_flags): + if is_same_node and node_assignment[other_rank] == 0: + node_assignment[other_rank] = next_node_id + + return next_node_id diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 2ef8587b559b..f8a7d1170bb0 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -8,6 +8,7 @@ from collections.abc import Iterable, MutableSequence, Sequence from functools import partial +from typing import Optional import torch from torch.distributed import (P2POp, ProcessGroup, all_gather, @@ -127,6 +128,8 @@ def shuffle_layer( dst_global = local2global(dst) if is_received_locally[dst]: continue + if old_indices[src_global] == -1 or new_indices[dst_global] == -1: + continue if old_indices[src_global] == new_indices[dst_global]: is_received_locally[dst] = True for weight, buffer in zip(expert_weights, @@ -139,6 +142,8 @@ def shuffle_layer( experts_send_loc: dict[int, int] = {} for src in range(num_local_experts): expert = old_indices[local2global(src)] + if expert == -1: + continue if expert in experts_send_loc: continue experts_send_loc[expert] = src @@ -181,6 +186,8 @@ def shuffle_layer( if is_received_locally[dst]: continue expert = new_indices[local2global(dst)] + if expert == -1: + continue if expert in experts_recv_loc: continue experts_recv_loc[expert] = dst @@ -227,6 +234,8 @@ def shuffle_layer( weight[dst].copy_(buffer[dst]) else: expert = new_indices[local2global(dst)] + if expert == -1: + continue src = experts_recv_loc[expert] for weight, buffer in zip(expert_weights, expert_weights_buffer): weight[dst].copy_(buffer[src]) @@ -238,6 +247,7 @@ def rearrange_expert_weights_inplace( expert_weights: Sequence[Iterable[torch.Tensor]], ep_group: ProcessGroup, is_profile: bool = False, + rank_mapping: Optional[dict[int, int]] = None, ) -> None: """ Rearranges the expert weights in place according to the new expert indices. @@ -256,7 +266,28 @@ def rearrange_expert_weights_inplace( is_profile (bool): If `True`, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. + rank_mapping: A dictionary mapping old rank to new rank. """ + if rank_mapping is not None: + if len(rank_mapping) == ep_group.size(): + # scale down + new_global_expert_indices = \ + _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices, + rank_mapping, + ) + else: + # scale up + old_global_expert_indices = \ + _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices, + rank_mapping, + ep_group.size(), + ) + + assert old_global_expert_indices.shape[ + 1] == new_global_expert_indices.shape[1] + num_moe_layers, num_physical_experts = old_global_expert_indices.shape assert len(expert_weights) == num_moe_layers @@ -304,4 +335,90 @@ def rearrange_expert_weights_inplace( ) +def _map_old_expert_indices_with_rank_mapping( + old_global_expert_indices: torch.Tensor, + rank_mapping: dict[int, int], + new_ep_size: int, +) -> torch.Tensor: + """ + Map the old global expert indices to the new global expert indices. + + Args: + old_global_expert_indices: + Shape (num_layers, old_ep_size * num_local_physical_experts). + rank_mapping: Mapping from old rank to new rank. + new_ep_size: New expert parallelism size. + + Returns: + Mapped expert indices with shape + (num_layers, new_ep_size * num_local_physical_experts). + """ + num_layers, old_num_physical_experts = old_global_expert_indices.shape + assert rank_mapping, "Rank mapping is required" + + # Get sizes from parameters and rank_mapping + old_ep_size = len(rank_mapping) + num_local_physical_experts = old_num_physical_experts // old_ep_size + new_num_physical_experts = new_ep_size * num_local_physical_experts + + # Create mapped tensor with new shape, initialized to -1 + mapped_expert_indices = torch.full( + (num_layers, new_num_physical_experts), + fill_value=-1, + dtype=old_global_expert_indices.dtype, + device=old_global_expert_indices.device, + ) + + # Handle rank mapping (scale up/down with rank changes) + for old_rank in range(old_ep_size): + new_rank = rank_mapping.get(old_rank) + if new_rank is not None and new_rank >= 0 and new_rank < new_ep_size: + # This old rank exists in the new configuration + old_start_idx = old_rank * num_local_physical_experts + old_end_idx = (old_rank + 1) * num_local_physical_experts + new_start_idx = new_rank * num_local_physical_experts + new_end_idx = (new_rank + 1) * num_local_physical_experts + + mapped_expert_indices[:, new_start_idx:new_end_idx] = \ + old_global_expert_indices[:, old_start_idx:old_end_idx] + # If new_rank is None or >= new_ep_size, the experts remain -1 + # (scale down case) + + return mapped_expert_indices + + +def _map_new_expert_indices_with_rank_mapping( + new_global_expert_indices: torch.Tensor, + rank_mapping: dict[int, int], +) -> torch.Tensor: + num_layers, new_num_physical_experts = new_global_expert_indices.shape + assert rank_mapping, "Rank mapping is required" + + # Get sizes from parameters and rank_mapping + old_ep_size = len(rank_mapping) + new_ep_size = sum(new_rank != -1 for new_rank in rank_mapping.values()) + num_local_physical_experts = new_num_physical_experts // new_ep_size + old_num_physical_experts = old_ep_size * num_local_physical_experts + + mapped_expert_indices = torch.full( + (num_layers, old_num_physical_experts), + fill_value=-1, + dtype=new_global_expert_indices.dtype, + device=new_global_expert_indices.device, + ) + + for old_rank in range(old_ep_size): + new_rank = rank_mapping[old_rank] + if new_rank >= 0 and new_rank < new_ep_size: + old_start_idx = old_rank * num_local_physical_experts + old_end_idx = (old_rank + 1) * num_local_physical_experts + new_start_idx = new_rank * num_local_physical_experts + new_end_idx = (new_rank + 1) * num_local_physical_experts + + mapped_expert_indices[:, old_start_idx:old_end_idx] = \ + new_global_expert_indices[:, new_start_idx:new_end_idx] + + return mapped_expert_indices + + __all__ = ["rearrange_expert_weights_inplace"] diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 5cbc8ca31752..c179d6cc29b7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -3,12 +3,18 @@ """ KV cache helper for store. """ +from collections import defaultdict +from collections.abc import Sequence +from concurrent.futures import CancelledError, Future +from typing import Optional, cast + import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config from vllm.logger import init_logger +from vllm.v1.outputs import ModelRunnerOutput logger = init_logger(__name__) @@ -107,3 +113,87 @@ def get_kv_connector_cache_layout(): "layout to HND for better xfer performance.") return "HND" return "NHD" + + +class KVOutputAggregator: + """Utility class to aggregate the output of all workers into a single + output corresponding to Rank 0 for scheduler.""" + + def __init__(self, world_size: int): + # Complete transfer tracker. Used by to track finished requests + # [req_id -> n_finished_workers] + self._recv_remaining_count = defaultdict[str, int](lambda: world_size) + self._send_remaining_count = defaultdict[str, int](lambda: world_size) + + def aggregate(self, + outputs: list[ModelRunnerOutput], + output_rank: int = 0) -> ModelRunnerOutput: + # aggregate finished_sending, finished_recving from all workers + + def update_finished_set(req_ids: Optional[set[str]], + remaining_count_dict: dict[str, int], + finished_set: set[str]) -> None: + for req_id in req_ids or (): + new_count = remaining_count_dict[req_id] - 1 + if new_count == 0: + finished_set.add(req_id) + del remaining_count_dict[req_id] + else: + remaining_count_dict[req_id] = new_count + + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + update_finished_set(output.finished_sending, + self._send_remaining_count, finished_sending) + update_finished_set(output.finished_recving, + self._recv_remaining_count, finished_recving) + + # select output of the worker specified by output_rank + output = outputs[output_rank] + + # set the aggregated finished_sending / finished_recving + # if output.finished_sending/recving is not empty, but the other ranks + # still have unfinished send/recv, we want to set the aggregated + # finished_sending/recving to None until all ranks have finished + # send/recv + output.finished_sending = finished_sending if finished_sending else None + output.finished_recving = finished_recving if finished_recving else None + + return output + + def async_aggregate(self, + output_futures: Sequence[Future[ModelRunnerOutput]], + output_rank: int = 0) -> Future[ModelRunnerOutput]: + """Takes a list of futures and returns a single future which resolves + to the respective list of outputs.""" + result_future: Future[ModelRunnerOutput] = Future() + + outputs: list[Optional[ModelRunnerOutput]] = [None + ] * len(output_futures) + + def make_callback(idx): + + def callback(fut): + if result_future.done(): + return + + try: + outputs[idx] = fut.result() + except CancelledError: + result_future.cancel() + except Exception as e: + result_future.set_exception(e) + + # this check assumes io_thread_pool uses a single thread + if all(outputs): + result_future.set_result( + self.aggregate(cast(list[ModelRunnerOutput], outputs), + output_rank)) + + return callback + + for i, output_future in enumerate(output_futures): + output_future.add_done_callback(make_callback(i)) + + return result_future diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f80b5eba235d..e1245775bea3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -57,7 +57,7 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -class KVConnectorMetadata: +class KVConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the Scheduler KVConnector and Worker KVConnector. @@ -71,7 +71,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design.") - self._connector_metadata = KVConnectorMetadata() + self._connector_metadata: Optional[KVConnectorMetadata] = None self._vllm_config = vllm_config self._role = role @@ -102,7 +102,7 @@ def clear_connector_metadata(self) -> None: This function should be called by the model runner every time after the model execution. """ - self._connector_metadata = KVConnectorMetadata() + self._connector_metadata = None def _get_connector_metadata(self) -> KVConnectorMetadata: """Get the connector metadata. @@ -112,6 +112,9 @@ def _get_connector_metadata(self) -> KVConnectorMetadata: Returns: ConnectorMetadata: the connector metadata. """ + + # Should only be called while set to valid metadata. + assert self._connector_metadata is not None return self._connector_metadata def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): @@ -190,7 +193,9 @@ def get_finished( ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have - finished generating tokens. + finished generating tokens on the worker. + The scheduler process (via the Executors) will use this output + to track which workers are done. Returns: ids of requests that have finished asynchronous transfer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index be3c23399419..a2eaa0040191 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -47,7 +47,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): assert ktcs is not None for ktc in ktcs: temp_config = copy.copy(vllm_config) - temp_config.kv_transfer_config = KVTransferConfig(**ktc) + engine_id = ktc.get("engine_id", + vllm_config.kv_transfer_config.engine_id) + temp_config.kv_transfer_config = KVTransferConfig( + **ktc, engine_id=engine_id) self._connectors.append( KVConnectorFactory.create_connector_v1(temp_config, role)) @@ -187,7 +190,7 @@ def request_finished( async_saves += 1 if txfer_params is not None: if kv_txfer_params is not None: - #TODO we can probably change this to merge the dicts here, + # TODO we can probably change this to merge the dicts here, # checking for key clashes. raise RuntimeError( "Only one connector can produce KV transfer params") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 56ae1acf8571..0c5986bfafaa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -79,7 +79,8 @@ class ReqMeta: class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): - self.requests: dict[ReqId, ReqMeta] = {} + self.reqs_to_recv: dict[ReqId, ReqMeta] = {} + self.reqs_to_send: dict[ReqId, float] = {} def add_new_req( self, @@ -87,7 +88,7 @@ def add_new_req( local_block_ids: list[int], kv_transfer_params: dict[str, Any], ): - self.requests[request_id] = ReqMeta( + self.reqs_to_recv[request_id] = ReqMeta( local_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], @@ -194,10 +195,12 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) - # Requests that need to start recv. + # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} + # Reqs to send and their expiration time + self._reqs_need_send: dict[ReqId, float] = {} def get_num_new_matched_tokens( self, request: "Request", @@ -284,6 +287,9 @@ def build_connector_meta( # Clear the list once workers start the transfers self._reqs_need_recv.clear() + meta.reqs_to_send = self._reqs_need_send + self._reqs_need_send = {} + return meta def request_finished( @@ -325,6 +331,11 @@ def request_finished( # If prompt < block_size, no xfer so free blocks immediately. delay_free_blocks = len(computed_block_ids) > 0 + if delay_free_blocks: + # Prefill request on remote. It will be read from D upon completion + self._reqs_need_send[request.request_id] = time.perf_counter( + ) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, @@ -394,14 +405,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # In progress transfers. # [req_id -> list[handle]] self._recving_transfers = defaultdict[ReqId, list[Transfer]](list) - - # Complete transfer tracker. Used by the rank 0 to track finished - # transactions on ranks 1 to N-1. - # [req_id -> count] - self._done_recving_count: defaultdict[ReqId, - int] = defaultdict(lambda: 0) - self._done_sending_count: defaultdict[ReqId, - int] = defaultdict(lambda: 0) + # Track the expiration time of requests that are waiting to be sent. + self._reqs_to_send: dict[ReqId, float] = {} # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None @@ -475,8 +480,13 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata, "Connection listener got unexpected message %s", msg) sock.send_multipart((identity, b"", encoded_data)) - def _nixl_handshake(self, host: str, port: int, - remote_tp_size: int) -> dict[int, str]: + def _nixl_handshake( + self, + host: str, + port: int, + remote_tp_size: int, + expected_engine_id: str, + ) -> dict[int, str]: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() @@ -485,26 +495,6 @@ def _nixl_handshake(self, host: str, port: int, # a hack to keep us moving. We will switch when moving to etcd # or where we have a single ZMQ socket in the scheduler. - def handshake(path: str, rank: int) -> str: - # Send query for the request. - with zmq_ctx(zmq.REQ, path) as sock: - sock.send(GET_META_MSG) - metadata_bytes = sock.recv() - decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) - metadata = decoder.decode(metadata_bytes) - got_metadata_time = time.perf_counter() - - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, rank, remote_tp_size) - setup_agent_time = time.perf_counter() - - logger.debug("NIXL handshake: get metadata took: %s", - got_metadata_time - start_time) - logger.debug("NIXL handshake: add agent took: %s", - setup_agent_time - got_metadata_time) - return remote_agent_name - # Handshake only with the remote TP rank that current local rank will # pull from. With homogeneous TP it happens to be the same rank_i. tp_ratio = self._tp_size[self.engine_id] // remote_tp_size @@ -512,8 +502,32 @@ def handshake(path: str, rank: int) -> str: path = make_zmq_path("tcp", host, port + p_remote_rank) logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank) + + # Send query for the request. + with zmq_ctx(zmq.REQ, path) as sock: + sock.send(GET_META_MSG) + metadata_bytes = sock.recv() + decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + metadata = decoder.decode(metadata_bytes) + got_metadata_time = time.perf_counter() + logger.debug("NIXL handshake: get metadata took: %s", + got_metadata_time - start_time) + + # Ensure engine id matches. + if metadata.engine_id != expected_engine_id: + raise RuntimeError(f"Remote NIXL agent engine ID mismatch. " + f"Expected {expected_engine_id}," + f"received {metadata.engine_id}.") + + # Register Remote agent. + remote_agent_name = self.add_remote_agent(metadata, p_remote_rank, + remote_tp_size) + setup_agent_time = time.perf_counter() + logger.debug("NIXL handshake: add agent took: %s", + setup_agent_time - got_metadata_time) + # Remote rank -> agent name. - return {p_remote_rank: handshake(path, p_remote_rank)} + return {p_remote_rank: remote_agent_name} def _background_nixl_handshake(self, req_id: str, remote_engine_id: EngineId, meta: ReqMeta): @@ -522,7 +536,7 @@ def _background_nixl_handshake(self, req_id: str, if fut is None: fut = self._handshake_initiation_executor.submit( self._nixl_handshake, meta.remote_host, meta.remote_port, - meta.tp_size) + meta.tp_size, remote_engine_id) self._handshake_futures[remote_engine_id] = fut def done_callback(f: Future[dict[int, str]], eid=remote_engine_id): @@ -725,10 +739,10 @@ def add_remote_agent(self, if remote_tp_rank in self._remote_agents.get(engine_id, {}): return self._remote_agents[engine_id][remote_tp_rank] - if engine_id in self._tp_size: - assert self._tp_size[engine_id] == remote_tp_size - else: + if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size + else: + assert self._tp_size[engine_id] == remote_tp_size # We may eventually enable this after asserting equality in cache # layout and close outputs. assert nixl_agent_meta.attn_backend_name == self.backend_name @@ -808,15 +822,9 @@ def add_remote_agent(self, def get_finished(self) -> tuple[set[str], set[str]]: """ - Get requests that are done sending or recving. - - In TP>1 setup, each rank exchanges KVs with its counterpart - ranks independently. get_finished() runs in a worker creates - the done_sending and done_recving sets that are sent to the - scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs - are done before adding to finished, Ranks 1 to N-1 communicate - to Rank 0 once their transaction is done + Rank 0 returns - finished sets to Scheduler only once all ranks are done. + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. """ done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) @@ -826,50 +834,17 @@ def get_finished(self) -> tuple[set[str], set[str]]: "and %s requests done recving", self.tp_rank, len(done_sending), len(done_recving)) - if self.world_size == 1: - return done_sending, done_recving - - # Rank 0: get finished from all other ranks. - if self.tp_rank == 0: - for req_id in done_sending: - self._done_sending_count[req_id] += 1 - for req_id in done_recving: - self._done_recving_count[req_id] += 1 - - # Keep track of how many other ranks have finished. - other_ranks_finished_ids: list[str] = [] - for i in range(1, self.world_size): - other_ranks_finished_ids.extend( - self.tp_group.recv_object(src=i)) - for req_id in other_ranks_finished_ids: - if (req_id in self._done_recving_count - or req_id in self._recving_transfers): - self._done_recving_count[req_id] += 1 - else: - self._done_sending_count[req_id] += 1 - - # Return ids that finished on all ranks to the scheduler. - all_done_recving: set[str] = set() - for req_id in list(self._done_recving_count.keys()): - if self._done_recving_count[req_id] == self.world_size: - del self._done_recving_count[req_id] - all_done_recving.add(req_id) - - all_done_sending: set[str] = set() - for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] == self.world_size: - del self._done_sending_count[req_id] - all_done_sending.add(req_id) + # Handle timeout to avoid stranding blocks on remote. + now = time.perf_counter() + while self._reqs_to_send: + req_id, expires = next(iter(self._reqs_to_send.items())) + # Sorted dict, oldest requests are put first so we can exit early. + if now < expires: + break + del self._reqs_to_send[req_id] + done_sending.add(req_id) - return all_done_sending, all_done_recving - - # Ranks 1 to N-1: send finished ids to Rank 0. - else: - finished_req_ids = list(done_recving.union(done_sending)) - self.tp_group.send_object(finished_req_ids, dst=0) - - # Unused as only Rank 0 results are sent to scheduler. - return done_sending, done_recving + return done_sending, done_recving def _get_new_notifs(self) -> set[str]: """ @@ -887,6 +862,7 @@ def _get_new_notifs(self) -> set[str]: tp_ratio): notified_req_ids.add(req_id) del self.consumer_notification_counts_by_req[req_id] + del self._reqs_to_send[req_id] return notified_req_ids def _pop_done_transfers( @@ -921,7 +897,7 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): Start loading by triggering non-blocking nixl_xfer. We check for these trnxs to complete in each step(). """ - for req_id, meta in metadata.requests.items(): + for req_id, meta in metadata.reqs_to_recv.items(): remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " @@ -943,6 +919,9 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): while not self._ready_requests.empty(): self._read_blocks_for_req(*self._ready_requests.get_nowait()) + # Add to requests that are waiting to be read and track expiration. + self._reqs_to_send.update(metadata.reqs_to_send) + def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug( "Remote agent %s available, calling _read_blocks for req %s", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 52f589a6d718..d47a75461d72 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( P2pNcclEngine) from vllm.distributed.parallel_state import get_world_group -from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import MLACommonMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -238,32 +237,16 @@ def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, assert self.p2p_nccl_engine is not None - def extract_kv_from_layer( - layer: torch.Tensor, - slot_mapping: torch.Tensor, - ) -> torch.Tensor: - """Extract the KV cache from the layer. - - Assume the shape of the layer is (2, num_pages, page_size, xxx) - if MLA is not used, and (num_pages, page_size, xxx) otherwise. - """ - if isinstance(attn_metadata, MLACommonMetadata): - num_pages, page_size = layer.shape[0], layer.shape[1] - return layer.reshape(num_pages * page_size, -1)[slot_mapping, - ...] - num_pages, page_size = layer.shape[1], layer.shape[2] - return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, - ...] - connector_metadata = self._get_connector_metadata() assert isinstance(connector_metadata, P2pNcclConnectorMetadata) for request in connector_metadata.requests: request_id = request.request_id ip, port = self.parse_request_id(request_id, True) remote_address = ip + ":" + str(port + self._rank) - kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping) - self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name, - kv_cache, remote_address) + self.p2p_nccl_engine.send_tensor( + request_id + "#" + layer_name, kv_layer, remote_address, + request.slot_mapping, + isinstance(attn_metadata, MLACommonMetadata)) def wait_for_save(self): if self.is_producer: @@ -286,9 +269,10 @@ def get_finished( assert self.p2p_nccl_engine is not None - forward_context: ForwardContext = get_forward_context() + no_compile_layers = ( + self._vllm_config.compilation_config.static_forward_context) return self.p2p_nccl_engine.get_finished(finished_req_ids, - forward_context) + no_compile_layers) # ============================== # Scheduler-side methods @@ -418,14 +402,6 @@ def build_connector_meta( block_ids=block_ids, block_size=self._block_size) - # Requests loaded asynchronously are not in the scheduler_output. - # for request_id in self._requests_need_load: - # request, block_ids = self._requests_need_load[request_id] - # meta.add_request(request_id=request.request_id, - # token_ids=request.prompt_token_ids, - # block_ids=block_ids, - # block_size=self._block_size) - self._requests_need_load.clear() return meta diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py index 6c9ccb2e301e..b94f2296dcb3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py @@ -8,7 +8,8 @@ import typing from collections import deque from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional +from dataclasses import dataclass +from typing import Any, Optional import msgpack import torch @@ -21,9 +22,6 @@ TensorMemoryPool) from vllm.utils import current_stream, get_ip -if TYPE_CHECKING: - from vllm.forward_context import ForwardContext - logger = logging.getLogger(__name__) DEFAULT_MEM_POOL_SIZE_GB = 32 @@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str): os.environ.pop(var, None) +@dataclass +class SendQueueItem: + tensor_id: str + remote_address: str + tensor: torch.Tensor + slot_mapping: torch.Tensor + is_mla: bool + + class P2pNcclEngine: def __init__(self, @@ -112,24 +119,26 @@ def __init__(self, self.send_stream = torch.cuda.Stream() self.recv_stream = torch.cuda.Stream() - mem_pool_size_gb = self.config.get_from_extra_config( - "mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB) - self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) * - 1024**3) # GB + mem_pool_size_gb = float( + self.config.get_from_extra_config("mem_pool_size_gb", + DEFAULT_MEM_POOL_SIZE_GB)) + self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb * + 1024**3)) # GB # The sending type includes tree mutually exclusive options: # PUT, GET, PUT_ASYNC. - self.send_type = self.config.get_from_extra_config("send_type", "PUT") + self.send_type = self.config.get_from_extra_config( + "send_type", "PUT_ASYNC") if self.send_type == "GET": # tensor_id: torch.Tensor self.send_store: dict[str, torch.Tensor] = {} else: # PUT or PUT_ASYNC # tensor_id: torch.Tensor - self.send_queue: deque[list[Any]] = deque() + self.send_queue: deque[SendQueueItem] = deque() self.send_request_id_to_tensor_ids: dict[str, set[str]] = {} if self.send_type == "PUT_ASYNC": - self._send_thread = threading.Thread(target=self._send_async, + self._send_thread = threading.Thread(target=self.send_async, daemon=True) self._send_thread.start() @@ -146,13 +155,12 @@ def __init__(self, "nccl_num_channels", "8") self._listener_thread = threading.Thread( - target=self._listen_for_requests, daemon=True) + target=self.listen_for_requests, daemon=True) self._listener_thread.start() self._ping_thread = None if port_offset == 0 and self.proxy_address != "": - self._ping_thread = threading.Thread(target=self._ping, - daemon=True) + self._ping_thread = threading.Thread(target=self.ping, daemon=True) self._ping_thread.start() logger.info( @@ -162,7 +170,7 @@ def __init__(self, self.http_address, self.zmq_address, self.proxy_address, self.send_type, self.buffer_size_threshold, self.nccl_num_channels) - def _create_connect(self, remote_address: typing.Optional[str] = None): + def create_connect(self, remote_address: typing.Optional[str] = None): assert remote_address is not None if remote_address not in self.socks: sock = self.context.socket(zmq.DEALER) @@ -184,7 +192,7 @@ def _create_connect(self, remote_address: typing.Optional[str] = None): comm: ncclComm_t = self.nccl.ncclCommInitRank( 2, unique_id, rank) self.comms[remote_address] = (comm, rank) - logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s", + logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s", self.zmq_address, remote_address, rank) return self.socks[remote_address], self.comms[remote_address] @@ -194,44 +202,54 @@ def send_tensor( tensor_id: str, tensor: torch.Tensor, remote_address: typing.Optional[str] = None, + slot_mapping: torch.Tensor = None, + is_mla: bool = False, ) -> bool: if remote_address is None: with self.recv_store_cv: self.recv_store[tensor_id] = tensor self.recv_store_cv.notify() return True - else: - if self.send_type == "PUT": - return self._send_sync(tensor_id, tensor, remote_address) - elif self.send_type == "PUT_ASYNC": - with self.send_queue_cv: - self.send_queue.append([tensor_id, remote_address, tensor]) - self.send_queue_cv.notify() - else: # GET - with self.send_store_cv: - tensor_size = tensor.element_size() * tensor.numel() - while (self.buffer_size + tensor_size - > self.buffer_size_threshold): - oldest_tenser_id = next(iter(self.send_store)) - oldest_tenser = self.send_store.pop(oldest_tenser_id) - oldest_tenser_size = oldest_tenser.element_size( - ) * oldest_tenser.numel() - self.buffer_size -= oldest_tenser_size - logger.info( - "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," - " buffer_size:%d, oldest_tenser_size:%d, rank:%d", - remote_address, tensor_id, tensor_size, - self.buffer_size, oldest_tenser_size, self.rank) - - self.send_store[tensor_id] = tensor - self.buffer_size += tensor_size - logger.debug( - "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " - "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", - remote_address, tensor_id, tensor_size, tensor.shape, - self.rank, self.buffer_size, - self.buffer_size / self.buffer_size_threshold * 100) + item = SendQueueItem(tensor_id=tensor_id, + remote_address=remote_address, + tensor=tensor, + slot_mapping=slot_mapping, + is_mla=is_mla) + + if self.send_type == "PUT": + return self.send_sync(item) + + if self.send_type == "PUT_ASYNC": + with self.send_queue_cv: + self.send_queue.append(item) + self.send_queue_cv.notify() + return True + + # GET + with self.send_store_cv: + tensor_size = tensor.element_size() * tensor.numel() + while (self.buffer_size + tensor_size + > self.buffer_size_threshold): + oldest_tenser_id = next(iter(self.send_store)) + oldest_tenser = self.send_store.pop(oldest_tenser_id) + oldest_tenser_size = oldest_tenser.element_size( + ) * oldest_tenser.numel() + self.buffer_size -= oldest_tenser_size + logger.info( + "⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d," + " buffer_size:%d, oldest_tenser_size:%d, rank:%d", + remote_address, tensor_id, tensor_size, self.buffer_size, + oldest_tenser_size, self.rank) + + self.send_store[tensor_id] = tensor + self.buffer_size += tensor_size + logger.debug( + "🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, " + "shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address, + tensor_id, tensor_size, tensor.shape, self.rank, + self.buffer_size, + self.buffer_size / self.buffer_size_threshold * 100) return True def recv_tensor( @@ -267,7 +285,7 @@ def recv_tensor( return None if remote_address not in self.socks: - self._create_connect(remote_address) + self.create_connect(remote_address) sock = self.socks[remote_address] comm, rank = self.comms[remote_address] @@ -282,121 +300,121 @@ def recv_tensor( remote_address, tensor_id, data["ret"]) return None - tensor = torch.empty(data["shape"], - dtype=getattr(torch, data["dtype"]), - device=self.device) + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr(torch, data["dtype"]), + device=self.device) - self._recv(comm, tensor, rank ^ 1, self.recv_stream) + self.recv(comm, tensor, rank ^ 1, self.recv_stream) return tensor - def _listen_for_requests(self): + def listen_for_requests(self): while True: socks = dict(self.poller.poll()) - if self.router_socket in socks: - remote_address, message = self.router_socket.recv_multipart() - data = msgpack.loads(message) - if data["cmd"] == "NEW": - unique_id = self.nccl.unique_id_from_bytes( - bytes(data["unique_id"])) - with torch.cuda.device(self.device): - rank = 1 - with set_p2p_nccl_context(self.nccl_num_channels): - comm: ncclComm_t = self.nccl.ncclCommInitRank( - 2, unique_id, rank) - self.comms[remote_address.decode()] = (comm, rank) - logger.info( - "🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", - self.zmq_address, remote_address.decode(), rank) - elif data["cmd"] == "PUT": - tensor_id = data["tensor_id"] - try: - with torch.cuda.stream(self.recv_stream): - tensor = torch.empty(data["shape"], - dtype=getattr( - torch, data["dtype"]), - device=self.device) - self.router_socket.send_multipart( - [remote_address, b"0"]) - comm, rank = self.comms[remote_address.decode()] - self._recv(comm, tensor, rank ^ 1, self.recv_stream) - tensor_size = tensor.element_size() * tensor.numel() - if (self.buffer_size + tensor_size - > self.buffer_size_threshold): - # Store Tensor in memory pool - addr = self.pool.store_tensor(tensor) - tensor = (addr, tensor.dtype, tensor.shape) - logger.warning( - "🔴[PUT]Recv Tensor, Out Of Threshold, " - "%s👈%s, data:%s, addr:%d", self.zmq_address, - remote_address.decode(), data, addr) - else: - self.buffer_size += tensor_size - - except torch.cuda.OutOfMemoryError: - self.router_socket.send_multipart( - [remote_address, b"1"]) - tensor = None + if self.router_socket not in socks: + continue + + remote_address, message = self.router_socket.recv_multipart() + data = msgpack.loads(message) + if data["cmd"] == "NEW": + unique_id = self.nccl.unique_id_from_bytes( + bytes(data["unique_id"])) + with torch.cuda.device(self.device): + rank = 1 + with set_p2p_nccl_context(self.nccl_num_channels): + comm: ncclComm_t = self.nccl.ncclCommInitRank( + 2, unique_id, rank) + self.comms[remote_address.decode()] = (comm, rank) + logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s", + self.zmq_address, remote_address.decode(), + rank) + elif data["cmd"] == "PUT": + tensor_id = data["tensor_id"] + try: + with torch.cuda.stream(self.recv_stream): + tensor = torch.empty(data["shape"], + dtype=getattr( + torch, data["dtype"]), + device=self.device) + self.router_socket.send_multipart([remote_address, b"0"]) + comm, rank = self.comms[remote_address.decode()] + self.recv(comm, tensor, rank ^ 1, self.recv_stream) + tensor_size = tensor.element_size() * tensor.numel() + if (self.buffer_size + tensor_size + > self.buffer_size_threshold): + # Store Tensor in memory pool + addr = self.pool.store_tensor(tensor) + tensor = (addr, tensor.dtype, tensor.shape) logger.warning( - "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " - "data:%s", self.zmq_address, - remote_address.decode(), data) - - with self.recv_store_cv: - self.recv_store[tensor_id] = tensor - self._have_received_tensor_id(tensor_id) - self.recv_store_cv.notify() - - elif data["cmd"] == "GET": - tensor_id = data["tensor_id"] - with self.send_store_cv: - tensor = self.send_store.pop(tensor_id, None) - if tensor is not None: - data = { - "ret": 0, - "shape": tensor.shape, - "dtype": - str(tensor.dtype).replace("torch.", "") - } - # LRU - self.send_store[tensor_id] = tensor - self._have_sent_tensor_id(tensor_id) - else: - data = {"ret": 1} - - self.router_socket.send_multipart( - [remote_address, msgpack.dumps(data)]) - - if data["ret"] == 0: - comm, rank = self.comms[remote_address.decode()] - self._send(comm, tensor.to(self.device), rank ^ 1, - self.send_stream) - else: + "🔴[PUT]Recv Tensor, Out Of Threshold, " + "%s👈%s, data:%s, addr:%d", self.zmq_address, + remote_address.decode(), data, addr) + else: + self.buffer_size += tensor_size + + except torch.cuda.OutOfMemoryError: + self.router_socket.send_multipart([remote_address, b"1"]) + tensor = None logger.warning( - "🚧Unexpected, Received message from %s, data:%s", - remote_address, data) + "🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, " + "data:%s", self.zmq_address, remote_address.decode(), + data) - def _have_sent_tensor_id(self, tensor_id: str): + with self.recv_store_cv: + self.recv_store[tensor_id] = tensor + self.have_received_tensor_id(tensor_id) + self.recv_store_cv.notify() + + elif data["cmd"] == "GET": + tensor_id = data["tensor_id"] + with self.send_store_cv: + tensor = self.send_store.pop(tensor_id, None) + if tensor is not None: + data = { + "ret": 0, + "shape": tensor.shape, + "dtype": str(tensor.dtype).replace("torch.", "") + } + # LRU + self.send_store[tensor_id] = tensor + self.have_sent_tensor_id(tensor_id) + else: + data = {"ret": 1} + + self.router_socket.send_multipart( + [remote_address, msgpack.dumps(data)]) + + if data["ret"] == 0: + comm, rank = self.comms[remote_address.decode()] + self.send(comm, tensor.to(self.device), rank ^ 1, + self.send_stream) + else: + logger.warning( + "🚧Unexpected, Received message from %s, data:%s", + remote_address, data) + + def have_sent_tensor_id(self, tensor_id: str): request_id = tensor_id.split('#')[0] if request_id not in self.send_request_id_to_tensor_ids: self.send_request_id_to_tensor_ids[request_id] = set() self.send_request_id_to_tensor_ids[request_id].add(tensor_id) - def _have_received_tensor_id(self, tensor_id: str): + def have_received_tensor_id(self, tensor_id: str): request_id = tensor_id.split('#')[0] if request_id not in self.recv_request_id_to_tensor_ids: self.recv_request_id_to_tensor_ids[request_id] = set() self.recv_request_id_to_tensor_ids[request_id].add(tensor_id) - def _send_async(self): + def send_async(self): while True: with self.send_queue_cv: while not self.send_queue: self.send_queue_cv.wait() - tensor_id, remote_address, tensor = self.send_queue.popleft() + item = self.send_queue.popleft() if not self.send_queue: self.send_queue_cv.notify() - self._send_sync(tensor_id, tensor, remote_address) + self.send_sync(item) def wait_for_sent(self): if self.send_type == "PUT_ASYNC": @@ -409,22 +427,21 @@ def wait_for_sent(self): "🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue" " to be empty, rank:%d", duration * 1000, self.rank) - def _send_sync( - self, - tensor_id: str, - tensor: torch.Tensor, - remote_address: typing.Optional[str] = None, - ) -> bool: - if remote_address is None: + def send_sync(self, item: SendQueueItem) -> bool: + if item.remote_address is None: return False - if remote_address not in self.socks: - self._create_connect(remote_address) + if item.remote_address not in self.socks: + self.create_connect(item.remote_address) - sock = self.socks[remote_address] - comm, rank = self.comms[remote_address] + with self.send_stream: + tensor = self.extract_kv_from_layer(item.is_mla, item.tensor, + item.slot_mapping) + + sock = self.socks[item.remote_address] + comm, rank = self.comms[item.remote_address] data = { "cmd": "PUT", - "tensor_id": tensor_id, + "tensor_id": item.tensor_id, "shape": tensor.shape, "dtype": str(tensor.dtype).replace("torch.", "") } @@ -435,20 +452,21 @@ def _send_sync( logger.error( "🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, " "MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s", - self.zmq_address, remote_address, rank, data, tensor.shape, + self.zmq_address, item.remote_address, rank, data, + tensor.shape, tensor.element_size() * tensor.numel() / 1024**3, response.decode()) return False - self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) + self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream) if self.send_type == "PUT_ASYNC": - self._have_sent_tensor_id(tensor_id) + self.have_sent_tensor_id(item.tensor_id) return True def get_finished( - self, finished_req_ids: set[str], forward_context: "ForwardContext" + self, finished_req_ids: set[str], no_compile_layers ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have @@ -463,7 +481,7 @@ def get_finished( # Clear the buffer upon request completion. for request_id in finished_req_ids: - for layer_name in forward_context.no_compile_layers: + for layer_name in no_compile_layers: tensor_id = request_id + "#" + layer_name if tensor_id in self.recv_store: with self.recv_store_cv: @@ -472,7 +490,6 @@ def get_finished( request_id, None) self.recv_request_id_to_tensor_ids.pop( request_id, None) - addr = 0 if isinstance(tensor, tuple): addr, _, _ = tensor self.pool.free(addr) @@ -485,7 +502,7 @@ def get_finished( return finished_sending or None, finished_recving or None - def _ping(self): + def ping(self): sock = self.context.socket(zmq.DEALER) sock.setsockopt_string(zmq.IDENTITY, self.zmq_address) logger.debug("ping start, zmq_address:%s", self.zmq_address) @@ -499,7 +516,7 @@ def _ping(self): sock.send(msgpack.dumps(data)) time.sleep(3) - def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): + def send(self, comm, tensor: torch.Tensor, dst: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") @@ -512,7 +529,7 @@ def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None): comm, cudaStream_t(stream.cuda_stream)) stream.synchronize() - def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None): + def recv(self, comm, tensor: torch.Tensor, src: int, stream=None): assert tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") @@ -531,3 +548,21 @@ def close(self) -> None: self._send_thread.join() if self._ping_thread is not None: self._ping_thread.join() + + @staticmethod + def extract_kv_from_layer( + is_mla: bool, + layer: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> torch.Tensor: + """Extract the KV cache from the layer. + Assume the shape of the layer is (2, num_pages, page_size, xxx) + if MLA is not used, and (num_pages, page_size, xxx) otherwise. + """ + if is_mla: + num_pages, page_size = layer.shape[0], layer.shape[1] + return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] + + num_pages, page_size = layer.shape[1], layer.shape[2] + return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, + ...] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c53601a22f21..1f7a14920c41 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -240,6 +240,8 @@ def __init__( if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_xpu(): + self.device = torch.device(f"xpu:{local_rank}") elif current_platform.is_out_of_tree(): self.device = torch.device( f"{current_platform.device_name}:{local_rank}") @@ -270,6 +272,9 @@ def __init__( self.use_custom_op_call = (current_platform.is_cuda_alike() or current_platform.is_tpu()) + self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr( + torch.ops._C, "init_shm_manager")) + @property def first_rank(self): """Return the global rank of the first process in the group""" @@ -381,6 +386,12 @@ def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: return self.device_communicator.all_gather(input_, dim) + def all_gatherv(self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None): + return self.device_communicator.all_gatherv(input_, dim, sizes) + def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: @@ -399,6 +410,12 @@ def reduce_scatter(self, else: return self._reduce_scatter_out_place(input_, dim) + def reduce_scatterv(self, + input_: torch.Tensor, + dim: int = -1, + sizes: Optional[list[int]] = None) -> torch.Tensor: + return self.device_communicator.reduce_scatterv(input_, dim, sizes) + def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor: return self.device_communicator.reduce_scatter(input_, dim) @@ -649,6 +666,11 @@ def send_tensor_dict( dst = (self.rank_in_group + 1) % self.world_size assert dst < self.world_size, f"Invalid dst rank ({dst})" + if self.use_cpu_custom_send_recv: + self.device_communicator.send_tensor_dict( # type: ignore + tensor_dict, dst) + return None + metadata_list: list[tuple[Any, Any]] = [] assert isinstance( tensor_dict, @@ -704,6 +726,10 @@ def recv_tensor_dict( src = (self.rank_in_group - 1) % self.world_size assert src < self.world_size, f"Invalid src rank ({src})" + if self.use_cpu_custom_send_recv: + return self.device_communicator.recv_tensor_dict( # type: ignore + src) + recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} for key, value in recv_metadata_list: @@ -1317,13 +1343,13 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], def is_global_first_rank() -> bool: """ - Check if the current process is the first rank globally across all + Check if the current process is the first rank globally across all parallelism strategies (PP, TP, DP, EP, etc.). - + Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0` or `get_pp_group().is_first_rank`, this function checks the global rank across all parallelism dimensions. - + Returns: bool: True if this is the global first rank (rank 0), False otherwise. Returns True if distributed is not initialized (single process). @@ -1352,7 +1378,7 @@ def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int: Args: pg: The process group to analyze - + Returns: int: The total number of nodes """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f6efcb80efd5..aec75f82631a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -9,16 +9,16 @@ import json import sys import threading -import warnings from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, - Type, TypeVar, Union, cast, get_args, get_origin) +from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, + Literal, Optional, Type, TypeVar, Union, cast, get_args, + get_origin) import regex as re import torch from pydantic import TypeAdapter, ValidationError -from typing_extensions import TypeIs, deprecated +from typing_extensions import TypeIs import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, @@ -26,26 +26,33 @@ DetailedTraceModules, Device, DeviceConfig, DistributedExecutorBackend, GuidedDecodingBackend, GuidedDecodingBackendV1, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ModelDType, ModelImpl, MultiModalConfig, - ObservabilityConfig, ParallelConfig, PoolerConfig, - PrefixCachingHashAlgo, PromptAdapterConfig, + KVTransferConfig, LoadConfig, LoadFormat, + LogprobsMode, LoRAConfig, ModelConfig, ModelDType, + ModelImpl, MultiModalConfig, ObservabilityConfig, + ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - TaskOption, TokenizerMode, TokenizerPoolConfig, - VllmConfig, get_attr_docs, get_field) -from vllm.executor.executor_base import ExecutorBase + TaskOption, TokenizerMode, VllmConfig, get_attr_docs, + get_field) from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file -from vllm.usage.usage_lib import UsageContext from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor) # yapf: enable +if TYPE_CHECKING: + from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.usage.usage_lib import UsageContext +else: + ExecutorBase = Any + QuantizationMethods = Any + UsageContext = Any + logger = init_logger(__name__) # object is used to allow for special typing forms @@ -58,8 +65,6 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: def _parse_type(val: str) -> T: try: - if return_type is json.loads and not re.match("^{.*}$", val): - return cast(T, nullable_kvs(val)) return return_type(val) except ValueError as e: raise argparse.ArgumentTypeError( @@ -80,47 +85,11 @@ def _optional_type(val: str) -> Optional[T]: def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: - if not re.match("^{.*}$", val): + if not re.match(r"(?s)^\s*{.*}\s*$", val): return str(val) return optional_type(json.loads)(val) -@deprecated( - "Passing a JSON argument as a string containing comma separated key=value " - "pairs is deprecated. This will be removed in v0.10.0. Please use a JSON " - "string instead.") -def nullable_kvs(val: str) -> dict[str, int]: - """Parses a string containing comma separate key [str] to value [int] - pairs into a dictionary. - - Args: - val: String value to be parsed. - - Returns: - Dictionary with parsed values. - """ - out_dict: dict[str, int] = {} - for item in val.split(","): - kv_parts = [part.lower().strip() for part in item.split("=")] - if len(kv_parts) != 2: - raise argparse.ArgumentTypeError( - "Each item should be in the form KEY=VALUE") - key, value = kv_parts - - try: - parsed_value = int(value) - except ValueError as exc: - msg = f"Failed to parse value of item {key}={value}" - raise argparse.ArgumentTypeError(msg) from exc - - if key in out_dict and out_dict[key] != parsed_value: - raise argparse.ArgumentTypeError( - f"Conflicting values specified for key: {key}") - out_dict[key] = parsed_value - - return out_dict - - def is_type(type_hint: TypeHint, type: TypeHintT) -> TypeIs[TypeHintT]: """Check if the type hint is a specific type.""" return type_hint is type or get_origin(type_hint) is type @@ -170,6 +139,10 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: return type_hints +def is_online_quantization(quantization: Any) -> bool: + return quantization in ["inc"] + + @functools.lru_cache(maxsize=30) def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) @@ -198,14 +171,17 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = """\n\nShould either be a valid JSON string or JSON keys - passed individually. For example, the following sets of arguments are - equivalent:\n\n - - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n - Additionally, list elements can be passed individually using '+': - - `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n - - `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`\n\n""" + json_tip = """Should either be a valid JSON string or JSON keys +passed individually. For example, the following sets of arguments are +equivalent: + +- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n +- `--json-arg.key1 value1 --json-arg.key2.key3 value2` + +Additionally, list elements can be passed individually using `+`: + +- `--json-arg '{"key4": ["value3", "value4", "value5"]}'`\n +- `--json-arg.key4+ value3 --json-arg.key4+='value4,value5'`""" if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: @@ -217,7 +193,7 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: raise argparse.ArgumentTypeError(repr(e)) from e kwargs[name]["type"] = parse_dataclass - kwargs[name]["help"] += json_tip + kwargs[name]["help"] += f"\n\n{json_tip}" elif contains_type(type_hints, bool): # Creates --no-<name> and --<name> flags kwargs[name]["action"] = argparse.BooleanOptionalAction @@ -253,7 +229,7 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): kwargs[name]["type"] = parse_type(json.loads) - kwargs[name]["help"] += json_tip + kwargs[name]["help"] += f"\n\n{json_tip}" elif (contains_type(type_hints, str) or any(is_not_builtin(th) for th in type_hints)): kwargs[name]["type"] = str @@ -319,9 +295,11 @@ class EngineArgs: tensor_parallel_size: int = ParallelConfig.tensor_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None + data_parallel_start_rank: Optional[int] = None data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None + data_parallel_hybrid_lb: bool = False data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_eplb: bool = ParallelConfig.enable_eplb @@ -337,7 +315,6 @@ class EngineArgs: CacheConfig.prefix_caching_hash_algo disable_sliding_window: bool = ModelConfig.disable_sliding_window disable_cascade_attn: bool = ModelConfig.disable_cascade_attn - use_v2_block_manager: bool = True swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization @@ -349,6 +326,7 @@ class EngineArgs: SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs + logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode disable_log_stats: bool = False revision: Optional[str] = ModelConfig.revision code_revision: Optional[str] = ModelConfig.code_revision @@ -361,15 +339,9 @@ class EngineArgs: enforce_eager: bool = ModelConfig.enforce_eager max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - # The following three fields are deprecated and will be removed in a future - # release. Setting them will have no effect. Please remove them from your - # configurations. - tokenizer_pool_size: int = TokenizerPoolConfig.pool_size - tokenizer_pool_type: str = TokenizerPoolConfig.pool_type - tokenizer_pool_extra_config: dict = \ - get_field(TokenizerPoolConfig, "extra_config") limit_mm_per_prompt: dict[str, int] = \ get_field(MultiModalConfig, "limit_per_prompt") + interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings media_io_kwargs: dict[str, dict[str, Any]] = get_field(MultiModalConfig, "media_io_kwargs") @@ -388,15 +360,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size - long_lora_scaling_factors: Optional[tuple[float, ...]] = \ - LoRAConfig.long_lora_scaling_factors - # PromptAdapter fields - enable_prompt_adapter: bool = False - max_prompt_adapters: int = PromptAdapterConfig.max_prompt_adapters - max_prompt_adapter_token: int = \ - PromptAdapterConfig.max_prompt_adapter_token - device: Device = DeviceConfig.device num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight @@ -428,7 +392,6 @@ class EngineArgs: speculative_config: Optional[Dict[str, Any]] = None - qlora_adapter_name_or_path: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = \ ObservabilityConfig.show_hidden_metrics_for_version otlp_traces_endpoint: Optional[str] = \ @@ -462,7 +425,6 @@ class EngineArgs: additional_config: dict[str, Any] = \ get_field(VllmConfig, "additional_config") - enable_reasoning: Optional[bool] = None # DEPRECATED reasoning_parser: str = DecodingConfig.reasoning_backend use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load @@ -471,6 +433,10 @@ class EngineArgs: enable_multimodal_encoder_data_parallel: bool = \ ParallelConfig.enable_multimodal_encoder_data_parallel + async_scheduling: bool = SchedulerConfig.async_scheduling + # DEPRECATED + enable_prompt_adapter: bool = False + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -478,13 +444,6 @@ def __post_init__(self): if isinstance(self.compilation_config, (int, dict)): self.compilation_config = CompilationConfig.from_cli( str(self.compilation_config)) - if self.qlora_adapter_name_or_path is not None: - warnings.warn( - "The `qlora_adapter_name_or_path` is deprecated " - "and will be removed in v0.10.0. ", - DeprecationWarning, - stacklevel=2, - ) # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() @@ -531,6 +490,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **model_kwargs["max_seq_len_to_capture"]) model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--logprobs-mode", + **model_kwargs["logprobs_mode"]) model_group.add_argument("--disable-sliding-window", **model_kwargs["disable_sliding_window"]) model_group.add_argument("--disable-cascade-attn", @@ -597,14 +558,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **load_kwargs["ignore_patterns"]) load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"]) - load_group.add_argument( - "--qlora-adapter-name-or-path", - type=str, - default=None, - help="The `--qlora-adapter-name-or-path` has no effect, do not set" - " it, and it will be removed in v0.10.0.", - deprecated=True, - ) load_group.add_argument('--pt-load-map-location', **load_kwargs["pt_load_map_location"]) @@ -625,15 +578,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: guided_decoding_group.add_argument( "--guided-decoding-disable-additional-properties", **guided_decoding_kwargs["disable_additional_properties"]) - guided_decoding_group.add_argument( - "--enable-reasoning", - action=argparse.BooleanOptionalAction, - deprecated=True, - help="[DEPRECATED] The `--enable-reasoning` flag is deprecated as " - "of v0.9.0. Use `--reasoning-parser` to specify the reasoning " - "parser backend instead. This flag (`--enable-reasoning`) will be " - "removed in v0.10.0. When `--reasoning-parser` is specified, " - "reasoning mode is automatically enabled.") guided_decoding_group.add_argument( "--reasoning-parser", # This choices is a special case because it's not static @@ -662,6 +606,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, help='Data parallel rank of this instance. ' 'When set, enables external load balancer mode.') + parallel_group.add_argument('--data-parallel-start-rank', + '-dpr', + type=int, + help='Starting data parallel rank ' + 'for secondary nodes.') parallel_group.add_argument('--data-parallel-size-local', '-dpl', type=int, @@ -683,6 +632,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default='mp', help='Backend for data parallel, either ' '"mp" or "ray".') + parallel_group.add_argument( + "--data-parallel-hybrid-lb", + **parallel_kwargs["data_parallel_hybrid_lb"]) parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -736,19 +688,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument("--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]) - # Tokenizer arguments - tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) - tokenizer_group = parser.add_argument_group( - title="TokenizerPoolConfig", - description=TokenizerPoolConfig.__doc__, - ) - tokenizer_group.add_argument("--tokenizer-pool-size", - **tokenizer_kwargs["pool_size"]) - tokenizer_group.add_argument("--tokenizer-pool-type", - **tokenizer_kwargs["pool_type"]) - tokenizer_group.add_argument("--tokenizer-pool-extra-config", - **tokenizer_kwargs["extra_config"]) - # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) multimodal_group = parser.add_argument_group( @@ -765,6 +704,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: multimodal_group.add_argument( "--disable-mm-preprocessor-cache", **multimodal_kwargs["disable_mm_preprocessor_cache"]) + multimodal_group.add_argument( + "--interleave-mm-strings", + **multimodal_kwargs["interleave_mm_strings"]) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -787,8 +729,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--lora-dtype", **lora_kwargs["lora_dtype"], ) - lora_group.add_argument("--long-lora-scaling-factors", - **lora_kwargs["long_lora_scaling_factors"]) lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) lora_group.add_argument("--fully-sharded-loras", @@ -796,33 +736,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) - # PromptAdapter related configs - prompt_adapter_kwargs = get_kwargs(PromptAdapterConfig) - prompt_adapter_group = parser.add_argument_group( - title="PromptAdapterConfig", - description=PromptAdapterConfig.__doc__, - ) - prompt_adapter_group.add_argument( - "--enable-prompt-adapter", - action=argparse.BooleanOptionalAction, - help="If True, enable handling of PromptAdapters.") - prompt_adapter_group.add_argument( - "--max-prompt-adapters", - **prompt_adapter_kwargs["max_prompt_adapters"]) - prompt_adapter_group.add_argument( - "--max-prompt-adapter-token", - **prompt_adapter_kwargs["max_prompt_adapter_token"]) - - # Device arguments - device_kwargs = get_kwargs(DeviceConfig) - device_group = parser.add_argument_group( - title="DeviceConfig", - description=DeviceConfig.__doc__, - ) - device_group.add_argument("--device", - **device_kwargs["device"], - deprecated=True) - # Speculative arguments speculative_group = parser.add_argument_group( title="SpeculativeConfig", @@ -905,6 +818,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) + scheduler_group.add_argument("--async-scheduling", + **scheduler_kwargs["async_scheduling"]) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -922,18 +837,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **vllm_kwargs["additional_config"]) # Other arguments - parser.add_argument('--use-v2-block-manager', - action='store_true', - default=True, - deprecated=True, - help='[DEPRECATED] block manager v1 has been ' - 'removed and SelfAttnBlockSpaceManager (i.e. ' - 'block manager v2) is now the default. ' - 'Setting this flag to True or False' - ' has no effect on vLLM behavior.') parser.add_argument('--disable-log-stats', action='store_true', help='Disable logging statistics.') + parser.add_argument('--enable-prompt-adapter', + action='store_true', + deprecated=True, + help='[DEPRECATED] Prompt adapter has been ' + 'removed. Setting this flag to True or False' + ' has no effect on vLLM behavior.') return parser @@ -979,12 +891,14 @@ def create_model_config(self) -> ModelConfig: enforce_eager=self.enforce_eager, max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, + logprobs_mode=self.logprobs_mode, disable_sliding_window=self.disable_sliding_window, disable_cascade_attn=self.disable_cascade_attn, skip_tokenizer_init=self.skip_tokenizer_init, enable_prompt_embeds=self.enable_prompt_embeds, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, + interleave_mm_strings=self.interleave_mm_strings, media_io_kwargs=self.media_io_kwargs, use_async_output_proc=not self.disable_async_output_proc, config_format=self.config_format, @@ -1000,14 +914,33 @@ def create_model_config(self) -> ModelConfig: override_attention_dtype=self.override_attention_dtype, ) + def validate_tensorizer_args(self): + from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig) + for key in self.model_loader_extra_config: + if key in TensorizerConfig._fields: + self.model_loader_extra_config["tensorizer_config"][ + key] = self.model_loader_extra_config[key] + def create_load_config(self) -> LoadConfig: if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" + if self.load_format == "tensorizer": + if hasattr(self.model_loader_extra_config, "to_serializable"): + self.model_loader_extra_config = ( + self.model_loader_extra_config.to_serializable()) + self.model_loader_extra_config["tensorizer_config"] = {} + self.model_loader_extra_config["tensorizer_config"][ + "tensorizer_dir"] = self.model + self.validate_tensorizer_args() + return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, + device="cpu" + if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, @@ -1049,6 +982,7 @@ def create_speculative_config( def create_engine_config( self, usage_context: Optional[UsageContext] = None, + headless: bool = False, ) -> VllmConfig: """ Create the VllmConfig. @@ -1063,7 +997,6 @@ def create_engine_config( If VLLM_USE_V1 is specified by the user but the VllmConfig is incompatible, we raise an error. """ - from vllm.platforms import current_platform current_platform.pre_register_and_update() device_config = DeviceConfig( @@ -1090,9 +1023,16 @@ def create_engine_config( # Set default arguments for V0 or V1 Engine. if use_v1: self._set_default_args_v1(usage_context, model_config) + # Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1 + if current_platform.is_cpu( + ) and current_platform.get_cpu_architecture() in ( + CpuArchEnum.POWERPC, CpuArchEnum.ARM): + logger.info( + "Chunked prefill is not supported for ARM and POWER CPUs; " + "disabling it for V1 backend.") + self.enable_chunked_prefill = False else: self._set_default_args_v0(model_config) - assert self.enable_chunked_prefill is not None if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]: @@ -1131,15 +1071,41 @@ def create_engine_config( # but we should not do this here. placement_group = ray.util.get_current_placement_group() + assert not headless or not self.data_parallel_hybrid_lb, ( + "data_parallel_hybrid_lb is not applicable in " + "headless mode") + data_parallel_external_lb = self.data_parallel_rank is not None + # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( "data_parallel_size_local must be 1 when data_parallel_rank " "is set") data_parallel_size_local = 1 + # Use full external lb if we have local_size of 1. + self.data_parallel_hybrid_lb = False elif self.data_parallel_size_local is not None: data_parallel_size_local = self.data_parallel_size_local + + if self.data_parallel_start_rank and not headless: + # Infer hybrid LB mode. + self.data_parallel_hybrid_lb = True + + if self.data_parallel_hybrid_lb and data_parallel_size_local == 1: + # Use full external lb if we have local_size of 1. + data_parallel_external_lb = True + self.data_parallel_hybrid_lb = False + + if data_parallel_size_local == self.data_parallel_size: + # Disable hybrid LB mode if set for a single node + self.data_parallel_hybrid_lb = False + + self.data_parallel_rank = self.data_parallel_start_rank or 0 else: + assert not self.data_parallel_hybrid_lb, ( + "data_parallel_size_local must be set to use " + "data_parallel_hybrid_lb.") + # Local DP size defaults to global DP size if not set. data_parallel_size_local = self.data_parallel_size @@ -1166,6 +1132,26 @@ def create_engine_config( self.data_parallel_rpc_port is not None) else ParallelConfig.data_parallel_rpc_port + if self.async_scheduling: + # Async scheduling does not work with the uniprocess backend. + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "mp" + logger.info("Using mp-based distributed executor backend " + "for async scheduling.") + if self.distributed_executor_backend == "uni": + raise ValueError("Async scheduling is not supported with " + "uni-process backend.") + if self.pipeline_parallel_size > 1: + raise ValueError("Async scheduling is not supported with " + "pipeline-parallel-size > 1.") + + # Currently, async scheduling does not support speculative decoding. + # TODO(woosuk): Support it. + if self.speculative_config is not None: + raise ValueError( + "Currently, speculative decoding is not supported with " + "async scheduling.") + parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, @@ -1176,6 +1162,7 @@ def create_engine_config( data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=self.data_parallel_backend, + data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, num_redundant_experts=self.num_redundant_experts, @@ -1209,7 +1196,6 @@ def create_engine_config( if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: raise ValueError("Multi-Step Chunked-Prefill is not supported " "for pipeline-parallel-size > 1") - from vllm.platforms import current_platform if current_platform.is_cpu(): logger.warning("Multi-Step (--num-scheduler-steps > 1) is " "currently not supported for CPUs and has been " @@ -1247,6 +1233,7 @@ def create_engine_config( long_prefill_token_threshold=self.long_prefill_token_threshold, disable_hybrid_kv_cache_manager=self. disable_hybrid_kv_cache_manager, + async_scheduling=self.async_scheduling, ) if not model_config.is_multimodal_model and self.default_mm_loras: @@ -1261,7 +1248,6 @@ def create_engine_config( default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, - long_lora_scaling_factors=self.long_lora_scaling_factors, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None @@ -1272,11 +1258,6 @@ def create_engine_config( load_config = self.create_load_config() - prompt_adapter_config = PromptAdapterConfig( - max_prompt_adapters=self.max_prompt_adapters, - max_prompt_adapter_token=self.max_prompt_adapter_token) \ - if self.enable_prompt_adapter else None - decoding_config = DecodingConfig( backend=self.guided_decoding_backend, disable_fallback=self.guided_decoding_disable_fallback, @@ -1287,8 +1268,8 @@ def create_engine_config( ) observability_config = ObservabilityConfig( - show_hidden_metrics_for_version=self. - show_hidden_metrics_for_version, + show_hidden_metrics_for_version=( + self.show_hidden_metrics_for_version), otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) @@ -1304,7 +1285,6 @@ def create_engine_config( load_config=load_config, decoding_config=decoding_config, observability_config=observability_config, - prompt_adapter_config=prompt_adapter_config, compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, @@ -1364,7 +1344,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Skip this check if we are running on a non-GPU platform, # or if the device capability is not available # (e.g. in a Ray actor without GPUs). - from vllm.platforms import current_platform if (current_platform.is_cuda() and current_platform.get_device_capability() and current_platform.get_device_capability().major < 8): @@ -1374,29 +1353,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # No Fp8 KV cache so far. if self.kv_cache_dtype != "auto": - fp8_attention = self.kv_cache_dtype.startswith("fp8") - will_use_fa = ( - current_platform.is_cuda() - and not envs.is_set("VLLM_ATTENTION_BACKEND") - ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" - supported = False - if current_platform.is_rocm(): - supported = True - elif fp8_attention and will_use_fa: - from vllm.attention.utils.fa_utils import ( - flash_attn_supports_fp8) - supported = flash_attn_supports_fp8() + supported = current_platform.is_kv_cache_dtype_supported( + self.kv_cache_dtype) if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False) return False - # No Prompt Adapter so far. - if self.enable_prompt_adapter: - _raise_or_fallback(feature_name="--enable-prompt-adapter", - recommend_to_remove=False) - return False - # No text embedding inputs so far. if self.enable_prompt_embeds: _raise_or_fallback(feature_name="--enable-prompt-embeds", @@ -1430,28 +1393,12 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return False # V1 supports N-gram, Medusa, and Eagle speculative decoding. - is_ngram_enabled = False - is_eagle_enabled = False - is_medusa_enabled = False - if self.speculative_config is not None: - # This is supported but experimental (handled below). - speculative_method = self.speculative_config.get("method") - if speculative_method: - if speculative_method in ("ngram", "[ngram]"): - is_ngram_enabled = True - elif speculative_method == "medusa": - is_medusa_enabled = True - elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"): - is_eagle_enabled = True - else: - speculative_model = self.speculative_config.get("model") - if speculative_model in ("ngram", "[ngram]"): - is_ngram_enabled = True - if not (is_ngram_enabled or is_eagle_enabled or is_medusa_enabled): - # Other speculative decoding methods are not supported yet. - _raise_or_fallback(feature_name="Speculative Decoding", - recommend_to_remove=False) - return False + if (self.speculative_config is not None + and self.speculative_config.get("method") == "draft_model"): + raise NotImplementedError( + "Speculative decoding with draft model is not supported yet. " + "Please consider using other speculative decoding methods " + "such as ngram, medusa, eagle, or deepseek_mtp.") # No XFormers so far. V1_BACKENDS = [ @@ -1527,7 +1474,6 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None: # Enable chunked prefill by default for long context (> 32K) # models to avoid OOM errors in initial memory profiling phase. elif use_long_context: - from vllm.platforms import current_platform is_gpu = current_platform.is_cuda() use_sliding_window = (model_config.get_sliding_window() is not None) @@ -1535,7 +1481,6 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None: if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora - and not self.enable_prompt_adapter and model_config.runner_type != "pooling"): self.enable_chunked_prefill = True logger.warning( @@ -1625,7 +1570,6 @@ def _set_default_args_v1(self, usage_context: UsageContext, # as the platform that vLLM is running on (e.g. the case of scaling # vLLM with Ray) and has no GPUs. In this case we use the default # values for non-H100/H200 GPUs. - from vllm.platforms import current_platform try: device_memory = current_platform.get_device_total_memory() device_name = current_platform.get_device_name().lower() @@ -1636,6 +1580,7 @@ def _set_default_args_v1(self, usage_context: UsageContext, # NOTE(Kuntai): Setting large `max_num_batched_tokens` for A100 reduces # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. + from vllm.usage.usage_lib import UsageContext if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1674,13 +1619,14 @@ def _set_default_args_v1(self, usage_context: UsageContext, # cpu specific default values. if current_platform.is_cpu(): + world_size = self.pipeline_parallel_size * self.tensor_parallel_size default_max_num_batched_tokens = { - UsageContext.LLM_CLASS: 4096, - UsageContext.OPENAI_API_SERVER: 2048, + UsageContext.LLM_CLASS: 4096 * world_size, + UsageContext.OPENAI_API_SERVER: 2048 * world_size, } default_max_num_seqs = { - UsageContext.LLM_CLASS: 128, - UsageContext.OPENAI_API_SERVER: 32, + UsageContext.LLM_CLASS: 256 * world_size, + UsageContext.OPENAI_API_SERVER: 128 * world_size, } use_context_value = usage_context.value if usage_context else None @@ -1728,7 +1674,6 @@ def add_cli_args(parser: FlexibleArgumentParser, parser.add_argument('--disable-log-requests', action='store_true', help='Disable logging requests.') - from vllm.platforms import current_platform current_platform.pre_register_and_update(parser) return parser diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3d7d28055dd0..39642d89167b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -435,9 +434,9 @@ async def add_request_async( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ Async version of @@ -467,7 +466,7 @@ async def add_request_async( processed_inputs = await self.input_preprocessor.preprocess_async( prompt, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, + tokenization_kwargs=tokenization_kwargs, ) if isinstance(params, SamplingParams) and \ @@ -489,7 +488,6 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=priority, ) @@ -859,9 +857,9 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Union[RequestOutput, PoolingRequestOutput], None]: if not self.is_running: if self.start_engine_loop: @@ -886,9 +884,9 @@ async def add_request( arrival_time=arrival_time or time.time(), lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, priority=priority, data_parallel_rank=data_parallel_rank, + tokenization_kwargs=tokenization_kwargs, ) return stream.generator() @@ -900,7 +898,6 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, ) -> AsyncGenerator[RequestOutput, None]: @@ -918,8 +915,6 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use - for generation, if any. priority: The priority of the request. Only applicable with priority scheduling. data_parallel_rank: The (global) data parallel rank that must @@ -979,7 +974,6 @@ async def generate( sampling_params, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, priority=priority, data_parallel_rank=data_parallel_rank, ): @@ -996,6 +990,7 @@ async def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """Generate outputs for a request from a pooling model. @@ -1070,6 +1065,7 @@ async def encode( lora_request=lora_request, trace_headers=trace_headers, priority=priority, + tokenization_kwargs=tokenization_kwargs, ): yield LLMEngine.validate_output(output, PoolingRequestOutput) except asyncio.CancelledError: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 25fa1c3058be..e7919d90442f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -44,7 +44,6 @@ from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup, PoolingSequenceGroupOutput, Sequence, SequenceGroup, @@ -223,7 +222,6 @@ def __init__( self.load_config = vllm_config.load_config self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa ) - self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa ) @@ -238,14 +236,14 @@ def __init__( self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs - if not self.model_config.skip_tokenizer_init: - self.tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) - tokenizer_group = self.get_tokenizer_group() - else: + if self.model_config.skip_tokenizer_init: self.tokenizer = None self.detokenizer = None tokenizer_group = None + else: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() # Ensure that the function doesn't contain a reference to self, # to avoid engine GC issues @@ -294,8 +292,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # Feature flags "enable_lora": bool(self.lora_config), - "enable_prompt_adapter": - bool(self.prompt_adapter_config), "enable_prefix_caching": self.cache_config.enable_prefix_caching, "enforce_eager": @@ -542,9 +538,6 @@ def _verify_args(self) -> None: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) - if self.prompt_adapter_config: - self.prompt_adapter_config.verify_with_model_config( - self.model_config) def _add_processed_request( self, @@ -553,7 +546,6 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, ) -> Optional[SequenceGroup]: @@ -569,7 +561,6 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, priority=priority, ) return None @@ -583,11 +574,10 @@ def _add_processed_request( encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, - lora_request, prompt_adapter_request) + lora_request) encoder_seq = (None if encoder_inputs is None else Sequence( - seq_id, encoder_inputs, block_size, eos_token_id, lora_request, - prompt_adapter_request)) + seq_id, encoder_inputs, block_size, eos_token_id, lora_request)) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -598,7 +588,6 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, priority=priority) elif isinstance(params, PoolingParams): @@ -608,7 +597,6 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, priority=priority) else: @@ -637,7 +625,6 @@ def add_request( lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: """Add a request to the engine's request pool. @@ -658,7 +645,6 @@ def add_request( the current monotonic time. lora_request: The LoRA request to add. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: The prompt adapter request to add. priority: The priority of the request. Only applicable with priority scheduling. @@ -719,7 +705,6 @@ def add_request( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, ) self._add_processed_request( @@ -728,7 +713,6 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=priority, ) @@ -741,7 +725,6 @@ def _create_sequence_group_with_sampling( arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, priority: int = 0, ) -> SequenceGroup: @@ -769,17 +752,15 @@ def _create_sequence_group_with_sampling( if self.vllm_config.speculative_config is not None: draft_size = \ self.vllm_config.speculative_config.num_speculative_tokens + 1 - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - sampling_params=sampling_params, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority, - draft_size=draft_size) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + encoder_seq=encoder_seq, + priority=priority, + draft_size=draft_size) return seq_group @@ -790,7 +771,6 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, priority: int = 0, ) -> SequenceGroup: @@ -798,15 +778,13 @@ def _create_sequence_group_with_pooling( # Defensive copy of PoolingParams, which are used by the pooler pooling_params = pooling_params.clone() # Create the sequence group. - seq_group = SequenceGroup( - request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority) + seq_group = SequenceGroup(request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + encoder_seq=encoder_seq, + priority=priority) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -1780,13 +1758,6 @@ def _get_stats(self, num_generation_tokens_from_prefill_groups) num_tokens_iter = (num_generation_tokens_iter + num_prompt_tokens_iter) - # Spec decode, if enabled, emits specialized metrics from the worker in - # sampler output. - if model_output and isinstance(model_output[0], SamplerOutput) and ( - model_output[0].spec_decode_worker_metrics is not None): - spec_decode_metrics = model_output[0].spec_decode_worker_metrics - else: - spec_decode_metrics = None return Stats( now=now, @@ -1808,7 +1779,6 @@ def _get_stats(self, num_tokens_iter=num_tokens_iter, time_to_first_tokens_iter=time_to_first_tokens_iter, time_per_output_tokens_iter=time_per_output_tokens_iter, - spec_decode_metrics=spec_decode_metrics, num_preemption_iter=num_preemption_iter, # Request stats @@ -1842,16 +1812,6 @@ def list_loras(self) -> Set[int]: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.model_executor.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_executor.remove_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> List[int]: - return self.model_executor.list_prompt_adapters() - def start_profile(self) -> None: self.model_executor.start_profile() diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 8d51f0472351..ba8dbd1fad79 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time -from typing import TYPE_CHECKING from typing import Counter as CollectionsCounter from typing import Dict, List, Optional, Type, Union, cast @@ -19,9 +18,6 @@ else: ray_metrics = None -if TYPE_CHECKING: - from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics - logger = init_logger(__name__) prometheus_client.disable_created_metrics() @@ -199,30 +195,6 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): documentation="Count of successfully processed requests.", labelnames=labelnames + [Metrics.labelname_finish_reason]) - # Speculative decoding stats - self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( - name="vllm:spec_decode_draft_acceptance_rate", - documentation="Speulative token acceptance rate.", - labelnames=labelnames, - multiprocess_mode="sum") - self.gauge_spec_decode_efficiency = self._gauge_cls( - name="vllm:spec_decode_efficiency", - documentation="Speculative decoding system efficiency.", - labelnames=labelnames, - multiprocess_mode="sum") - self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( - name="vllm:spec_decode_num_accepted_tokens_total", - documentation="Number of accepted tokens.", - labelnames=labelnames)) - self.counter_spec_decode_num_draft_tokens = self._counter_cls( - name="vllm:spec_decode_num_draft_tokens_total", - documentation="Number of draft tokens.", - labelnames=labelnames) - self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( - name="vllm:spec_decode_num_emitted_tokens_total", - documentation="Number of emitted tokens.", - labelnames=labelnames)) - # --8<-- [end:metrics-definitions] @@ -391,9 +363,6 @@ def log(self, stats: Stats) -> None: self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter) - # Update spec decode metrics - self.maybe_update_spec_decode_metrics(stats) - # Log locally every local_interval seconds. if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): @@ -435,10 +404,6 @@ def log(self, stats: Stats) -> None: stats.gpu_prefix_cache_hit_rate * 100, stats.cpu_prefix_cache_hit_rate * 100, ) - if self.spec_decode_metrics is not None: - log_fn( - self._format_spec_decode_metrics_str( - self.spec_decode_metrics)) self._reset(stats, prompt_throughput, generation_throughput) @@ -447,21 +412,9 @@ def _reset(self, stats, prompt_throughput, generation_throughput) -> None: self.num_prompt_tokens = [] self.num_generation_tokens = [] self.last_local_log = stats.now - self.spec_decode_metrics = None self.last_prompt_throughput = prompt_throughput self.last_generation_throughput = generation_throughput - def _format_spec_decode_metrics_str( - self, metrics: "SpecDecodeWorkerMetrics") -> str: - - return ("Speculative metrics: " - f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " - f"System efficiency: {metrics.system_efficiency:.3f}, " - f"Number of speculative tokens: {metrics.num_spec_tokens}, " - f"Number of accepted tokens: {metrics.accepted_tokens}, " - f"Number of draft tokens: {metrics.draft_tokens}, " - f"Number of emitted tokens: {metrics.emitted_tokens}.") - def info(self, type: str, obj: SupportsMetricsInfo) -> None: raise NotImplementedError @@ -579,33 +532,14 @@ def log(self, stats: Stats): self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) self.num_generation_tokens.append(stats.num_generation_tokens_iter) - # Update spec decode metrics - self.maybe_update_spec_decode_metrics(stats) - # Log locally every local_interval seconds. if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): - if self.spec_decode_metrics is not None: - self._log_gauge( - self.metrics.gauge_spec_decode_draft_acceptance_rate, - self.spec_decode_metrics.draft_acceptance_rate) - self._log_gauge(self.metrics.gauge_spec_decode_efficiency, - self.spec_decode_metrics.system_efficiency) - self._log_counter( - self.metrics.counter_spec_decode_num_accepted_tokens, - self.spec_decode_metrics.accepted_tokens) - self._log_counter( - self.metrics.counter_spec_decode_num_draft_tokens, - self.spec_decode_metrics.draft_tokens) - self._log_counter( - self.metrics.counter_spec_decode_num_emitted_tokens, - self.spec_decode_metrics.emitted_tokens) # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] self.last_local_log = stats.now - self.spec_decode_metrics = None def info(self, type: str, obj: SupportsMetricsInfo) -> None: # Info type metrics are syntactic sugar for a gauge permanently set to 1 diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 9375dc4c495b..3281a9121a9d 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -16,10 +16,9 @@ import time from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional +from typing import List from vllm.config import SupportsMetricsInfo, VllmConfig -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @dataclass @@ -65,8 +64,6 @@ class Stats: running_lora_adapters: List[str] max_lora: str - spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None - class StatLoggerBase(ABC): """Base class for StatLogger.""" @@ -77,7 +74,6 @@ def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None: self.num_generation_tokens: List[int] = [] self.last_local_log = time.time() self.local_interval = local_interval - self.spec_decode_metrics: Optional[SpecDecodeWorkerMetrics] = None @abstractmethod def log(self, stats: Stats) -> None: @@ -86,9 +82,3 @@ def log(self, stats: Stats) -> None: @abstractmethod def info(self, type: str, obj: SupportsMetricsInfo) -> None: raise NotImplementedError - - def maybe_update_spec_decode_metrics(self, stats: Stats): - """Save spec decode metrics (since they are unlikely - to be emitted at same time as log interval).""" - if stats.spec_decode_metrics is not None: - self.spec_decode_metrics = stats.spec_decode_metrics diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index db968cd6b5d8..ff0405d2f843 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -10,7 +10,6 @@ from vllm.inputs import PromptType from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.utils import Device @@ -33,7 +32,6 @@ class RPCProcessRequest: request_id: str lora_request: Optional[LoRARequest] = None trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 def __init__( @@ -43,7 +41,6 @@ def __init__( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: super().__init__() @@ -53,7 +50,6 @@ def __init__( self.request_id = request_id self.lora_request = lora_request self.trace_headers = trace_headers - self.prompt_adapter_request = prompt_adapter_request self.priority = priority diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 9e018ec7f344..67d9a3bf6ce2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -45,7 +45,6 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import PoolingRequestOutput, RequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import Device @@ -448,7 +447,6 @@ def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. @@ -465,8 +463,6 @@ def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request to use - for generation, if any. priority: Priority of the request (lower means earlier handling). Any priority other than 0 will lead to an error if the scheduling policy is not "priority". @@ -474,8 +470,7 @@ def generate( return cast( AsyncGenerator[RequestOutput, None], self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, - prompt_adapter_request, priority)) + lora_request, trace_headers, priority)) def encode( self, @@ -521,7 +516,6 @@ async def _process_request( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ PoolingRequestOutput, None]]: @@ -575,7 +569,6 @@ async def _process_request( request_id=request_id, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, priority=priority, )) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index ef088bd3933a..fe6eb0d8c2f1 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -304,14 +304,12 @@ def _handle_process_request(self, request: RPCProcessRequest): self._send_outputs(rpc_err) try: - self.engine.add_request( - request_id=request_id, - prompt=request.prompt, - params=request.params, - lora_request=request.lora_request, - trace_headers=request.trace_headers, - prompt_adapter_request=request.prompt_adapter_request, - priority=request.priority) + self.engine.add_request(request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + priority=request.priority) if self.log_requests: logger.info("Added request %s.", request.request_id) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index e0fa6a00ecfa..8b66ef0dc765 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -104,11 +104,6 @@ def process_outputs(self, seqs = sequence_group.get_seqs( status=SequenceStatus.FINISHED_ABORTED) - for output in outputs: - if output.samples[0].output_token != VLLM_INVALID_TOKEN_ID: - sequence_group.metrics.spec_token_acceptance_counts[ - output.step_index] += 1 - assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert len(seqs) == 1, ( "Beam search not supported in multi-step decoding.") diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 8688fcc82cd9..671e9648a3d0 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -16,7 +16,6 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Device, collect_from_async_generator, random_uuid @@ -55,7 +54,6 @@ def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" @@ -324,3 +322,9 @@ async def is_sleeping(self) -> bool: async def add_lora(self, lora_request: LoRARequest) -> None: """Load a new LoRA adapter into the engine for future requests.""" ... + + async def scale_elastic_ep(self, + new_data_parallel_size: int, + drain_timeout: int = 300) -> None: + """Scale the engine""" + raise NotImplementedError diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4b6c50526b10..a6602391d408 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -4,7 +4,7 @@ import asyncio import json from abc import ABC, abstractmethod -from collections import defaultdict, deque +from collections import Counter, defaultdict, deque from collections.abc import Awaitable, Iterable from functools import cached_property, lru_cache, partial from pathlib import Path @@ -28,6 +28,7 @@ ChatCompletionToolMessageParam) from openai.types.chat.chat_completion_content_part_input_audio_param import ( InputAudio) +from openai.types.responses import ResponseInputImageParam from PIL import Image from pydantic import BaseModel, ConfigDict, TypeAdapter # yapf: enable @@ -38,7 +39,6 @@ from vllm.config import ModelConfig from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.models import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.utils import MediaConnector @@ -52,6 +52,12 @@ logger = init_logger(__name__) +MODALITY_PLACEHOLDERS_MAP = { + "image": "<##IMAGE##>", + "audio": "<##AUDIO##>", + "video": "<##VIDEO##>", +} + class AudioURL(TypedDict, total=False): url: Required[str] @@ -145,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): video_url: Required[str] +class CustomThinkCompletionContentParam(TypedDict, total=False): + """A Think Completion Content Param that accepts a plain text and a boolean. + + Example: + { + "thinking": "I am thinking about the answer", + "closed": True, + "type": "thinking" + } + """ + + thinking: Required[str] + """The thinking content.""" + + closed: bool + """Whether the thinking is closed.""" + + type: Required[Literal["thinking"]] + """The thinking type.""" + + ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, ChatCompletionContentPartInputAudioParam, @@ -153,7 +180,8 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False): CustomChatCompletionContentSimpleImageParam, ChatCompletionContentPartImageEmbedsParam, CustomChatCompletionContentSimpleAudioParam, - CustomChatCompletionContentSimpleVideoParam, str] + CustomChatCompletionContentSimpleVideoParam, str, + CustomThinkCompletionContentParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -354,6 +382,7 @@ def resolve_mistral_chat_template( "so it will be ignored.") return None + @deprecate_kwargs( "trust_remote_code", additional_message="Please use `model_config.trust_remote_code` instead.", @@ -517,6 +546,7 @@ def model_config(self) -> ModelConfig: @cached_property def model_cls(self): + from vllm.model_executor.model_loader import get_model_cls return get_model_cls(self.model_config) @property @@ -633,15 +663,22 @@ class BaseMultiModalContentParser(ABC): def __init__(self) -> None: super().__init__() - # multimodal placeholder_string : count - self._placeholder_counts: dict[str, int] = defaultdict(lambda: 0) - - def _add_placeholder(self, placeholder: Optional[str]): + # stores model placehodlers list with corresponding + # general MM placeholder: + # { + # "<##IMAGE##>": ["<image>", "<image>", "<image>"], + # "<##AUDIO##>": ["<audio>", "<audio>"] + # } + self._placeholder_storage: dict[str, list] = defaultdict(list) + + def _add_placeholder(self, modality: ModalityStr, + placeholder: Optional[str]): + mod_placeholder = MODALITY_PLACEHOLDERS_MAP[modality] if placeholder: - self._placeholder_counts[placeholder] += 1 + self._placeholder_storage[mod_placeholder].append(placeholder) - def mm_placeholder_counts(self) -> dict[str, int]: - return dict(self._placeholder_counts) + def mm_placeholder_storage(self) -> dict[str, list]: + return dict(self._placeholder_storage) @abstractmethod def parse_image(self, image_url: str) -> None: @@ -685,7 +722,7 @@ def parse_image(self, image_url: str) -> None: image = self._connector.fetch_image(image_url) placeholder = self._tracker.add("image", image) - self._add_placeholder(placeholder) + self._add_placeholder("image", placeholder) def parse_image_embeds(self, image_embeds: Union[str, dict[str, str]]) -> None: @@ -700,17 +737,17 @@ def parse_image_embeds(self, embedding = self._connector.fetch_image_embedding(image_embeds) placeholder = self._tracker.add("image_embeds", embedding) - self._add_placeholder(placeholder) + self._add_placeholder("image", placeholder) def parse_image_pil(self, image_pil: Image.Image) -> None: placeholder = self._tracker.add("image", image_pil) - self._add_placeholder(placeholder) + self._add_placeholder("image", placeholder) def parse_audio(self, audio_url: str) -> None: audio = self._connector.fetch_audio(audio_url) placeholder = self._tracker.add("audio", audio) - self._add_placeholder(placeholder) + self._add_placeholder("audio", placeholder) def parse_input_audio(self, input_audio: InputAudio) -> None: audio_data = input_audio.get("data", "") @@ -723,7 +760,7 @@ def parse_video(self, video_url: str) -> None: video = self._connector.fetch_video(video_url=video_url) placeholder = self._tracker.add("video", video) - self._add_placeholder(placeholder) + self._add_placeholder("video", placeholder) class AsyncMultiModalContentParser(BaseMultiModalContentParser): @@ -741,7 +778,7 @@ def parse_image(self, image_url: str) -> None: image_coro = self._connector.fetch_image_async(image_url) placeholder = self._tracker.add("image", image_coro) - self._add_placeholder(placeholder) + self._add_placeholder("image", placeholder) def parse_image_embeds(self, image_embeds: Union[str, dict[str, str]]) -> None: @@ -760,20 +797,20 @@ def parse_image_embeds(self, future.set_result(embedding) placeholder = self._tracker.add("image_embeds", future) - self._add_placeholder(placeholder) + self._add_placeholder("image", placeholder) def parse_image_pil(self, image_pil: Image.Image) -> None: future: asyncio.Future[Image.Image] = asyncio.Future() future.set_result(image_pil) placeholder = self._tracker.add("image", future) - self._add_placeholder(placeholder) + self._add_placeholder("image", placeholder) def parse_audio(self, audio_url: str) -> None: audio_coro = self._connector.fetch_audio_async(audio_url) placeholder = self._tracker.add("audio", audio_coro) - self._add_placeholder(placeholder) + self._add_placeholder("audio", placeholder) def parse_input_audio(self, input_audio: InputAudio) -> None: audio_data = input_audio.get("data", "") @@ -786,7 +823,7 @@ def parse_video(self, video_url: str) -> None: video = self._connector.fetch_video_async(video_url=video_url) placeholder = self._tracker.add("video", video) - self._add_placeholder(placeholder) + self._add_placeholder("video", placeholder) def validate_chat_template(chat_template: Optional[Union[Path, str]]): @@ -856,12 +893,40 @@ def load_chat_template( return _cached_load_chat_template(chat_template, is_literal=is_literal) +def _get_interleaved_text_prompt(placeholder_storage: dict[str, list], + texts: list[str]) -> str: + for idx, elem in enumerate(texts): + if elem in placeholder_storage: + texts[idx] = placeholder_storage[elem].pop(0) + + return "\n".join(texts) + + # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], - text_prompt: str) -> str: +def _get_full_multimodal_text_prompt(placeholder_storage: dict[str, list], + texts: list[str], + interleave_strings: bool + ) -> str: """Combine multimodal prompts for a multimodal language model.""" + # flatten storage to make it looks like + # { + # "<|image|>": 2, + # "<|audio|>": 1 + # } + placeholder_counts = Counter( + [v for elem in placeholder_storage.values() for v in elem] + ) + + if interleave_strings: + text_prompt = _get_interleaved_text_prompt(placeholder_storage, texts) + else: + text_prompt = "\n".join(texts) + + # Pass interleaved text further in case the user used image placeholders + # himself, but forgot to disable the 'interleave_strings' flag + # Look through the text prompt to check for missing placeholders missing_placeholders: list[str] = [] for placeholder in placeholder_counts: @@ -870,6 +935,13 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], placeholder_counts[placeholder] -= text_prompt.count(placeholder) if placeholder_counts[placeholder] < 0: + logger.error( + "Placeholder count is negative! " + "Ensure that the 'interleave_strings' flag is disabled " + "(current value: %s) " + "when manually placing image placeholders.", interleave_strings + ) + logger.debug("Input prompt: %s", text_prompt) raise ValueError( f"Found more '{placeholder}' placeholders in input prompt than " "actual multimodal data items.") @@ -877,8 +949,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], missing_placeholders.extend([placeholder] * placeholder_counts[placeholder]) - # NOTE: For now we always add missing placeholders at the front of - # the prompt. This may change to be customizable in the future. + # NOTE: Default behaviour: we always add missing placeholders + # at the front of the prompt, if interleave_strings=False return "\n".join(missing_placeholders + [text_prompt]) @@ -888,11 +960,14 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], _InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) _PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam) +_ThinkParser = partial(cast, CustomThinkCompletionContentParam) # Need to validate url objects _ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python _AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python _VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python +_ResponsesInputImageParser = TypeAdapter( + ResponseInputImageParam).validate_python _ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio, PILImage] # Define a mapping from part types to their corresponding parsing functions. @@ -902,6 +977,12 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int], ] = { "text": lambda part: _TextParser(part).get("text", None), + "thinking": + lambda part: _ThinkParser(part).get("thinking", None), + "input_text": + lambda part: _TextParser(part).get("text", None), + "input_image": + lambda part: _ResponsesInputImageParser(part).get("image_url", None), "image_url": lambda part: _ImageParser(part).get("image_url", {}).get("url", None), "image_embeds": @@ -986,6 +1067,7 @@ def _parse_chat_message_content_parts( mm_tracker: BaseMultiModalItemTracker, *, wrap_dicts: bool, + interleave_strings: bool, ) -> list[ConversationMessage]: content = list[_ContentPart]() @@ -996,6 +1078,7 @@ def _parse_chat_message_content_parts( part, mm_parser, wrap_dicts=wrap_dicts, + interleave_strings=interleave_strings ) if parse_res: content.append(parse_res) @@ -1005,11 +1088,14 @@ def _parse_chat_message_content_parts( return [ConversationMessage(role=role, content=content)] # type: ignore texts = cast(list[str], content) - text_prompt = "\n".join(texts) - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, - text_prompt) + mm_placeholder_storage = mm_parser.mm_placeholder_storage() + if mm_placeholder_storage: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_storage, + texts, + interleave_strings) + else: + text_prompt = "\n".join(texts) + return [ConversationMessage(role=role, content=text_prompt)] @@ -1018,6 +1104,7 @@ def _parse_chat_message_content_part( mm_parser: BaseMultiModalContentParser, *, wrap_dicts: bool, + interleave_strings: bool, ) -> Optional[_ContentPart]: """Parses a single part of a conversation. If wrap_dicts is True, structured dictionary pieces for texts and images will be @@ -1028,10 +1115,8 @@ def _parse_chat_message_content_part( """ if isinstance(part, str): # Handle plain text parts return part - # Handle structured dictionary parts part_type, content = _parse_chat_message_content_mm_part(part) - # if part_type is text/refusal/image_url/audio_url/video_url/input_audio but # content is None, log a warning and skip if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None: @@ -1040,41 +1125,44 @@ def _parse_chat_message_content_part( "with empty / unparsable content.", part, part_type) return None - if part_type in ("text", "refusal"): + if part_type in ("text", "input_text", "refusal", "thinking"): str_content = cast(str, content) if wrap_dicts: return {'type': 'text', 'text': str_content} else: return str_content + modality = None if part_type == "image_pil": image_content = cast(Image.Image, content) mm_parser.parse_image_pil(image_content) - return {'type': 'image'} if wrap_dicts else None - if part_type == "image_url": + modality = "image" + elif part_type in ("image_url", "input_image"): str_content = cast(str, content) mm_parser.parse_image(str_content) - return {'type': 'image'} if wrap_dicts else None - if part_type == "image_embeds": + modality = "image" + elif part_type == "image_embeds": content = cast(Union[str, dict[str, str]], content) mm_parser.parse_image_embeds(content) - return {'type': 'image'} if wrap_dicts else None - if part_type == "audio_url": + modality = "image" + elif part_type == "audio_url": str_content = cast(str, content) mm_parser.parse_audio(str_content) - return {'type': 'audio'} if wrap_dicts else None - - if part_type == "input_audio": + modality = "audio" + elif part_type == "input_audio": dict_content = cast(InputAudio, content) mm_parser.parse_input_audio(dict_content) - return {'type': 'audio'} if wrap_dicts else None - - if part_type == "video_url": + modality = "audio" + elif part_type == "video_url": str_content = cast(str, content) mm_parser.parse_video(str_content) - return {'type': 'video'} if wrap_dicts else None + modality = "video" + else: + raise NotImplementedError(f"Unknown part type: {part_type}") - raise NotImplementedError(f"Unknown part type: {part_type}") + return {'type': modality} if wrap_dicts else ( + MODALITY_PLACEHOLDERS_MAP[modality] if interleave_strings else None + ) # No need to validate using Pydantic again @@ -1086,6 +1174,7 @@ def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, content_format: _ChatTemplateContentFormat, + interleave_strings: bool, ) -> list[ConversationMessage]: role = message["role"] content = message.get("content") @@ -1101,6 +1190,7 @@ def _parse_chat_message_content( content, # type: ignore mm_tracker, wrap_dicts=(content_format == "openai"), + interleave_strings=interleave_strings, ) for result_msg in result: @@ -1153,6 +1243,11 @@ def parse_chat_messages( msg, mm_tracker, content_format, + interleave_strings=( + content_format == "string" + and model_config.multimodal_config is not None + and model_config.multimodal_config.interleave_mm_strings + ) ) conversation.extend(sub_messages) @@ -1176,6 +1271,11 @@ def parse_chat_messages_futures( msg, mm_tracker, content_format, + interleave_strings=( + content_format == "string" + and model_config.multimodal_config is not None + and model_config.multimodal_config.interleave_mm_strings + ) ) conversation.extend(sub_messages) diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 3e09d45b2ed7..fed3ea650405 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -7,17 +7,6 @@ from __future__ import annotations import importlib.metadata -import signal -import sys - - -def register_signal_handlers(): - - def signal_handler(sig, frame): - sys.exit(0) - - signal.signal(signal.SIGINT, signal_handler) - signal.signal(signal.SIGTSTP, signal_handler) def main(): diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 5ddaee5b52af..e71f77ba8067 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -55,7 +55,7 @@ def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: try: input_message = input("> ") except EOFError: - return + break conversation.append({"role": "user", "content": input_message}) chat_completion = client.chat.completions.create(model=model_name, @@ -118,7 +118,7 @@ def cmd(args: argparse.Namespace) -> None: try: input_message = input("> ") except EOFError: - return + break conversation.append({"role": "user", "content": input_message}) chat_completion = client.chat.completions.create( @@ -170,7 +170,10 @@ def cmd(args: argparse.Namespace) -> None: print("Please enter prompt to complete:") while True: - input_prompt = input("> ") + try: + input_prompt = input("> ") + except EOFError: + break completion = client.completions.create(model=model_name, prompt=input_prompt) output = completion.choices[0].text diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9e24b31e1aae..72460c2d91c7 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -45,9 +45,6 @@ def cmd(args: argparse.Namespace) -> None: if args.headless or args.api_server_count < 1: run_headless(args) else: - if args.data_parallel_start_rank: - raise ValueError("data_parallel_start_rank is only " - "applicable in headless mode") if args.api_server_count > 1: run_multi_api_server(args) else: @@ -65,36 +62,6 @@ def subparser_init( help="Start the vLLM OpenAI Compatible API server.", description="Start the vLLM OpenAI Compatible API server.", usage="vllm serve [model_tag] [options]") - serve_parser.add_argument("model_tag", - type=str, - nargs='?', - help="The model tag to serve " - "(optional if specified in config)") - serve_parser.add_argument( - "--headless", - action='store_true', - default=False, - help="Run in headless mode. See multi-node data parallel " - "documentation for more details.") - serve_parser.add_argument( - '--data-parallel-start-rank', - '-dpr', - type=int, - default=0, - help='Starting data parallel rank for secondary nodes.') - serve_parser.add_argument('--api-server-count', - '-asc', - type=int, - default=1, - help='How many API server processes to run.') - serve_parser.add_argument( - "--config", - type=str, - default='', - required=False, - help="Read CLI options from a config file. " - "Must be a YAML with the following options: " - "https://docs.vllm.ai/en/latest/configuration/serve_args.html") serve_parser = make_arg_parser(serve_parser) show_filtered_argument_or_group_from_help(serve_parser, ["serve"]) @@ -114,13 +81,14 @@ def run_headless(args: argparse.Namespace): # Create the EngineConfig. engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER - vllm_config = engine_args.create_engine_config(usage_context=usage_context) + vllm_config = engine_args.create_engine_config(usage_context=usage_context, + headless=True) if not envs.VLLM_USE_V1: raise ValueError("Headless mode is only supported for V1") - if engine_args.data_parallel_rank is not None: - raise ValueError("data_parallel_rank is not applicable in " + if engine_args.data_parallel_hybrid_lb: + raise ValueError("data_parallel_hybrid_lb is not applicable in " "headless mode") parallel_config = vllm_config.parallel_config @@ -150,7 +118,7 @@ def signal_handler(signum, frame): engine_manager = CoreEngineProcManager( target_fn=EngineCoreProc.run_engine_core, local_engine_count=local_engine_count, - start_index=args.data_parallel_start_rank, + start_index=vllm_config.parallel_config.data_parallel_rank, local_start_index=0, vllm_config=vllm_config, local_client=False, @@ -197,6 +165,11 @@ def run_multi_api_server(args: argparse.Namespace): " api_server_count > 1") model_config.disable_mm_preprocessor_cache = True + if vllm_config.parallel_config.data_parallel_hybrid_lb: + raise NotImplementedError( + "Hybrid load balancing with --api-server-count > 0" + "is not yet supported.") + executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2cd473adfd38..2f766a2dae57 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -28,8 +28,11 @@ apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) -from vllm.entrypoints.score_utils import (_cosine_similarity, - _validate_score_input_lens) +from vllm.entrypoints.score_utils import (ScoreContentPartParam, + ScoreMultiModalParam, + _cosine_similarity, + _validate_score_input_lens, + get_score_prompt) from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt @@ -41,8 +44,7 @@ from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, PoolingRequestOutput, RequestOutput, ScoringRequestOutput) -from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.pooling_params import PoolingParams, PoolingTask from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, @@ -311,7 +313,6 @@ def generate( *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, ) -> list[RequestOutput]: @@ -327,7 +328,6 @@ def generate( prompt_token_ids: Optional[list[int]] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, ) -> list[RequestOutput]: @@ -343,7 +343,6 @@ def generate( prompt_token_ids: Optional[list[list[int]]] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, ) -> list[RequestOutput]: @@ -360,7 +359,6 @@ def generate( prompt_token_ids: list[int], use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, ) -> list[RequestOutput]: @@ -377,7 +375,6 @@ def generate( prompt_token_ids: list[list[int]], use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, ) -> list[RequestOutput]: @@ -392,7 +389,6 @@ def generate( prompt_token_ids: Union[list[int], list[list[int]]], use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, ) -> list[RequestOutput]: @@ -412,7 +408,6 @@ def generate( prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, GuidedDecodingRequest]] = None, priority: Optional[list[int]] = None, @@ -437,8 +432,6 @@ def generate( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for - generation, if any. priority: The priority of the requests, if any. Only applicable when priority scheduling policy is enabled. @@ -451,20 +444,19 @@ def generate( considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ - runner_type = self.llm_engine.model_config.runner_type - if runner_type not in ["generate", "transcription"]: + model_config = self.llm_engine.model_config + runner_type = model_config.runner_type + if runner_type != "generate": messages = [ - "LLM.generate() is only supported for (conditional) generation " - "models (XForCausalLM, XForConditionalGeneration).", + "LLM.generate() is only supported for generative models." ] - supported_runner_types = self.llm_engine.model_config \ - .supported_runner_types - if "generate" in supported_runner_types: + if "generate" in model_config.supported_runner_types: messages.append( "Your model supports the 'generate' runner, but is " f"currently initialized for the '{runner_type}' runner. " - "Please initialize vLLM using `--task generate`.") + "Please initialize vLLM using `--task generate` or " + "`--task transcription`.") raise ValueError(" ".join(messages)) @@ -505,7 +497,6 @@ def generate( params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, guided_options=guided_options_request, tokenization_kwargs=tokenization_kwargs, priority=priority, @@ -961,7 +952,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -976,7 +968,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -991,7 +984,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1007,7 +1001,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1023,7 +1018,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1037,7 +1033,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: ... @@ -1056,7 +1053,8 @@ def encode( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + pooling_task: PoolingTask = "encode", + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> list[PoolingRequestOutput]: """Apply pooling to the hidden states corresponding to the input prompts. @@ -1076,8 +1074,7 @@ def encode( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for - generation, if any. + pooling_task: Override the pooling task to use. Returns: A list of `PoolingRequestOutput` objects containing the @@ -1088,13 +1085,12 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the `inputs` parameter. """ - runner_type = self.llm_engine.model_config.runner_type + model_config = self.llm_engine.model_config + runner_type = model_config.runner_type if runner_type != "pooling": messages = ["LLM.encode() is only supported for pooling models."] - supported_runner_types = self.llm_engine.model_config \ - .supported_runner_types - if "pooling" in supported_runner_types: + if "pooling" in model_config.supported_runner_types: messages.append( "Your model supports the 'pooling' runner, but is " f"currently initialized for the '{runner_type}' runner. " @@ -1115,15 +1111,18 @@ def encode( if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() - elif isinstance(pooling_params, PoolingParams): - pooling_params.verify(self.llm_engine.model_config) + + if isinstance(pooling_params, PoolingParams): + pooling_params.verify(pooling_task, model_config) else: for pooling_param in pooling_params: - pooling_param.verify(self.llm_engine.model_config) + pooling_param.verify(pooling_task, model_config) - tokenization_kwargs: dict[str, Any] = {} - _validate_truncation_size(self.llm_engine.model_config.max_model_len, - truncate_prompt_tokens, tokenization_kwargs) + if tokenization_kwargs is None: + tokenization_kwargs = dict[str, Any]() + _validate_truncation_size(model_config.max_model_len, + truncate_prompt_tokens, + tokenization_kwargs) self._validate_and_add_requests( prompts=parsed_prompts, @@ -1131,7 +1130,6 @@ def encode( use_tqdm=use_tqdm, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, - prompt_adapter_request=prompt_adapter_request, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -1148,7 +1146,6 @@ def embed( pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[EmbeddingRequestOutput]: """ Generate an embedding vector for each prompt. @@ -1168,23 +1165,24 @@ def embed( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for - generation, if any. Returns: A list of `EmbeddingRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - if self.llm_engine.model_config.task != "embed": - raise ValueError( - "Embedding API is only enabled for `--task embed`") + model_config = self.llm_engine.model_config + if "embed" not in model_config.supported_tasks: + raise ValueError("Embedding API is not supported by this model. " + "Please set `--task embed`.") - items = self.encode(prompts, - truncate_prompt_tokens=truncate_prompt_tokens, - use_tqdm=use_tqdm, - pooling_params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + items = self.encode( + prompts, + truncate_prompt_tokens=truncate_prompt_tokens, + use_tqdm=use_tqdm, + pooling_params=pooling_params, + lora_request=lora_request, + pooling_task="embed", + ) return [EmbeddingRequestOutput.from_base(item) for item in items] @@ -1195,7 +1193,6 @@ def classify( *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ClassificationRequestOutput]: """ Generate class logits for each prompt. @@ -1213,21 +1210,23 @@ def classify( it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for - generation, if any. Returns: A list of `ClassificationRequestOutput` objects containing the embedding vectors in the same order as the input prompts. """ - if self.llm_engine.model_config.task != "classify": + model_config = self.llm_engine.model_config + if "classify" not in model_config.supported_tasks: raise ValueError( - "Classification API is only enabled for `--task classify`") + "Classification API is not supported by this model. " + "Please set `--task classify`.") - items = self.encode(prompts, - use_tqdm=use_tqdm, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + items = self.encode( + prompts, + use_tqdm=use_tqdm, + lora_request=lora_request, + pooling_task="classify", + ) return [ClassificationRequestOutput.from_base(item) for item in items] @@ -1239,7 +1238,6 @@ def _embedding_score( truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: encoded_output: list[PoolingRequestOutput] = self.encode( @@ -1247,7 +1245,8 @@ def _embedding_score( truncate_prompt_tokens=truncate_prompt_tokens, use_tqdm=use_tqdm, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + pooling_task="embed", + ) encoded_output_1: list[PoolingRequestOutput] = encoded_output[ 0:len(text_1)] @@ -1268,46 +1267,68 @@ def _embedding_score( def _cross_encoding_score( self, tokenizer: AnyTokenizer, - text_1: list[str], - text_2: list[str], + data_1: Union[list[str], list[ScoreContentPartParam]], + data_2: Union[list[str], list[ScoreContentPartParam]], truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: if isinstance(tokenizer, MistralTokenizer): raise ValueError( "Score API is only enabled for `--task embed or score`") - if len(text_1) == 1: - text_1 = text_1 * len(text_2) - - input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] - - pooling_params = PoolingParams(use_cross_encoder=True) + if len(data_1) == 1: + data_1 = data_1 * len(data_2) + pooling_params = PoolingParams(task="score") tokenization_kwargs: dict[str, Any] = {} _validate_truncation_size(self.llm_engine.model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs) parsed_prompts = [] - for q, t in input_pairs: - prompt_inputs = tokenizer(text=q, - text_pair=t, - **tokenization_kwargs) - engine_prompt = TokensPrompt( - prompt_token_ids=prompt_inputs["input_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) - parsed_prompts.append(engine_prompt) + input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] + + if self.llm_engine.model_config.is_multimodal_model: + + model_config = self.llm_engine.model_config + + for q, d in input_pairs: + _, engine_prompt = get_score_prompt( + model_config=model_config, + data_1=q, + data_2=d, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + ) + + parsed_prompts.append(engine_prompt) + + else: + + for q, t in input_pairs: + if self.llm_engine.model_config.use_pad_token: + # cross_encoder models defaults to using pad_token. + prompt_inputs = tokenizer( + text=q, # type: ignore[arg-type] + text_pair=t, # type: ignore[arg-type] + **tokenization_kwargs) + else: + # `llm as reranker` models defaults to not using pad_token. + prompt_inputs = tokenizer( + text=q + t, # type: ignore[operator] + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) self._validate_and_add_requests( prompts=parsed_prompts, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -1318,50 +1339,56 @@ def _cross_encoding_score( def score( self, - text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], - text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], + ScoreMultiModalParam], + data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], + ScoreMultiModalParam], /, *, truncate_prompt_tokens: Optional[int] = None, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> list[ScoringRequestOutput]: - """Generate similarity scores for all pairs `<text,text_pair>`. + """Generate similarity scores for all pairs `<text,text_pair>` or + `<multi-modal data, multi-modal data pair>`. The inputs can be `1 -> 1`, `1 -> N` or `N -> N`. - In the `1 - N` case the `text_1` sentence will be replicated `N` - times to pair with the `text_2` sentences. + In the `1 - N` case the `data_1` input will be replicated `N` + times to pair with the `data_2` inputs. The input pairs are used to build a list of prompts for the cross encoder model. This class automatically batches the prompts, considering the memory constraint. For the best performance, put all - of your texts into a single list and pass it to this method. + of your inputs into a single list and pass it to this method. + + Supports both text and multi-modal data (images, etc.) when used with + appropriate multi-modal models. For multi-modal inputs, ensure the + prompt structure matches the model's expected input format. Args: - text_1: can be a single prompt or a list of prompts, in which - case it has to have the same length as the `text_2` list - text_2: The texts to pair with the query to form the input - to the LLM. See [PromptType][vllm.inputs.PromptType] for - more details about the format of each prompts. + data_1: Can be a single prompt, a list of prompts or + `ScoreMultiModalParam`, which can contain either text or + multi-modal data. When a list, it must have the same length as + the `data_2` list. + data_2: The data to pair with the query to form the input to + the LLM. Can be text or multi-modal data. See [PromptType] + [vllm.inputs.PromptType] for more details about the format of + each prompt. use_tqdm: If `True`, shows a tqdm progress bar. If a callable (e.g., `functools.partial(tqdm, leave=False)`), it is used to create the progress bar. If `False`, no progress bar is created. lora_request: LoRA request to use for generation, if any. - prompt_adapter_request: Prompt Adapter request to use for - generation, if any. Returns: A list of `ScoringRequestOutput` objects containing the generated scores in the same order as the input prompts. """ - runner_type = self.llm_engine.model_config.runner_type + model_config = self.llm_engine.model_config + runner_type = model_config.runner_type if runner_type != "pooling": messages = ["LLM.score() is only supported for pooling models."] - supported_runner_types = self.llm_engine.model_config \ - .supported_runner_types - if "pooling" in supported_runner_types: + if "pooling" in model_config.supported_runner_types: messages.append( "Your model supports the 'pooling' runner, but is " f"currently initialized for the '{runner_type}' runner. " @@ -1370,12 +1397,13 @@ def score( raise ValueError(" ".join(messages)) - if self.llm_engine.model_config.task not in ("embed", "classify"): - raise ValueError("Score API is only enabled for " - "`--task embed or --task classify`.") + if all(t not in model_config.supported_tasks + for t in ("embed", "classify")): + raise ValueError("Score API is not supported by this model. " + "Please set `--task embed` or `--task classify`.") - if (self.llm_engine.model_config.task == "classify" - and self.llm_engine.model_config.hf_config.num_labels != 1): + if (model_config.task == "classify" + and getattr(model_config.hf_config, "num_labels", 0) != 1): raise ValueError("Score API is only enabled for num_labels == 1.") # the tokenizer for models such as @@ -1383,46 +1411,72 @@ def score( # lists of tokens to the `text` and `text_pair` kwargs tokenizer = self.get_tokenizer() - def ensure_str(prompt: SingletonPrompt): - if isinstance(prompt, dict): - if "multi_modal_data" in prompt: - raise ValueError("Multi-modal prompt is not " - "supported for scoring") - elif "prompt_token_ids" in prompt: - prompt = tokenizer.decode( - cast(TokensPrompt, prompt)["prompt_token_ids"]) - elif "prompt" in prompt: - prompt = cast(TextPrompt, prompt)["prompt"] - assert type(prompt) is str - return prompt - - if isinstance(text_1, (str, dict)): - # Convert a single prompt to a list. - text_1 = [text_1] - input_text_1: list[str] = [ensure_str(t) for t in text_1] + if not self.llm_engine.model_config.is_multimodal_model: - if isinstance(text_2, (str, dict)): - # Convert a single prompt to a list. - text_2 = [text_2] - input_text_2: list[str] = [ensure_str(t) for t in text_2] + def check_data_type(data: Union[SingletonPrompt, + Sequence[SingletonPrompt], + ScoreMultiModalParam]): + if isinstance(data, dict) and "content" in data: + raise ValueError( + f"ScoreMultiModalParam is not supported for {self.llm_engine.model_config.architecture}", # noqa: E501 + ) + + check_data_type(data_1) + check_data_type(data_2) + + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for scoring") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(data_1, (str, dict)): + # Convert a single prompt to a list. + data_1 = [data_1] # type: ignore[list-item] + + data_1 = [ensure_str(t) for t in data_1] + + if isinstance(data_2, (str, dict)): + # Convert a single prompt to a list. + data_2 = [data_2] # type: ignore[list-item] - _validate_score_input_lens(input_text_1, input_text_2) + data_2 = [ensure_str(t) for t in data_2] + + if isinstance(data_1, dict) and "content" in data_1: + data_1 = data_1.get("content") # type: ignore[assignment] + elif isinstance(data_1, str): + data_1 = [data_1] + + if isinstance(data_2, dict) and "content" in data_2: + data_2 = data_2.get("content") # type: ignore[assignment] + elif isinstance(data_2, str): + data_2 = [data_2] + + _validate_score_input_lens(data_1, data_2) # type: ignore[arg-type] if self.llm_engine.model_config.is_cross_encoder: - return self._cross_encoding_score(tokenizer, input_text_1, - input_text_2, - truncate_prompt_tokens, use_tqdm, - lora_request, - prompt_adapter_request) + return self._cross_encoding_score( + tokenizer, + data_1, # type: ignore[arg-type] + data_2, # type: ignore[arg-type] + truncate_prompt_tokens, + use_tqdm, + lora_request) else: return self._embedding_score( tokenizer, - input_text_1, # type: ignore[arg-type] - input_text_2, # type: ignore[arg-type] + data_1, # type: ignore[arg-type] + data_2, # type: ignore[arg-type] truncate_prompt_tokens, use_tqdm, - lora_request, - prompt_adapter_request) + lora_request) def start_profile(self) -> None: self.llm_engine.start_profile() @@ -1533,7 +1587,6 @@ def _validate_and_add_requests( *, use_tqdm: Union[bool, Callable[..., tqdm]] = True, lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], - prompt_adapter_request: Optional[PromptAdapterRequest], tokenization_kwargs: Optional[dict[str, Any]] = None, guided_options: Optional[GuidedDecodingRequest] = None, priority: Optional[list[int]] = None, @@ -1579,7 +1632,6 @@ def _validate_and_add_requests( tokenization_kwargs=tokenization_kwargs, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, - prompt_adapter_request=prompt_adapter_request, priority=priority[i] if priority else 0, ) @@ -1589,7 +1641,6 @@ def _add_request( params: Union[SamplingParams, PoolingParams], tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: request_id = str(next(self.request_counter)) @@ -1599,7 +1650,6 @@ def _add_request( params, lora_request=lora_request, tokenization_kwargs=tokenization_kwargs, - prompt_adapter_request=prompt_adapter_request, priority=priority, ) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index f3aee188dae9..06ff3b417f83 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -8,7 +8,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams logger = init_logger(__name__) @@ -30,7 +29,6 @@ def log_inputs( params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: max_log_len = self.max_log_len if max_log_len is not None: @@ -44,7 +42,6 @@ def log_inputs( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " "prompt_embeds shape: %s, " - "lora_request: %s, prompt_adapter_request: %s.", request_id, - prompt, params, prompt_token_ids, + "lora_request: %s.", request_id, prompt, params, prompt_token_ids, prompt_embeds.shape if prompt_embeds is not None else None, - lora_request, prompt_adapter_request) + lora_request) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 749cd9d35174..d4135519aa45 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,9 +18,10 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Annotated, Any, Optional +from typing import Annotated, Any, Callable, Optional import prometheus_client +import pydantic import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request @@ -60,17 +61,14 @@ CompletionResponse, DetokenizeRequest, DetokenizeResponse, - EmbeddingChatRequest, - EmbeddingCompletionRequest, EmbeddingRequest, EmbeddingResponse, ErrorResponse, LoadLoRAAdapterRequest, - PoolingChatRequest, - PoolingCompletionRequest, PoolingRequest, PoolingResponse, RerankRequest, RerankResponse, - ScoreRequest, ScoreResponse, - TokenizeRequest, + ResponsesRequest, + ResponsesResponse, ScoreRequest, + ScoreResponse, TokenizeRequest, TokenizeResponse, TranscriptionRequest, TranscriptionResponse, @@ -88,6 +86,7 @@ LoRAModulePath, OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling +from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses from vllm.entrypoints.openai.serving_score import ServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) @@ -369,6 +368,10 @@ def models(request: Request) -> OpenAIServingModels: return request.app.state.openai_serving_models +def responses(request: Request) -> Optional[OpenAIServingResponses]: + return request.app.state.openai_serving_responses + + def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat @@ -427,6 +430,7 @@ async def get_server_load_metrics(request: Request): # - /v1/chat/completions # - /v1/completions # - /v1/audio/transcriptions + # - /v1/audio/translations # - /v1/embeddings # - /pooling # - /classify @@ -518,6 +522,19 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): assert_never(generator) +def maybe_register_tokenizer_info_endpoint(args): + """Conditionally register the tokenizer info endpoint if enabled.""" + if getattr(args, 'enable_tokenizer_info_endpoint', False): + + @router.get("/tokenizer_info") + async def get_tokenizer_info(raw_request: Request): + """Get comprehensive tokenizer information.""" + result = await tokenization(raw_request).get_tokenizer_info() + return JSONResponse(content=result.model_dump(), + status_code=result.code if isinstance( + result, ErrorResponse) else 200) + + @router.get("/v1/models") async def show_available_models(raw_request: Request): handler = models(raw_request) @@ -532,6 +549,71 @@ async def show_version(): return JSONResponse(content=ver) +@router.post("/v1/responses", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "content": { + "text/event-stream": {} + } + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.NOT_FOUND.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +@with_cancellation +async def create_responses(request: ResponsesRequest, raw_request: Request): + handler = responses(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Responses API") + + generator = await handler.create_responses(request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ResponsesResponse): + return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.get("/v1/responses/{response_id}") +async def retrieve_responses(response_id: str, raw_request: Request): + handler = responses(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Responses API") + + response = await handler.retrieve_responses(response_id) + + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + return JSONResponse(content=response.model_dump()) + + +@router.post("/v1/responses/{response_id}/cancel") +async def cancel_responses(response_id: str, raw_request: Request): + handler = responses(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Responses API") + + response = await handler.cancel_responses(response_id) + + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + return JSONResponse(content=response.model_dump()) + + @router.post("/v1/chat/completions", dependencies=[Depends(validate_json_request)], responses={ @@ -885,31 +967,6 @@ async def do_rerank_v2(request: RerankRequest, raw_request: Request): return await do_rerank(request, raw_request) -TASK_HANDLERS: dict[str, dict[str, tuple]] = { - "generate": { - "messages": (ChatCompletionRequest, create_chat_completion), - "default": (CompletionRequest, create_completion), - }, - "embed": { - "messages": (EmbeddingChatRequest, create_embedding), - "default": (EmbeddingCompletionRequest, create_embedding), - }, - "score": { - "default": (RerankRequest, do_rerank) - }, - "rerank": { - "default": (RerankRequest, do_rerank) - }, - "reward": { - "messages": (PoolingChatRequest, create_pooling), - "default": (PoolingCompletionRequest, create_pooling), - }, - "classify": { - "messages": (PoolingChatRequest, create_pooling), - "default": (PoolingCompletionRequest, create_pooling), - }, -} - if envs.VLLM_SERVER_DEV_MODE: logger.warning("SECURITY WARNING: Development endpoints are enabled! " "This should NOT be used in production!") @@ -961,6 +1018,97 @@ async def is_sleeping(raw_request: Request): return JSONResponse(content={"is_sleeping": is_sleeping}) +@router.post("/scale_elastic_ep", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: { + "model": dict + }, + HTTPStatus.BAD_REQUEST.value: { + "model": ErrorResponse + }, + HTTPStatus.REQUEST_TIMEOUT.value: { + "model": ErrorResponse + }, + HTTPStatus.INTERNAL_SERVER_ERROR.value: { + "model": ErrorResponse + }, + }) +async def scale_elastic_ep(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, + detail="Invalid JSON format") from e # noqa: B904 + + new_data_parallel_size = body.get("new_data_parallel_size") + drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes + + if new_data_parallel_size is None: + raise HTTPException(status_code=400, + detail="new_data_parallel_size is required") + + if not isinstance(new_data_parallel_size, + int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, + detail="new_data_parallel_size must be a positive integer") + + if not isinstance(drain_timeout, int) or drain_timeout <= 0: + raise HTTPException(status_code=400, + detail="drain_timeout must be a positive integer") + + # Set scaling flag to prevent new requests + global _scaling_elastic_ep + _scaling_elastic_ep = True + client = engine_client(raw_request) + try: + await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) + return JSONResponse({ + "message": + f"Scaled to {new_data_parallel_size} " + "data parallel engines", + }) + except TimeoutError as e: + raise HTTPException(status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds") from e + except Exception as e: + logger.error("Scale failed: %s", e) + raise HTTPException(status_code=500, detail="Scale failed") from e + finally: + _scaling_elastic_ep = False + + +@router.post("/is_scaling_elastic_ep") +async def is_scaling_elastic_ep(raw_request: Request): + return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep}) + + +# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers +# (requires typing_extensions >= 4.13) +RequestType = Any +GetHandlerFn = Callable[[Request], Optional[OpenAIServing]] +EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] + +# NOTE: Items defined earlier take higher priority +INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [ + (ChatCompletionRequest, (chat, create_chat_completion)), + (CompletionRequest, (completion, create_completion)), + (EmbeddingRequest, (embedding, create_embedding)), + (ClassificationRequest, (classify, create_classify)), + (ScoreRequest, (score, create_score)), + (RerankRequest, (rerank, do_rerank)), + (PoolingRequest, (pooling, create_pooling)), +] + +# NOTE: Construct the TypeAdapters only once +INVOCATION_VALIDATORS = [ + (pydantic.TypeAdapter(request_type), (get_handler, endpoint)) + for request_type, (get_handler, endpoint) in INVOCATION_TYPES +] + + @router.post("/invocations", dependencies=[Depends(validate_json_request)], responses={ @@ -975,32 +1123,34 @@ async def is_sleeping(raw_request: Request): }, }) async def invocations(raw_request: Request): - """ - For SageMaker, routes requests to other handlers based on model `task`. - """ + """For SageMaker, routes requests based on the request type.""" try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException(status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}") from e - task = raw_request.app.state.task + valid_endpoints = [(validator, endpoint) + for validator, (get_handler, + endpoint) in INVOCATION_VALIDATORS + if get_handler(raw_request) is not None] - if task not in TASK_HANDLERS: - raise HTTPException( - status_code=400, - detail=f"Unsupported task: '{task}' for '/invocations'. " - f"Expected one of {set(TASK_HANDLERS.keys())}") + for request_validator, endpoint in valid_endpoints: + try: + request = request_validator.validate_python(body) + except pydantic.ValidationError: + continue - handler_config = TASK_HANDLERS[task] - if "messages" in body: - request_model, handler = handler_config["messages"] - else: - request_model, handler = handler_config["default"] + return await endpoint(request, raw_request) - # this is required since we lose the FastAPI automatic casting - request = request_model.model_validate(body) - return await handler(request, raw_request) + type_names = [ + t.__name__ if isinstance(t := validator._type, type) else str(t) + for validator, _ in valid_endpoints + ] + msg = ("Cannot find suitable handler for request. " + f"Expected one of: {type_names}") + res = base(raw_request).create_error_response(message=msg) + return JSONResponse(content=res.model_dump(), status_code=res.code) if envs.VLLM_TORCH_PROFILER_DIR: @@ -1133,6 +1283,177 @@ async def send_with_request_id(message: Message) -> None: return self.app(scope, receive, send_with_request_id) +# Global variable to track scaling state +_scaling_elastic_ep = False + + +class ScalingMiddleware: + """ + Middleware that checks if the model is currently scaling and + returns a 503 Service Unavailable response if it is. + + This middleware applies to all HTTP requests and prevents + processing when the model is in a scaling state. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, + send: Send) -> Awaitable[None]: + if scope["type"] != "http": + return self.app(scope, receive, send) + + # Check global scaling state + global _scaling_elastic_ep + if _scaling_elastic_ep: + # Return 503 Service Unavailable response + response = JSONResponse(content={ + "error": + "The model is currently scaling. Please try again later." + }, + status_code=503) + return response(scope, receive, send) + + return self.app(scope, receive, send) + + +def _extract_content_from_chunk(chunk_data: dict) -> str: + """Extract content from a streaming response chunk.""" + try: + from vllm.entrypoints.openai.protocol import ( + ChatCompletionStreamResponse, CompletionStreamResponse) + + # Try using Completion types for type-safe parsing + if chunk_data.get('object') == 'chat.completion.chunk': + chat_response = ChatCompletionStreamResponse.model_validate( + chunk_data) + if chat_response.choices and chat_response.choices[0].delta.content: + return chat_response.choices[0].delta.content + elif chunk_data.get('object') == 'text_completion': + completion_response = CompletionStreamResponse.model_validate( + chunk_data) + if completion_response.choices and completion_response.choices[ + 0].text: + return completion_response.choices[0].text + except pydantic.ValidationError: + # Fallback to manual parsing + if 'choices' in chunk_data and chunk_data['choices']: + choice = chunk_data['choices'][0] + if 'delta' in choice and choice['delta'].get('content'): + return choice['delta']['content'] + elif choice.get('text'): + return choice['text'] + return "" + + +class SSEDecoder: + """Robust Server-Sent Events decoder for streaming responses.""" + + def __init__(self): + self.buffer = "" + self.content_buffer = [] + + def decode_chunk(self, chunk: bytes) -> list[dict]: + """Decode a chunk of SSE data and return parsed events.""" + import json + + try: + chunk_str = chunk.decode('utf-8') + except UnicodeDecodeError: + # Skip malformed chunks + return [] + + self.buffer += chunk_str + events = [] + + # Process complete lines + while '\n' in self.buffer: + line, self.buffer = self.buffer.split('\n', 1) + line = line.rstrip('\r') # Handle CRLF + + if line.startswith('data: '): + data_str = line[6:].strip() + if data_str == '[DONE]': + events.append({'type': 'done'}) + elif data_str: + try: + event_data = json.loads(data_str) + events.append({'type': 'data', 'data': event_data}) + except json.JSONDecodeError: + # Skip malformed JSON + continue + + return events + + def extract_content(self, event_data: dict) -> str: + """Extract content from event data.""" + return _extract_content_from_chunk(event_data) + + def add_content(self, content: str) -> None: + """Add content to the buffer.""" + if content: + self.content_buffer.append(content) + + def get_complete_content(self) -> str: + """Get the complete buffered content.""" + return ''.join(self.content_buffer) + + +def _log_streaming_response(response, response_body: list) -> None: + """Log streaming response with robust SSE parsing.""" + from starlette.concurrency import iterate_in_threadpool + + sse_decoder = SSEDecoder() + chunk_count = 0 + + def buffered_iterator(): + nonlocal chunk_count + + for chunk in response_body: + chunk_count += 1 + yield chunk + + # Parse SSE events from chunk + events = sse_decoder.decode_chunk(chunk) + + for event in events: + if event['type'] == 'data': + content = sse_decoder.extract_content(event['data']) + sse_decoder.add_content(content) + elif event['type'] == 'done': + # Log complete content when done + full_content = sse_decoder.get_complete_content() + if full_content: + # Truncate if too long + if len(full_content) > 2048: + full_content = full_content[:2048] + "" + "...[truncated]" + logger.info( + "response_body={streaming_complete: " \ + "content='%s', chunks=%d}", + full_content, chunk_count) + else: + logger.info( + "response_body={streaming_complete: " \ + "no_content, chunks=%d}", + chunk_count) + return + + response.body_iterator = iterate_in_threadpool(buffered_iterator()) + logger.info("response_body={streaming_started: chunks=%d}", + len(response_body)) + + +def _log_non_streaming_response(response_body: list) -> None: + """Log non-streaming response.""" + try: + decoded_body = response_body[0].decode() + logger.info("response_body={%s}", decoded_body) + except UnicodeDecodeError: + logger.info("response_body={<binary_data>}") + + def build_app(args: Namespace) -> FastAPI: if args.disable_fastapi_docs: app = FastAPI(openapi_url=None, @@ -1185,6 +1506,9 @@ async def validation_exception_handler(_: Request, if args.enable_request_id_headers: app.add_middleware(XRequestIdMiddleware) + # Add scaling middleware to check for scaling state + app.add_middleware(ScalingMiddleware) + if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE: logger.warning("CAUTION: Enabling log response in the API Server. " "This can include sensitive information and should be " @@ -1197,8 +1521,17 @@ async def log_response(request: Request, call_next): section async for section in response.body_iterator ] response.body_iterator = iterate_in_threadpool(iter(response_body)) - logger.info("response_body={%s}", - response_body[0].decode() if response_body else None) + # Check if this is a streaming response by looking at content-type + content_type = response.headers.get("content-type", "") + is_streaming = content_type == "text/event-stream; charset=utf-8" + + # Log response body based on type + if not response_body: + logger.info("response_body={<empty>}") + elif is_streaming: + _log_streaming_response(response, response_body) + else: + _log_non_streaming_response(response_body) return response for middleware in args.middleware: @@ -1287,9 +1620,22 @@ async def init_app_state( model_config=model_config, base_model_paths=base_model_paths, lora_modules=lora_modules, - prompt_adapters=args.prompt_adapters, ) await state.openai_serving_models.init_static_loras() + state.openai_serving_responses = OpenAIServingResponses( + engine_client, + model_config, + state.openai_serving_models, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) if "generate" in model_config.supported_tasks else None state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, @@ -1300,21 +1646,20 @@ async def init_app_state( chat_template_content_format=args.chat_template_content_format, return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, - expand_tools_even_if_tool_choice_none=args. - expand_tools_even_if_tool_choice_none, tool_parser=args.tool_call_parser, reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, enable_force_include_usage=args.enable_force_include_usage, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, model_config, @@ -1322,7 +1667,7 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.runner_type == "pooling" else None + ) if "encode" in model_config.supported_tasks else None state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, @@ -1330,27 +1675,24 @@ async def init_app_state( request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, - ) if model_config.task == "embed" else None + ) if "embed" in model_config.supported_tasks else None state.openai_serving_classification = ServingClassification( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.task == "classify" else None + ) if "classify" in model_config.supported_tasks else None - enable_serving_reranking = (model_config.task == "classify" and getattr( - model_config.hf_config, "num_labels", 0) == 1) - state.jinaai_serving_reranking = ServingScores( - engine_client, - model_config, - state.openai_serving_models, - request_logger=request_logger) if enable_serving_reranking else None + enable_serving_reranking = ("classify" in model_config.supported_tasks + and getattr(model_config.hf_config, + "num_labels", 0) == 1) state.openai_serving_scores = ServingScores( engine_client, model_config, state.openai_serving_models, - request_logger=request_logger) if ( - model_config.task == "embed" or enable_serving_reranking) else None + request_logger=request_logger, + ) if ("embed" in model_config.supported_tasks + or enable_serving_reranking) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, @@ -1365,13 +1707,13 @@ async def init_app_state( model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.runner_type == "transcription" else None + ) if "transcription" in model_config.supported_tasks else None state.openai_serving_translation = OpenAIServingTranslation( engine_client, model_config, state.openai_serving_models, request_logger=request_logger, - ) if model_config.runner_type == "transcription" else None + ) if "transcription" in model_config.supported_tasks else None state.task = model_config.task state.enable_server_load_tracking = args.enable_server_load_tracking @@ -1467,6 +1809,7 @@ async def run_server_worker(listen_address, uvicorn_kwargs['log_config'] = log_config async with build_async_engine_client(args, client_config) as engine_client: + maybe_register_tokenizer_info_endpoint(args) app = build_app(args) vllm_config = await engine_client.get_vllm_config() diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4f8aaab772fd..3025a6263682 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -10,14 +10,17 @@ import json import ssl from collections.abc import Sequence -from typing import Optional, Union, get_args +from dataclasses import field +from typing import Literal, Optional, Union + +from pydantic.dataclasses import dataclass import vllm.envs as envs +from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.serving_models import (LoRAModulePath, - PromptAdapterPath) +from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser @@ -61,240 +64,176 @@ def __call__( setattr(namespace, self.dest, lora_list) -class PromptAdapterParserAction(argparse.Action): +@config +@dataclass +class FrontendArgs: + """Arguments for the OpenAI-compatible frontend server.""" + host: Optional[str] = None + """Host name.""" + port: int = 8000 + """Port number.""" + uvicorn_log_level: Literal["debug", "info", "warning", "error", "critical", + "trace"] = "info" + """Log level for uvicorn.""" + disable_uvicorn_access_log: bool = False + """Disable uvicorn access log.""" + allow_credentials: bool = False + """Allow credentials.""" + allowed_origins: list[str] = field(default_factory=lambda: ["*"]) + """Allowed origins.""" + allowed_methods: list[str] = field(default_factory=lambda: ["*"]) + """Allowed methods.""" + allowed_headers: list[str] = field(default_factory=lambda: ["*"]) + """Allowed headers.""" + api_key: Optional[str] = None + """If provided, the server will require this key to be presented in the + header.""" + lora_modules: Optional[list[LoRAModulePath]] = None + """LoRA modules configurations in either 'name=path' format or JSON format + or JSON list format. Example (old format): `'name=path'` Example (new + format): `{\"name\": \"name\", \"path\": \"lora_path\", + \"base_model_name\": \"id\"}`""" + chat_template: Optional[str] = None + """The file path to the chat template, or the template in single-line form + for the specified model.""" + chat_template_content_format: ChatTemplateContentFormatOption = "auto" + """The format to render message content within a chat template. - def __call__( - self, - parser: argparse.ArgumentParser, - namespace: argparse.Namespace, - values: Optional[Union[str, Sequence[str]]], - option_string: Optional[str] = None, - ): - if values is None: - values = [] - if isinstance(values, str): - raise TypeError("Expected values to be a list") +* "string" will render the content as a string. Example: `"Hello World"` +* "openai" will render the content as a list of dictionaries, similar to OpenAI +schema. Example: `[{"type": "text", "text": "Hello world!"}]`""" + response_role: str = "assistant" + """The role name to return if `request.add_generation_prompt=true`.""" + ssl_keyfile: Optional[str] = None + """The file path to the SSL key file.""" + ssl_certfile: Optional[str] = None + """The file path to the SSL cert file.""" + ssl_ca_certs: Optional[str] = None + """The CA certificates file.""" + enable_ssl_refresh: bool = False + """Refresh SSL Context when SSL certificate files change""" + ssl_cert_reqs: int = int(ssl.CERT_NONE) + """Whether client certificate is required (see stdlib ssl module's).""" + root_path: Optional[str] = None + """FastAPI root_path when app is behind a path based routing proxy.""" + middleware: list[str] = field(default_factory=lambda: []) + """Additional ASGI middleware to apply to the app. We accept multiple + --middleware arguments. The value should be an import path. If a function + is provided, vLLM will add it to the server using + `@app.middleware('http')`. If a class is provided, vLLM will + add it to the server using `app.add_middleware()`.""" + return_tokens_as_token_ids: bool = False + """When `--max-logprobs` is specified, represents single tokens as + strings of the form 'token_id:{token_id}' so that tokens that are not + JSON-encodable can be identified.""" + disable_frontend_multiprocessing: bool = False + """If specified, will run the OpenAI frontend server in the same process as + the model serving engine.""" + enable_request_id_headers: bool = False + """If specified, API server will add X-Request-Id header to responses. + Caution: this hurts performance at high QPS.""" + enable_auto_tool_choice: bool = False + """Enable auto tool choice for supported models. Use `--tool-call-parser` + to specify which parser to use.""" + tool_call_parser: Optional[str] = None + """Select the tool call parser depending on the model that you're using. + This is used to parse the model-generated tool call into OpenAI API format. + Required for `--enable-auto-tool-choice`. You can choose any option from + the built-in parsers or register a plugin via `--tool-parser-plugin`.""" + tool_parser_plugin: str = "" + """Special the tool parser plugin write to parse the model-generated tool + into OpenAI API format, the name register in this plugin can be used in + `--tool-call-parser`.""" + log_config_file: Optional[str] = envs.VLLM_LOGGING_CONFIG_PATH + """Path to logging config JSON file for both vllm and uvicorn""" + max_log_len: Optional[int] = None + """Max number of prompt characters or prompt ID numbers being printed in + log. The default of None means unlimited.""" + disable_fastapi_docs: bool = False + """Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint.""" + enable_prompt_tokens_details: bool = False + """If set to True, enable prompt_tokens_details in usage.""" + enable_server_load_tracking: bool = False + """If set to True, enable tracking server_load_metrics in the app state.""" + enable_force_include_usage: bool = False + """If set to True, including usage on every request.""" + enable_tokenizer_info_endpoint: bool = False + """Enable the /get_tokenizer_info endpoint. May expose chat + templates and other tokenizer configuration.""" - adapter_list: list[PromptAdapterPath] = [] - for item in values: - name, path = item.split('=') - adapter_list.append(PromptAdapterPath(name, path)) - setattr(namespace, self.dest, adapter_list) + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + from vllm.engine.arg_utils import get_kwargs + frontend_kwargs = get_kwargs(FrontendArgs) -def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: - parser.add_argument("--host", - type=optional_type(str), - default=None, - help="Host name.") - parser.add_argument("--port", type=int, default=8000, help="Port number.") - parser.add_argument( - "--uvicorn-log-level", - type=str, - default="info", - choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], - help="Log level for uvicorn.") - parser.add_argument("--disable-uvicorn-access-log", - action="store_true", - help="Disable uvicorn access log.") - parser.add_argument("--allow-credentials", - action="store_true", - help="Allow credentials.") - parser.add_argument("--allowed-origins", - type=json.loads, - default=["*"], - help="Allowed origins.") - parser.add_argument("--allowed-methods", - type=json.loads, - default=["*"], - help="Allowed methods.") - parser.add_argument("--allowed-headers", - type=json.loads, - default=["*"], - help="Allowed headers.") - parser.add_argument("--api-key", - type=optional_type(str), - default=None, - help="If provided, the server will require this key " - "to be presented in the header.") - parser.add_argument( - "--lora-modules", - type=optional_type(str), - default=None, - nargs='+', - action=LoRAParserAction, - help="LoRA module configurations in either 'name=path' format" - "or JSON format. " - "Example (old format): ``'name=path'`` " - "Example (new format): " - "``{\"name\": \"name\", \"path\": \"lora_path\", " - "\"base_model_name\": \"id\"}``") - parser.add_argument( - "--prompt-adapters", - type=optional_type(str), - default=None, - nargs='+', - action=PromptAdapterParserAction, - help="Prompt adapter configurations in the format name=path. " - "Multiple adapters can be specified.") - parser.add_argument("--chat-template", - type=optional_type(str), - default=None, - help="The file path to the chat template, " - "or the template in single-line form " - "for the specified model.") - parser.add_argument( - '--chat-template-content-format', - type=str, - default="auto", - choices=get_args(ChatTemplateContentFormatOption), - help='The format to render message content within a chat template.' - '\n\n' - '* "string" will render the content as a string. ' - 'Example: ``"Hello World"``\n' - '* "openai" will render the content as a list of dictionaries, ' - 'similar to OpenAI schema. ' - 'Example: ``[{"type": "text", "text": "Hello world!"}]``') - parser.add_argument("--response-role", - type=optional_type(str), - default="assistant", - help="The role name to return if " - "``request.add_generation_prompt=true``.") - parser.add_argument("--ssl-keyfile", - type=optional_type(str), - default=None, - help="The file path to the SSL key file.") - parser.add_argument("--ssl-certfile", - type=optional_type(str), - default=None, - help="The file path to the SSL cert file.") - parser.add_argument("--ssl-ca-certs", - type=optional_type(str), - default=None, - help="The CA certificates file.") - parser.add_argument( - "--enable-ssl-refresh", - action="store_true", - default=False, - help="Refresh SSL Context when SSL certificate files change") - parser.add_argument( - "--ssl-cert-reqs", - type=int, - default=int(ssl.CERT_NONE), - help="Whether client certificate is required (see stdlib ssl module's)." - ) - parser.add_argument( - "--root-path", - type=optional_type(str), - default=None, - help="FastAPI root_path when app is behind a path based routing proxy." - ) - parser.add_argument( - "--middleware", - type=optional_type(str), - action="append", - default=[], - help="Additional ASGI middleware to apply to the app. " - "We accept multiple --middleware arguments. " - "The value should be an import path. " - "If a function is provided, vLLM will add it to the server " - "using ``@app.middleware('http')``. " - "If a class is provided, vLLM will add it to the server " - "using ``app.add_middleware()``. ") - parser.add_argument( - "--return-tokens-as-token-ids", - action="store_true", - help="When ``--max-logprobs`` is specified, represents single tokens " - " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.") - parser.add_argument( - "--disable-frontend-multiprocessing", - action="store_true", - help="If specified, will run the OpenAI frontend server in the same " - "process as the model serving engine.") - parser.add_argument( - "--enable-request-id-headers", - action="store_true", - help="If specified, API server will add X-Request-Id header to " - "responses.") - parser.add_argument( - "--enable-auto-tool-choice", - action="store_true", - default=False, - help="Enable auto tool choice for supported models. Use " - "``--tool-call-parser`` to specify which parser to use.") - parser.add_argument( - "--expand-tools-even-if-tool-choice-none", - action="store_true", - default=False, - deprecated=True, - help="Include tool definitions in prompts " - "even when tool_choice='none'. " - "This is a transitional option that will be removed in v0.10.0. " - "In v0.10.0, tool definitions will always be included regardless of " - "tool_choice setting. Use this flag now to test the new behavior " - "before the breaking change.") - - valid_tool_parsers = ToolParserManager.tool_parsers.keys() - parser.add_argument( - "--tool-call-parser", - type=str, - metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " - "--tool-parser-plugin", - default=None, - help= - "Select the tool call parser depending on the model that you're using." - " This is used to parse the model-generated tool call into OpenAI API " - "format. Required for ``--enable-auto-tool-choice``.") + # Special case: allowed_origins, allowed_methods, allowed_headers all + # need json.loads type + # Should also remove nargs + frontend_kwargs["allowed_origins"]["type"] = json.loads + frontend_kwargs["allowed_methods"]["type"] = json.loads + frontend_kwargs["allowed_headers"]["type"] = json.loads + del frontend_kwargs["allowed_origins"]["nargs"] + del frontend_kwargs["allowed_methods"]["nargs"] + del frontend_kwargs["allowed_headers"]["nargs"] - parser.add_argument( - "--tool-parser-plugin", - type=str, - default="", - help= - "Special the tool parser plugin write to parse the model-generated tool" - " into OpenAI API format, the name register in this plugin can be used " - "in ``--tool-call-parser``.") + # Special case: LoRA modules need custom parser action and + # optional_type(str) + frontend_kwargs["lora_modules"]["type"] = optional_type(str) + frontend_kwargs["lora_modules"]["action"] = LoRAParserAction - parser.add_argument( - "--log-config-file", - type=str, - default=envs.VLLM_LOGGING_CONFIG_PATH, - help="Path to logging config JSON file for both vllm and uvicorn", - ) + # Special case: Middleware needs append action + frontend_kwargs["middleware"]["action"] = "append" + frontend_kwargs["middleware"]["type"] = str + if "nargs" in frontend_kwargs["middleware"]: + del frontend_kwargs["middleware"]["nargs"] + frontend_kwargs["middleware"]["default"] = [] - parser = AsyncEngineArgs.add_cli_args(parser) + # Special case: Tool call parser shows built-in options. + valid_tool_parsers = list(ToolParserManager.tool_parsers.keys()) + frontend_kwargs["tool_call_parser"]["choices"] = valid_tool_parsers - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - ' The default of None means unlimited.') + frontend_group = parser.add_argument_group( + title="Frontend", + description=FrontendArgs.__doc__, + ) + for key, value in frontend_kwargs.items(): + frontend_group.add_argument(f"--{key.replace('_', '-')}", **value) + + return parser + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Create the CLI argument parser used by the OpenAI API server. + + We rely on the helper methods of `FrontendArgs` and `AsyncEngineArgs` to + register all arguments instead of manually enumerating them here. This + avoids code duplication and keeps the argument definitions in one place. + """ + parser.add_argument("model_tag", + type=str, + nargs="?", + help="The model tag to serve " + "(optional if specified in config)") parser.add_argument( - "--disable-fastapi-docs", - action='store_true', - default=False, - help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint." - ) - parser.add_argument( - "--enable-prompt-tokens-details", - action='store_true', - default=False, - help="If set to True, enable prompt_tokens_details in usage.") - parser.add_argument( - "--enable-force-include-usage", - action='store_true', + "--headless", + action="store_true", default=False, - help="If set to True, including usage on every request.") + help="Run in headless mode. See multi-node data parallel " + "documentation for more details.") + parser.add_argument("--api-server-count", + "-asc", + type=int, + default=1, + help="How many API server processes to run.") parser.add_argument( - "--enable-server-load-tracking", - action='store_true', - default=False, - help= - "If set to True, enable tracking server_load_metrics in the app state." - ) + "--config", + help="Read CLI options from a config file. " + "Must be a YAML with the following options: " + "https://docs.vllm.ai/en/latest/configuration/serve_args.html") + parser = FrontendArgs.add_cli_args(parser) + parser = AsyncEngineArgs.add_cli_args(parser) return parser @@ -311,9 +250,6 @@ def validate_parsed_serve_args(args: argparse.Namespace): if args.enable_auto_tool_choice and not args.tool_call_parser: raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") - if args.enable_prompt_embeds and args.enable_prompt_adapter: - raise ValueError( - "Cannot use prompt embeds and prompt adapter at the same time.") def log_non_default_args(args: argparse.Namespace): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d4db238f456e..6c6ec207a3ca 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,6 +11,18 @@ import regex as re import torch from fastapi import HTTPException, UploadFile +# yapf: disable +from openai.types.chat.chat_completion_audio import ( + ChatCompletionAudio as OpenAIChatCompletionAudio) +from openai.types.chat.chat_completion_message import ( + Annotation as OpenAIAnnotation) +# yapf: enable +from openai.types.responses import (ResponseInputParam, ResponseOutputItem, + ResponseOutputMessage, ResponsePrompt, + ResponseStatus, ResponseTextConfig) +from openai.types.responses.response import ToolChoice +from openai.types.responses.tool import Tool +from openai.types.shared import Metadata, Reasoning from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) from typing_extensions import TypeAlias @@ -18,6 +30,8 @@ from vllm import envs from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, random_tool_call_id) +from vllm.entrypoints.score_utils import (ScoreContentPartParam, + ScoreMultiModalParam) from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, @@ -220,6 +234,146 @@ def get_logits_processors(processors: Optional[LogitsProcessors], return None +class ResponsesRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/responses/create + background: Optional[bool] = False + include: Optional[list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ]] = None + input: Union[str, ResponseInputParam] + instructions: Optional[str] = None + max_output_tokens: Optional[int] = None + max_tool_calls: Optional[int] = None + metadata: Optional[Metadata] = None + model: Optional[str] = None + parallel_tool_calls: Optional[bool] = True + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", + "priority"] = "auto" + store: Optional[bool] = True + stream: Optional[bool] = False + temperature: Optional[float] = None + text: Optional[ResponseTextConfig] = None + tool_choice: ToolChoice = "auto" + tools: list[Tool] = Field(default_factory=list) + top_logprobs: Optional[int] = 0 + top_p: Optional[float] = None + truncation: Optional[Literal["auto", "disabled"]] = "disabled" + user: Optional[str] = None + + # --8<-- [start:responses-extra-params] + request_id: str = Field( + default_factory=lambda: f"resp_{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response."), + ) + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling."), + ) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit). Not supported by vLLM engine V0.")) + # --8<-- [end:responses-extra-params] + + _DEFAULT_SAMPLING_PARAMS = { + "temperature": 1.0, + "top_p": 1.0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + if self.max_output_tokens is None: + max_tokens = default_max_tokens + else: + max_tokens = min(self.max_output_tokens, default_max_tokens) + + default_sampling_params = default_sampling_params or {} + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + + # Structured output + guided_decoding = None + if self.text is not None and self.text.format is not None: + response_format = self.text.format + if response_format.type == "json_schema": + guided_decoding = GuidedDecodingParams.from_optional( + json=response_format.schema_) + elif response_format.type == "json_object": + raise NotImplementedError("json_object is not supported") + + # TODO: add more parameters + return SamplingParams.from_optional( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + logprobs=self.top_logprobs, + output_kind=(RequestOutputKind.DELTA + if self.stream else RequestOutputKind.FINAL_ONLY), + guided_decoding=guided_decoding, + ) + + @model_validator(mode="before") + def validate_background(cls, data): + if not data.get("background"): + return data + if not data.get("store", True): + raise ValueError( + "background can only be used when `store` is true") + return data + + @model_validator(mode="before") + def validate_prompt(cls, data): + if data.get("prompt") is not None: + raise ValueError("prompt template is not supported") + return data + + @model_validator(mode="before") + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -583,6 +737,24 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: "required": ["name", "parameters"] } + def get_tool_schema_defs( + tools: list[ChatCompletionToolsParam]) -> dict: + all_defs = dict[str, dict[str, Any]]() + for tool in tools: + if tool.function.parameters is None: + continue + defs = tool.function.parameters.pop("$defs", {}) + for def_name, def_schema in defs.items(): + if def_name in all_defs and all_defs[ + def_name] != def_schema: + raise ValueError( + f"Tool definition '{def_name}' has " + "multiple schemas, which is not " + "supported.") + else: + all_defs[def_name] = def_schema + return all_defs + json_schema = { "type": "array", "minItems": 1, @@ -591,6 +763,9 @@ def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: "anyOf": [get_tool_schema(tool) for tool in self.tools] } } + json_schema_defs = get_tool_schema_defs(self.tools) + if json_schema_defs: + json_schema["$defs"] = json_schema_defs return json_schema return None @@ -666,7 +841,7 @@ def check_tool_usage(cls, data): return data # if "tool_choice" is specified -- validation - if "tool_choice" in data: + if "tool_choice" in data and data["tool_choice"] is not None: # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: @@ -678,7 +853,7 @@ def check_tool_usage(cls, data): if data["tool_choice"] not in [ "auto", "required" ] and not isinstance(data["tool_choice"], dict): - raise NotImplementedError( + raise ValueError( f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ 'Only named tools, "none", "auto" or "required" '\ 'are supported.' @@ -851,6 +1026,16 @@ class CompletionRequest(OpenAIBaseModel): " as strings of the form 'token_id:{token_id}' so that tokens " "that are not JSON-encodable can be identified.")) + cache_salt: Optional[str] = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit). Not supported by vLLM engine V0.")) + kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters used for disaggregated serving.") @@ -1027,6 +1212,20 @@ def validate_prompt_and_prompt_embeds(cls, data): "At least one of `prompt` or `prompt_embeds` must be set.") return data + @model_validator(mode="before") + @classmethod + def check_cache_salt_support(cls, data): + if data.get("cache_salt") is not None: + if not envs.VLLM_USE_V1: + raise ValueError( + "Parameter 'cache_salt' is not supported with " + "this instance of vLLM, which uses engine V0.") + if not isinstance(data["cache_salt"], + str) or not data["cache_salt"]: + raise ValueError("Parameter 'cache_salt' must be a " + "non-empty string if provided.") + return data + class EmbeddingCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -1038,10 +1237,6 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): user: Optional[str] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:embedding-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:embedding-pooling-params] - # --8<-- [start:embedding-extra-params] add_special_tokens: bool = Field( default=True, @@ -1060,8 +1255,7 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): # --8<-- [end:embedding-extra-params] def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1073,10 +1267,6 @@ class EmbeddingChatRequest(OpenAIBaseModel): user: Optional[str] = None truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:chat-embedding-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:chat-embedding-pooling-params] - # --8<-- [start:chat-embedding-extra-params] add_special_tokens: bool = Field( default=False, @@ -1124,8 +1314,7 @@ def check_generation_prompt(cls, data): return data def to_pooling_params(self): - return PoolingParams(dimensions=self.dimensions, - additional_data=self.additional_data) + return PoolingParams(dimensions=self.dimensions) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1137,15 +1326,17 @@ def to_pooling_params(self): class ScoreRequest(OpenAIBaseModel): model: Optional[str] = None - text_1: Union[list[str], str] - text_2: Union[list[str], str] + text_1: Union[list[str], str, ScoreMultiModalParam] + text_2: Union[list[str], str, ScoreMultiModalParam] truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:score-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:score-pooling-params] - # --8<-- [start:score-extra-params] + + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + priority: int = Field( default=0, description=( @@ -1156,23 +1347,24 @@ class ScoreRequest(OpenAIBaseModel): # --8<-- [end:score-extra-params] - def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data) + def to_pooling_params(self): + return PoolingParams() class RerankRequest(OpenAIBaseModel): model: Optional[str] = None - query: str - documents: list[str] + query: Union[str, ScoreMultiModalParam] + documents: Union[list[str], ScoreMultiModalParam] top_n: int = Field(default_factory=lambda: 0) truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None - # --8<-- [start:rerank-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:rerank-pooling-params] - # --8<-- [start:rerank-extra-params] + + mm_processor_kwargs: Optional[dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the HF processor."), + ) + priority: int = Field( default=0, description=( @@ -1183,13 +1375,13 @@ class RerankRequest(OpenAIBaseModel): # --8<-- [end:rerank-extra-params] - def to_pooling_params(self, *, use_cross_encoder: bool = False): - return PoolingParams(use_cross_encoder=use_cross_encoder, - additional_data=self.additional_data) + def to_pooling_params(self): + return PoolingParams() class RerankDocument(BaseModel): - text: str + text: Optional[str] = None + multi_modal: Optional[ScoreContentPartParam] = None class RerankResult(BaseModel): @@ -1234,11 +1426,16 @@ class CompletionResponseChoice(OpenAIBaseModel): class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: str = "text_completion" + object: Literal["text_completion"] = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseChoice] + service_tier: Optional[Literal["auto", "default", "flex", "scale", + "priority"]] = None + system_fingerprint: Optional[str] = None usage: UsageInfo + + # vLLM-specific fields that are not in OpenAI spec kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1317,10 +1514,6 @@ class ClassificationRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[int] = None user: Optional[str] = None - # --8<-- [start:classification-pooling-params] - additional_data: Optional[Any] = None - # --8<-- [end:classification-pooling-params] - # --8<-- [start:classification-extra-params] priority: int = Field( default=0, @@ -1333,7 +1526,7 @@ class ClassificationRequest(OpenAIBaseModel): # --8<-- [end:classification-extra-params] def to_pooling_params(self): - return PoolingParams(additional_data=self.additional_data) + return PoolingParams() class ClassificationData(OpenAIBaseModel): @@ -1390,10 +1583,16 @@ class ExtractedToolCallInformation(BaseModel): class ChatMessage(OpenAIBaseModel): role: str - reasoning_content: Optional[str] = None content: Optional[str] = None + refusal: Optional[str] = None + annotations: Optional[OpenAIAnnotation] = None + audio: Optional[OpenAIChatCompletionAudio] = None + function_call: Optional[FunctionCall] = None tool_calls: list[ToolCall] = Field(default_factory=list) + # vLLM-specific fields that are not in OpenAI spec + reasoning_content: Optional[str] = None + class ChatCompletionLogProb(OpenAIBaseModel): token: str @@ -1428,7 +1627,12 @@ class ChatCompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] + service_tier: Optional[Literal["auto", "default", "flex", "scale", + "priority"]] = None + system_fingerprint: Optional[str] = None usage: UsageInfo + + # vLLM-specific fields that are not in OpenAI spec prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, description="KVTransfer parameters.") @@ -1473,6 +1677,83 @@ class TranscriptionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class ResponseReasoningItem(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"rs_{random_uuid()}") + text: str + summary: list = Field(default_factory=list) + type: Literal["reasoning"] = "reasoning" + encrypted_content: Optional[str] = None + status: Optional[Literal["in_progress", "completed", "incomplete"]] + + +class ResponsesResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") + created_at: int = Field(default_factory=lambda: int(time.time())) + # error: Optional[ResponseError] = None + # incomplete_details: Optional[IncompleteDetails] = None + instructions: Optional[str] = None + metadata: Optional[Metadata] = None + model: str + object: Literal["response"] = "response" + output: list[Union[ResponseOutputMessage, ResponseReasoningItem]] + parallel_tool_calls: bool + temperature: float + tool_choice: ToolChoice + tools: list[Tool] + top_p: float + background: bool + max_output_tokens: int + max_tool_calls: Optional[int] = None + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] + status: ResponseStatus + text: Optional[ResponseTextConfig] = None + top_logprobs: int + truncation: Literal["auto", "disabled"] + usage: Optional[UsageInfo] = None + user: Optional[str] = None + + @classmethod + def from_request( + cls, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + created_time: int, + output: list[ResponseOutputItem], + status: ResponseStatus, + usage: Optional[UsageInfo] = None, + ) -> "ResponsesResponse": + return cls( + id=request.request_id, + created_at=created_time, + instructions=request.instructions, + metadata=request.metadata, + model=model_name, + output=output, + parallel_tool_calls=request.parallel_tool_calls, + temperature=sampling_params.temperature, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=sampling_params.top_p, + background=request.background, + max_output_tokens=sampling_params.max_tokens, + max_tool_calls=request.max_tool_calls, + previous_response_id=request.previous_response_id, + prompt=request.prompt, + reasoning=request.reasoning, + service_tier=request.service_tier, + status=status, + text=request.text, + top_logprobs=sampling_params.logprobs, + truncation=request.truncation, + user=request.user, + usage=usage, + ) + + BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest] @@ -1648,6 +1929,16 @@ class DetokenizeResponse(OpenAIBaseModel): prompt: str +class TokenizerInfoResponse(OpenAIBaseModel): + """ + Response containing tokenizer configuration + equivalent to tokenizer_config.json + """ + + model_config = ConfigDict(extra="allow") + tokenizer_class: str + + class LoadLoRAAdapterRequest(BaseModel): lora_name: str lora_path: str @@ -1712,7 +2003,7 @@ class TranscriptionRequest(OpenAIBaseModel): """ stream: Optional[bool] = False - """When set, it will enable output to be streamed in a similar fashion + """When set, it will enable output to be streamed in a similar fashion as the Chat Completion endpoint. """ # --8<-- [start:transcription-extra-params] @@ -1974,9 +2265,9 @@ class TranslationRequest(OpenAIBaseModel): """ stream: Optional[bool] = False - """Custom field not present in the original OpenAI definition. When set, + """Custom field not present in the original OpenAI definition. When set, it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. + Completion endpoint. """ # Flattened stream option to simplify form data. stream_include_usage: Optional[bool] = False diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index e112e2f893a0..ef5bf6f9a812 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -337,7 +337,6 @@ async def main(args): model_config=model_config, base_model_paths=base_model_paths, lora_modules=None, - prompt_adapters=None, ) openai_serving_chat = OpenAIServingChat( engine, @@ -348,7 +347,7 @@ async def main(args): chat_template=None, chat_template_content_format="auto", enable_prompt_tokens_details=args.enable_prompt_tokens_details, - ) if model_config.runner_type == "generate" else None + ) if "generate" in model_config.supported_tasks else None openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, @@ -356,17 +355,19 @@ async def main(args): request_logger=request_logger, chat_template=None, chat_template_content_format="auto", - ) if model_config.task == "embed" else None + ) if "embed" in model_config.supported_tasks else None - enable_serving_reranking = (model_config.task == "classify" and getattr( - model_config.hf_config, "num_labels", 0) == 1) + enable_serving_reranking = ("classify" in model_config.supported_tasks + and getattr(model_config.hf_config, + "num_labels", 0) == 1) - openai_serving_scores = (ServingScores( + openai_serving_scores = ServingScores( engine, model_config, openai_serving_models, request_logger=request_logger, - ) if (model_config.task == "embed" or enable_serving_reranking) else None) + ) if ("embed" in model_config.supported_tasks + or enable_serving_reranking) else None tracker = BatchProgressTracker() logger.info("Reading batch from %s...", args.input_file) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 86895cf7dd77..33d80743420c 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,7 +63,6 @@ def __init__( return_tokens_as_token_ids: bool = False, reasoning_parser: str = "", enable_auto_tools: bool = False, - expand_tools_even_if_tool_choice_none: bool = False, tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, @@ -112,8 +111,6 @@ def __init__( raise TypeError("Error: --enable-auto-tool-choice requires " f"tool_parser:'{tool_parser}' which has not " "been registered") from e - self.expand_tools_even_if_tool_choice_none = ( - expand_tools_even_if_tool_choice_none) self.enable_prompt_tokens_details = enable_prompt_tokens_details self.enable_force_include_usage = enable_force_include_usage @@ -150,11 +147,8 @@ async def create_chat_completion( raise self.engine_client.dead_error try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request, - supports_default_mm_loras=True) + lora_request = self._maybe_get_adapters( + request, supports_default_mm_loras=True) model_name = self._get_model_name(request.model, lora_request) @@ -182,20 +176,6 @@ async def create_chat_completion( if request.tools is None: tool_dicts = None - elif (request.tool_choice == "none" - and not self.expand_tools_even_if_tool_choice_none): - if len(request.tools) > 0: - logger.warning_once( - "Tools are specified but tool_choice is set to 'none' " - "and --expand-tools-even-if-tool-choice-none is not " - "enabled. Tool definitions will be excluded from the " - "prompt. This behavior will change in vLLM v0.10 where " - "tool definitions will be included by default even " - "with tool_choice='none'. To adopt the new behavior " - "now, use --expand-tools-even-if-tool-choice-none. " - "To suppress this warning, either remove tools from " - "the request or set tool_choice to a different value.") - tool_dicts = None else: tool_dicts = [tool.model_dump() for tool in request.tools] @@ -256,8 +236,7 @@ async def create_chat_completion( self._log_inputs(request_id, request_prompts[i], params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -276,7 +255,6 @@ async def create_chat_completion( request_id, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, priority=request.priority, ) @@ -630,8 +608,13 @@ async def chat_completion_stream_generator( previous_text = previous_texts[i] previous_token_ids = all_previous_token_ids[i] current_text = previous_text + delta_text - current_token_ids = previous_token_ids + list( - output.token_ids) + + # avoid the None + list error. + if previous_token_ids: + current_token_ids = previous_token_ids + list( + output.token_ids) + else: + current_token_ids = list(output.token_ids) # handle streaming deltas for tools with named tool_choice if tool_choice_function_name: @@ -1050,6 +1033,7 @@ async def chat_completion_full_generator( message = ChatMessage( role=role, content="", + reasoning_content=reasoning_content, tool_calls=[ tool_call_class(function=FunctionCall( name=tool_call.name, @@ -1093,9 +1077,17 @@ async def chat_completion_full_generator( else: # FOR NOW make it a chat message; we will have to detect # the type to make it later. + ret_content = content + + # try to use content return from tool parser first, + # tool parser may do some modify for the content. + if (tool_call_info.content + and len(tool_call_info.content) > 0): + ret_content = tool_call_info.content + message = ChatMessage(role=role, reasoning_content=reasoning_content, - content=content) + content=ret_content) # undetermined case that is still important to handle else: diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index 3ac4f01ea602..377f7f684717 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -6,6 +6,7 @@ import numpy as np from fastapi import Request +from typing_extensions import override from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient @@ -21,12 +22,14 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import ClassificationOutput, PoolingRequestOutput +from vllm.pooling_params import PoolingParams logger = init_logger(__name__) class ClassificationMixin(OpenAIServing): + @override async def _preprocess( self, ctx: ServeContext, @@ -46,19 +49,11 @@ async def _preprocess( return None try: - ( - ctx.lora_request, - ctx.prompt_adapter_request, - ) = self._maybe_get_adapters(ctx.request) + ctx.lora_request = self._maybe_get_adapters(ctx.request) ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.lora_request) - if ctx.prompt_adapter_request is not None: - raise NotImplementedError( - "Prompt adapter is not supported for classification models" - ) - ( ctx.request_prompts, ctx.engine_prompts, @@ -75,6 +70,7 @@ async def _preprocess( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + @override def _build_response( self, ctx: ServeContext, @@ -158,3 +154,31 @@ async def create_classify( ) return await super().handle(ctx) # type: ignore + + @override + def _validate_request( + self, + ctx: ClassificationServeContext, + ) -> Optional[ErrorResponse]: + if error := super()._validate_request(ctx): + return error + + ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens + + return None + + @override + def _create_pooling_params( + self, + ctx: ClassificationServeContext, + ) -> Union[PoolingParams, ErrorResponse]: + pooling_params = super()._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params + + try: + pooling_params.verify("classify", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + + return pooling_params diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6c9c29b71445..323795ca4372 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -23,6 +23,7 @@ CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo) from vllm.entrypoints.openai.serving_engine import ( @@ -56,21 +57,28 @@ def __init__( *, request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, + enable_prompt_tokens_details: bool = False, enable_force_include_usage: bool = False, ): - super().__init__(engine_client=engine_client, - model_config=model_config, - models=models, - request_logger=request_logger, - return_tokens_as_token_ids=return_tokens_as_token_ids, - enable_force_include_usage=enable_force_include_usage) + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + ) + self.enable_prompt_tokens_details = enable_prompt_tokens_details self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) if self.default_sampling_params: source = self.model_config.generation_config source = "model" if source == "auto" else source - logger.info("Using default completion sampling params from %s: %s", - source, self.default_sampling_params) + logger.info( + "Using default completion sampling params from %s: %s", + source, + self.default_sampling_params, + ) async def create_completion( self, @@ -113,10 +121,7 @@ async def create_completion( raw_request.state.request_metadata = request_metadata try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) @@ -169,23 +174,27 @@ async def create_completion( max_model_len=self.max_model_len, request=request, input_length=input_length, - default_sampling_params=self.default_sampling_params) + default_sampling_params=self.default_sampling_params, + ) if request.use_beam_search: sampling_params = request.to_beam_search_params( max_tokens, self.default_sampling_params) else: sampling_params = request.to_sampling_params( - max_tokens, self.model_config.logits_processor_pattern, - self.default_sampling_params) + max_tokens, + self.model_config.logits_processor_pattern, + self.default_sampling_params, + ) request_id_item = f"{request_id}-{i}" - self._log_inputs(request_id_item, - request_prompts[i], - params=sampling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + self._log_inputs( + request_id_item, + request_prompts[i], + params=sampling_params, + lora_request=lora_request, + ) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) @@ -208,7 +217,6 @@ async def create_completion( sampling_params, request_id_item, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, priority=request.priority, ) @@ -242,7 +250,8 @@ async def create_completion( num_prompts=num_prompts, tokenizer=tokenizer, request_metadata=request_metadata, - enable_force_include_usage=self.enable_force_include_usage) + enable_force_include_usage=self.enable_force_include_usage, + ) # Non-streaming response final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts @@ -313,13 +322,15 @@ async def completion_stream_generator( previous_num_tokens = [0] * num_choices * num_prompts has_echoed = [False] * num_choices * num_prompts num_prompt_tokens = [0] * num_prompts + num_cached_tokens = None + first_iteration = True stream_options = request.stream_options if stream_options: - include_usage = stream_options.include_usage or \ - enable_force_include_usage - include_continuous_usage = include_usage and \ - stream_options.continuous_usage_stats + include_usage = (stream_options.include_usage + or enable_force_include_usage) + include_continuous_usage = (include_usage and + stream_options.continuous_usage_stats) else: include_usage, include_continuous_usage = False, False @@ -328,6 +339,10 @@ async def completion_stream_generator( prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs + if first_iteration: + num_cached_tokens = res.num_cached_tokens + first_iteration = False + if res.prompt is not None: prompt_text = res.prompt else: @@ -361,7 +376,8 @@ async def completion_stream_generator( # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ - *prompt_token_ids, *output.token_ids + *prompt_token_ids, + *output.token_ids, ] out_logprobs = [ *(prompt_logprobs or []), @@ -374,8 +390,8 @@ async def completion_stream_generator( delta_token_ids = output.token_ids out_logprobs = output.logprobs - if not delta_text and not delta_token_ids \ - and not previous_num_tokens[i]: + if (not delta_text and not delta_token_ids + and not previous_num_tokens[i]): # Chunked prefill case, don't return empty chunks continue @@ -411,7 +427,8 @@ async def completion_stream_generator( finish_reason=finish_reason, stop_reason=stop_reason, ) - ]) + ], + ) if include_continuous_usage: prompt_tokens = num_prompt_tokens[prompt_idx] completion_tokens = previous_num_tokens[i] @@ -429,7 +446,12 @@ async def completion_stream_generator( final_usage_info = UsageInfo( prompt_tokens=total_prompt_tokens, completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens) + total_tokens=total_prompt_tokens + total_completion_tokens, + ) + + if self.enable_prompt_tokens_details and num_cached_tokens: + final_usage_info.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=num_cached_tokens) if include_usage: final_usage_chunk = CompletionStreamResponse( @@ -439,8 +461,8 @@ async def completion_stream_generator( choices=[], usage=final_usage_info, ) - final_usage_data = (final_usage_chunk.model_dump_json( - exclude_unset=False, exclude_none=True)) + final_usage_data = final_usage_chunk.model_dump_json( + exclude_unset=False, exclude_none=True) yield f"data: {final_usage_data}\n\n" # report to FastAPI middleware aggregate usage across all choices @@ -465,8 +487,10 @@ def request_output_to_completion_response( choices: list[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 - + kv_transfer_params = None + last_final_res = None for final_res in final_res_batch: + last_final_res = final_res prompt_token_ids = final_res.prompt_token_ids assert prompt_token_ids is not None prompt_logprobs = clamp_prompt_logprobs(final_res.prompt_logprobs) @@ -535,15 +559,22 @@ def request_output_to_completion_response( total_tokens=num_prompt_tokens + num_generated_tokens, ) - request_metadata.final_usage_info = usage + if (self.enable_prompt_tokens_details and last_final_res + and last_final_res.num_cached_tokens): + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=last_final_res.num_cached_tokens) + request_metadata.final_usage_info = usage + if final_res_batch: + kv_transfer_params = final_res_batch[0].kv_transfer_params return CompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, - kv_transfer_params=final_res_batch[0].kv_transfer_params) + kv_transfer_params=kv_transfer_params, + ) def _create_completion_logprobs( self, @@ -562,8 +593,9 @@ def _create_completion_logprobs( last_token_len = 0 - should_return_as_token_id = return_as_token_id if \ - return_as_token_id is not None else self.return_tokens_as_token_ids + should_return_as_token_id = (return_as_token_id + if return_as_token_id is not None else + self.return_tokens_as_token_ids) for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] if step_top_logprobs is None: @@ -595,10 +627,12 @@ def _create_completion_logprobs( out_top_logprobs.append({ # Convert float("-inf") to the # JSON-serializable float that OpenAI uses - self._get_decoded_token(top_lp[1], - top_lp[0], - tokenizer, - return_as_token_id=should_return_as_token_id): + self._get_decoded_token( + top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=should_return_as_token_id, + ): max(top_lp[1].logprob, -9999.0) for i, top_lp in enumerate(step_top_logprobs.items()) if num_output_top_logprobs >= i diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e87decfe636a..697f43c018b2 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -24,6 +24,7 @@ from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) +from vllm.pooling_params import PoolingParams logger = init_logger(__name__) @@ -45,24 +46,18 @@ def _get_embedding( class EmbeddingMixin(OpenAIServing): + @override async def _preprocess( self, ctx: ServeContext, ) -> Optional[ErrorResponse]: ctx = cast(EmbeddingServeContext, ctx) try: - ( - ctx.lora_request, - ctx.prompt_adapter_request, - ) = self._maybe_get_adapters(ctx.request) + ctx.lora_request = self._maybe_get_adapters(ctx.request) tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request ) - if ctx.prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for embedding models") - if isinstance(ctx.request, EmbeddingChatRequest): ( _, @@ -97,6 +92,7 @@ async def _preprocess( logger.exception("Error in preprocessing prompt inputs") return self.create_error_response(str(e)) + @override def _build_response( self, ctx: ServeContext, @@ -191,11 +187,20 @@ def _validate_request( ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens - pooling_params = ctx.request.to_pooling_params() + return None + + @override + def _create_pooling_params( + self, + ctx: ServeContext[EmbeddingRequest], + ) -> Union[PoolingParams, ErrorResponse]: + pooling_params = super()._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params try: - pooling_params.verify(self.model_config) + pooling_params.verify("embed", self.model_config) except ValueError as e: return self.create_error_response(str(e)) - return None + return pooling_params diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index c267c04e0cb5..edc366f9b8a8 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio import base64 import io import json import sys import time -from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping, - Sequence) -from concurrent.futures.thread import ThreadPoolExecutor +from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence +from concurrent.futures import ThreadPoolExecutor from http import HTTPStatus from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional, TypeVar, Union, cast, overload) @@ -18,11 +18,6 @@ from starlette.datastructures import Headers from typing_extensions import TypeIs -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - if sys.version_info >= (3, 12): from typing import TypedDict else: @@ -53,7 +48,8 @@ EmbeddingRequest, EmbeddingResponse, ErrorResponse, PoolingResponse, RerankRequest, - ScoreRequest, ScoreResponse, + ResponsesRequest, ScoreRequest, + ScoreResponse, TokenizeChatRequest, TokenizeCompletionRequest, TokenizeResponse, @@ -72,14 +68,13 @@ MultiModalDataDict) from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.sequence import Logprob, PromptLogprobs from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import (is_list_of, make_async, merge_async_iterators, - random_uuid) +from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of, + merge_async_iterators, random_uuid) logger = init_logger(__name__) @@ -91,7 +86,8 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest, TokenizeChatRequest] SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest] -AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest] +AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest, + ResponsesRequest] AnyResponse = Union[ CompletionResponse, @@ -164,7 +160,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel, request_id: str created_time: int = Field(default_factory=lambda: int(time.time())) lora_request: Optional[LoRARequest] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None # Shared across most requests tokenizer: Optional[AnyTokenizer] = None @@ -224,11 +219,19 @@ def __init__( self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - self._tokenize_prompt_input_async = make_async( - self._tokenize_prompt_input, executor=self._tokenizer_executor) - self._tokenize_prompt_input_or_inputs_async = make_async( - self._tokenize_prompt_input_or_inputs, - executor=self._tokenizer_executor) + self._async_tokenizer_pool: dict[AnyTokenizer, + AsyncMicrobatchTokenizer] = {} + + def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer: + """ + Return (and cache) an `AsyncMicrobatchTokenizer` bound to the + given tokenizer. + """ + async_tokenizer = self._async_tokenizer_pool.get(tokenizer) + if async_tokenizer is None: + async_tokenizer = AsyncMicrobatchTokenizer(tokenizer) + self._async_tokenizer_pool[tokenizer] = async_tokenizer + return async_tokenizer async def _preprocess( self, @@ -300,6 +303,16 @@ def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]: " Please, select a smaller truncation size.") return None + def _create_pooling_params( + self, + ctx: ServeContext, + ) -> Union[PoolingParams, ErrorResponse]: + if not hasattr(ctx.request, "to_pooling_params"): + return self.create_error_response( + "Request type does not support pooling parameters") + + return ctx.request.to_pooling_params() + async def _prepare_generators( self, ctx: ServeContext, @@ -313,11 +326,9 @@ async def _prepare_generators( trace_headers = (None if ctx.raw_request is None else await self._get_trace_headers(ctx.raw_request.headers)) - if not hasattr(ctx.request, "to_pooling_params"): - return self.create_error_response( - "Request type does not support pooling parameters") - - pooling_params = ctx.request.to_pooling_params() + pooling_params = self._create_pooling_params(ctx) + if isinstance(pooling_params, ErrorResponse): + return pooling_params if ctx.engine_prompts is None: return self.create_error_response( @@ -330,12 +341,10 @@ async def _prepare_generators( return self.create_error_response( "Request prompts not available") - self._log_inputs( - request_id_item, - ctx.request_prompts[i], - params=pooling_params, - lora_request=ctx.lora_request, - prompt_adapter_request=ctx.prompt_adapter_request) + self._log_inputs(request_id_item, + ctx.request_prompts[i], + params=pooling_params, + lora_request=ctx.lora_request) # Mypy has an existing bug related to inferring the variance of # TypedDicts with `builtins.enumerate`: @@ -437,11 +446,6 @@ async def _check_model( if isinstance(load_result, ErrorResponse) and \ load_result.code == HTTPStatus.BAD_REQUEST.value: error_response = load_result - if request.model in [ - prompt_adapter.prompt_adapter_name - for prompt_adapter in self.models.prompt_adapter_requests - ]: - return None return error_response or self.create_error_response( message=f"The model `{request.model}` does not exist.", @@ -476,25 +480,21 @@ def _maybe_get_adapters( self, request: AnyRequest, supports_default_mm_loras: bool = False, - ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[ - None, PromptAdapterRequest]]: + ) -> Optional[LoRARequest]: if request.model in self.models.lora_requests: - return self.models.lora_requests[request.model], None + return self.models.lora_requests[request.model] # Currently only support default modality specific loras # if we have exactly one lora matched on the request. if supports_default_mm_loras: default_mm_lora = self._get_active_default_mm_loras(request) if default_mm_lora is not None: - return default_mm_lora, None + return default_mm_lora if self._is_model_supported(request.model): - return None, None + return None - for prompt_adapter in self.models.prompt_adapter_requests: - if request.model == prompt_adapter.prompt_adapter_name: - return None, prompt_adapter # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") @@ -516,7 +516,7 @@ def _get_message_types(self, request: AnyRequest) -> set[str]: message_types.add(content_dict["type"].split("_")[0]) return message_types - def _normalize_prompt_text_to_input( + async def _normalize_prompt_text_to_input( self, request: AnyRequest, tokenizer: AnyTokenizer, @@ -524,38 +524,44 @@ def _normalize_prompt_text_to_input( truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]], add_special_tokens: bool, ) -> TextTokensPrompt: + async_tokenizer = self._get_async_tokenizer(tokenizer) + if (self.model_config.encoder_config is not None and self.model_config.encoder_config.get( "do_lower_case", False)): prompt = prompt.lower() if truncate_prompt_tokens is None: - encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + encoded = await async_tokenizer( + prompt, add_special_tokens=add_special_tokens) elif truncate_prompt_tokens < 0: # Negative means we cap at the model's max length - encoded = tokenizer(prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=self.max_model_len) + encoded = await async_tokenizer( + prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=self.max_model_len) else: - encoded = tokenizer(prompt, - add_special_tokens=add_special_tokens, - truncation=True, - max_length=truncate_prompt_tokens) + encoded = await async_tokenizer( + prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) input_ids = encoded.input_ids - input_text = prompt return self._validate_input(request, input_ids, input_text) - def _normalize_prompt_tokens_to_input( + async def _normalize_prompt_tokens_to_input( self, request: AnyRequest, tokenizer: AnyTokenizer, prompt_ids: list[int], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], ) -> TextTokensPrompt: + async_tokenizer = self._get_async_tokenizer(tokenizer) + if truncate_prompt_tokens is None: input_ids = prompt_ids elif truncate_prompt_tokens < 0: @@ -563,7 +569,7 @@ def _normalize_prompt_tokens_to_input( else: input_ids = prompt_ids[-truncate_prompt_tokens:] - input_text = tokenizer.decode(input_ids) + input_text = await async_tokenizer.decode(input_ids) return self._validate_input(request, input_ids, input_text) @@ -627,7 +633,7 @@ def _validate_input( return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) - def _tokenize_prompt_input( + async def _tokenize_prompt_input_async( self, request: AnyRequest, tokenizer: AnyTokenizer, @@ -640,23 +646,24 @@ def _tokenize_prompt_input( [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] that assumes single input. """ - return next( - self._tokenize_prompt_inputs( + async for result in self._tokenize_prompt_inputs_async( request, tokenizer, - [prompt_input], + [prompt_input], truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, - )) + ): + return result + raise ValueError("No results yielded from tokenization") - def _tokenize_prompt_inputs( + async def _tokenize_prompt_inputs_async( self, request: AnyRequest, tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, list[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None, add_special_tokens: bool = True, - ) -> Iterator[TextTokensPrompt]: + ) -> AsyncGenerator[TextTokensPrompt, None]: """ A simpler implementation of [`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs] @@ -664,7 +671,7 @@ def _tokenize_prompt_inputs( """ for text in prompt_inputs: if isinstance(text, str): - yield self._normalize_prompt_text_to_input( + yield await self._normalize_prompt_text_to_input( request, tokenizer, prompt=text, @@ -672,14 +679,14 @@ def _tokenize_prompt_inputs( add_special_tokens=add_special_tokens, ) else: - yield self._normalize_prompt_tokens_to_input( + yield await self._normalize_prompt_tokens_to_input( request, tokenizer, prompt_ids=text, truncate_prompt_tokens=truncate_prompt_tokens, ) - def _tokenize_prompt_input_or_inputs( + async def _tokenize_prompt_input_or_inputs_async( self, request: AnyRequest, tokenizer: AnyTokenizer, @@ -713,21 +720,31 @@ def _tokenize_prompt_input_or_inputs( # VSCode Pyright extension should still work properly # "is False" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - inputs_text.extend([ - self._normalize_prompt_text_to_input( - request, - tokenizer, - prompt=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens, - add_special_tokens=add_special_tokens) - if prompt_input["is_tokens"] is False else - self._normalize_prompt_tokens_to_input( - request, - tokenizer, - prompt_ids=prompt_input["content"], - truncate_prompt_tokens=truncate_prompt_tokens) - for prompt_input in parse_and_batch_prompt(input_or_inputs) - ]) + + # Parse and batch the input prompts + batch_inputs = parse_and_batch_prompt(input_or_inputs) + + # Process each input in the batch concurrently + tasks = [] + for prompt_input in batch_inputs: + if prompt_input["is_tokens"] is False: + task = self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens) + else: + task = self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens) + tasks.append(task) + + # Wait for all tokenization tasks to complete + results = await asyncio.gather(*tasks) + inputs_text.extend(results) return inputs_text, inputs_embeds @@ -789,6 +806,12 @@ async def _preprocess_completion( prompt_token_ids=request_prompt_text["prompt_token_ids"]) for request_prompt_text in request_prompts_text ] + cache_salt = request.cache_salt if ( + hasattr(request, "cache_salt") + and request.cache_salt is not None) else None + if cache_salt: + for prompt_text in engine_prompts_text: + prompt_text["cache_salt"] = cache_salt # This check is equivalent to simply checking if # `request_prompts_embeds` is empty, but it's difficult to propagate @@ -806,6 +829,9 @@ async def _preprocess_completion( prompt_embeds=request_prompt_embeds["prompt_embeds"]) for request_prompt_embeds in request_prompts_embeds ] + if cache_salt: + for prompt_embed in engine_prompts_embeds: + prompt_embed["cache_salt"] = cache_salt request_prompts = request_prompts_embeds + request_prompts_text engine_prompts = engine_prompts_embeds + engine_prompts_text @@ -813,7 +839,7 @@ async def _preprocess_completion( async def _preprocess_chat( self, - request: ChatLikeRequest, + request: Union[ChatLikeRequest, ResponsesRequest], tokenizer: AnyTokenizer, messages: list[ChatCompletionMessageParam], chat_template: Optional[str], @@ -948,7 +974,6 @@ def _log_inputs( params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if self.request_logger is None: return @@ -970,7 +995,6 @@ def _log_inputs( prompt_embeds, params=params, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, ) async def _get_trace_headers( diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index bc4f523c82e3..27614fcb4112 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import json -import pathlib from asyncio import Lock from collections import defaultdict from dataclasses import dataclass @@ -19,7 +17,6 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.utils import AtomicCounter logger = init_logger(__name__) @@ -31,12 +28,6 @@ class BaseModelPath: model_path: str -@dataclass -class PromptAdapterPath: - name: str - local_path: str - - @dataclass class LoRAModulePath: name: str @@ -60,7 +51,6 @@ def __init__( base_model_paths: list[BaseModelPath], *, lora_modules: Optional[list[LoRAModulePath]] = None, - prompt_adapters: Optional[list[PromptAdapterPath]] = None, ): super().__init__() @@ -81,20 +71,6 @@ def __init__( LoRAResolverRegistry.get_resolver(lora_resolver_name)) self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock) - self.prompt_adapter_requests = [] - if prompt_adapters is not None: - for i, prompt_adapter in enumerate(prompt_adapters, start=1): - with pathlib.Path(prompt_adapter.local_path, - "adapter_config.json").open() as f: - adapter_config = json.load(f) - num_virtual_tokens = adapter_config["num_virtual_tokens"] - self.prompt_adapter_requests.append( - PromptAdapterRequest( - prompt_adapter_name=prompt_adapter.name, - prompt_adapter_id=i, - prompt_adapter_local_path=prompt_adapter.local_path, - prompt_adapter_num_virtual_tokens=num_virtual_tokens)) - async def init_static_loras(self): """Loads all static LoRA modules. Raises if any fail to load""" @@ -141,14 +117,7 @@ async def show_available_models(self) -> ModelList: permission=[ModelPermission()]) for lora in self.lora_requests.values() ] - prompt_adapter_cards = [ - ModelCard(id=prompt_adapter.prompt_adapter_name, - root=self.base_model_paths[0].name, - permission=[ModelPermission()]) - for prompt_adapter in self.prompt_adapter_requests - ] model_cards.extend(lora_cards) - model_cards.extend(prompt_adapter_cards) return ModelList(data=model_cards) async def load_lora_adapter( diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index c2ed50d04d12..12334cdac365 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -94,17 +94,10 @@ async def create_pooling( try: truncate_prompt_tokens = _validate_truncation_size( self.max_model_len, truncate_prompt_tokens) - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) - if prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for pooling models") - if isinstance(request, PoolingChatRequest): ( _, @@ -142,14 +135,18 @@ async def create_pooling( try: pooling_params = request.to_pooling_params() + try: + pooling_params.verify("encode", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, request_prompts[i], params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request) trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py new file mode 100644 index 000000000000..64880a3a5377 --- /dev/null +++ b/vllm/entrypoints/openai/serving_responses.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import time +from collections.abc import AsyncGenerator, AsyncIterator +from http import HTTPStatus +from typing import Callable, Final, Optional, Union + +import jinja2 +from fastapi import Request +from openai.types.responses import ResponseOutputMessage, ResponseOutputText + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + ChatTemplateContentFormatOption) +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ErrorResponse, + PromptTokenUsageInfo, + RequestResponseMetadata, + ResponseReasoningItem, + ResponsesRequest, + ResponsesResponse, UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +class OpenAIServingResponses(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + chat_template_content_format: ChatTemplateContentFormatOption, + return_tokens_as_token_ids: bool = False, + reasoning_parser: str = "", + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None, + enable_prompt_tokens_details: bool = False, + enable_force_include_usage: bool = False, + ) -> None: + super().__init__( + engine_client=engine_client, + model_config=model_config, + models=models, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids, + enable_force_include_usage=enable_force_include_usage, + ) + + self.chat_template = chat_template + self.chat_template_content_format: Final = chat_template_content_format + + self.reasoning_parser: Optional[Callable[[AnyTokenizer], + ReasoningParser]] = None + if reasoning_parser: + try: + self.reasoning_parser = ( + ReasoningParserManager.get_reasoning_parser( + reasoning_parser)) + assert self.reasoning_parser is not None + except Exception as e: + raise TypeError( + f"{reasoning_parser=} has not been registered") from e + + self.enable_prompt_tokens_details = enable_prompt_tokens_details + self.enable_force_include_usage = enable_force_include_usage + self.default_sampling_params = ( + self.model_config.get_diff_sampling_param()) + if self.default_sampling_params: + source = self.model_config.generation_config + source = "model" if source == "auto" else source + logger.info("Using default chat sampling params from %s: %s", + source, self.default_sampling_params) + + # HACK(woosuk): This is a hack. We should use a better store. + # FIXME: This causes a memory leak since we never remove responses + # from the store. + self.response_store: dict[str, ResponsesResponse] = {} + self.response_store_lock = asyncio.Lock() + + # HACK(woosuk): This is a hack. We should use a better store. + # FIXME: This causes a memory leak since we never remove messages + # from the store. + self.msg_store: dict[str, list[ChatCompletionMessageParam]] = {} + + self.background_tasks: dict[str, asyncio.Task] = {} + + async def create_responses( + self, + request: ResponsesRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ResponsesResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Handle the previous response ID. + prev_response_id = request.previous_response_id + if prev_response_id is not None: + if not prev_response_id.startswith("resp_"): + return self._make_invalid_id_error(prev_response_id) + async with self.response_store_lock: + prev_response = self.response_store.get(prev_response_id) + if prev_response is None: + return self._make_not_found_error(prev_response_id) + else: + prev_response = None + # Construct the input messages. + messages = self._construct_input_messages(request, prev_response) + + try: + lora_request = self._maybe_get_adapters(request) + model_name = self._get_model_name(request.model, lora_request) + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + _, request_prompts, engine_prompts = await self._preprocess_chat( + request, + tokenizer, + messages, + chat_template=self.chat_template, + chat_template_content_format=self.chat_template_content_format, + ) + except (ValueError, TypeError, RuntimeError, + jinja2.TemplateError) as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(f"{e} {e.__cause__}") + + request_metadata = RequestResponseMetadata( + request_id=request.request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # Schedule the request and get the result generator. + generators: list[AsyncGenerator[RequestOutput, None]] = [] + try: + for i, engine_prompt in enumerate(engine_prompts): + default_max_tokens = self.max_model_len - len( + engine_prompt["prompt_token_ids"]) + sampling_params = request.to_sampling_params( + default_max_tokens, self.default_sampling_params) + + self._log_inputs(request.request_id, + request_prompts[i], + params=sampling_params, + lora_request=lora_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.generate( + engine_prompt, + sampling_params, + request.request_id, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert len(generators) == 1 + result_generator, = generators + + # Store the input messages. + if request.store: + self.msg_store[request.request_id] = messages + + if request.background: + created_time = int(time.time()) + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=[], + status="queued", + usage=None, + ) + async with self.response_store_lock: + self.response_store[response.id] = response + + # Run the request in the background. + task = asyncio.create_task( + self._run_background_request( + request, + sampling_params, + result_generator, + model_name, + tokenizer, + request_metadata, + created_time, + ), + name=f"create_{response.id}", + ) + + # For cleanup. + response_id = response.id + self.background_tasks[response_id] = task + task.add_done_callback( + lambda _: self.background_tasks.pop(response_id, None)) + return response + + if request.stream: + raise NotImplementedError("Streaming responses are not supported") + + try: + return await self.responses_full_generator( + request, + sampling_params, + result_generator, + model_name, + tokenizer, + request_metadata, + ) + except Exception as e: + return self.create_error_response(str(e)) + + async def responses_full_generator( + self, + request: ResponsesRequest, + sampling_params: SamplingParams, + result_generator: AsyncIterator[RequestOutput], + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + created_time: Optional[int] = None, + ) -> Union[ErrorResponse, ResponsesResponse]: + if created_time is None: + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + assert final_res is not None + assert len(final_res.outputs) == 1 + final_output = final_res.outputs[0] + + if self.reasoning_parser: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content(final_output.text, + request=request)) + else: + reasoning_content = None + content = final_output.text + + output = [] + if reasoning_content: + reasoning_item = ResponseReasoningItem( + text=reasoning_content, + status=None, # NOTE: Only the last output item has status. + ) + output.append(reasoning_item) + if content: + output_text = ResponseOutputText( + text=content, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + message = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=[output_text], + role="assistant", + status="completed", + type="message", + ) + output.append(message) + + # Calculate usage. + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = len(final_output.token_ids) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + if self.enable_prompt_tokens_details and final_res.num_cached_tokens: + usage.prompt_tokens_details = PromptTokenUsageInfo( + cached_tokens=final_res.num_cached_tokens) + request_metadata.final_usage_info = usage + + response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=created_time, + output=output, + status="completed", + usage=usage, + ) + + if request.store: + async with self.response_store_lock: + stored_response = self.response_store.get(response.id) + # If the response is already cancelled, don't update it. + if (stored_response is None + or stored_response.status != "cancelled"): + self.response_store[response.id] = response + return response + + def _construct_input_messages( + self, + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse] = None, + ) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] + if request.instructions: + messages.append({ + "role": "system", + "content": request.instructions, + }) + + # Prepend the conversation history. + if prev_response is not None: + # Add the previous messages. + prev_msg = self.msg_store[prev_response.id] + messages.extend(prev_msg) + + # Add the previous output. + for output_item in prev_response.output: + # NOTE: We skip the reasoning output. + if isinstance(output_item, ResponseOutputMessage): + for content in output_item.content: + messages.append({ + "role": "assistant", + "content": content.text, + }) + + # Append the new input. + # Responses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + else: + messages.extend(request.input) # type: ignore + return messages + + async def _run_background_request( + self, + request: ResponsesRequest, + *args, + **kwargs, + ): + try: + response = await self.responses_full_generator( + request, *args, **kwargs) + except Exception as e: + logger.exception("Background request failed for %s", + request.request_id) + response = self.create_error_response(str(e)) + + if isinstance(response, ErrorResponse): + # If the request has failed, update the status to "failed". + response_id = request.request_id + async with self.response_store_lock: + stored_response = self.response_store.get(response_id) + assert stored_response is not None + if stored_response.status not in ("completed", "cancelled"): + stored_response.status = "failed" + + async def retrieve_responses( + self, + response_id: str, + ) -> Union[ErrorResponse, ResponsesResponse]: + if not response_id.startswith("resp_"): + return self._make_invalid_id_error(response_id) + + async with self.response_store_lock: + response = self.response_store.get(response_id) + + if response is None: + return self._make_not_found_error(response_id) + return response + + async def cancel_responses( + self, + response_id: str, + ) -> Union[ErrorResponse, ResponsesResponse]: + if not response_id.startswith("resp_"): + return self._make_invalid_id_error(response_id) + + async with self.response_store_lock: + response = self.response_store.get(response_id) + if response is None: + return self._make_not_found_error(response_id) + + prev_status = response.status + if prev_status not in ("queued", "in_progress"): + return self.create_error_response( + err_type="invalid_request_error", + message="Cannot cancel a synchronous response.", + ) + + # Update the status to "cancelled". + response.status = "cancelled" + + # Abort the request. + if (task := self.background_tasks.get(response_id)): + task.cancel() + try: + await task + except asyncio.CancelledError: + logger.exception("Background task for %s was cancelled", + response_id) + return response + + def _make_invalid_id_error(self, response_id: str) -> ErrorResponse: + return self.create_error_response( + err_type="invalid_request_error", + message=(f"Invalid 'response_id': '{response_id}'. " + "Expected an ID that begins with 'resp'."), + ) + + def _make_not_found_error(self, response_id: str) -> ErrorResponse: + return self.create_error_response( + err_type="invalid_request_error", + message=f"Response with id '{response_id}' not found.", + status_code=HTTPStatus.NOT_FOUND, + ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 328d4ff0e6c0..4da2094147ce 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -17,14 +17,16 @@ ScoreResponseData, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels -from vllm.entrypoints.score_utils import (_cosine_similarity, - _validate_score_input_lens) +from vllm.entrypoints.score_utils import (ScoreContentPartParam, + ScoreMultiModalParam, + _cosine_similarity, + _validate_score_input_lens, + get_score_prompt) from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import make_async, merge_async_iterators @@ -52,14 +54,11 @@ async def _embedding_score( texts_1: list[str], texts_2: list[str], request: Union[RerankRequest, ScoreRequest], - request_id=str, + request_id: str, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[Union[LoRARequest, None]] = None, - prompt_adapter_request: Optional[Union[PromptAdapterRequest, - None]] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> list[PoolingRequestOutput]: - + ) -> Union[list[PoolingRequestOutput], ErrorResponse]: input_texts = texts_1 + texts_2 engine_prompts: list[TokensPrompt] = [] @@ -86,6 +85,11 @@ async def _embedding_score( generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] pooling_params = request.to_pooling_params() + try: + pooling_params.verify("embed", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) + for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" @@ -93,8 +97,7 @@ async def _embedding_score( self._log_inputs(request_id_item, input_texts[i], params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request) generators.append( self.engine_client.encode( @@ -137,58 +140,114 @@ async def _embedding_score( return final_res_batch + def _preprocess_score( + self, + request: Union[RerankRequest, ScoreRequest], + tokenizer: AnyTokenizer, + tokenization_kwargs: dict[str, Any], + data_1: Union[str, ScoreContentPartParam], + data_2: Union[str, ScoreContentPartParam], + ) -> tuple[str, TokensPrompt]: + + model_config = self.model_config + + full_prompt, engine_prompt = get_score_prompt( + model_config=model_config, + data_1=data_1, + data_2=data_2, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + ) + if request.mm_processor_kwargs is not None: + engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs + + return full_prompt, engine_prompt + async def _cross_encoding_score( self, tokenizer: AnyTokenizer, - texts_1: list[str], - texts_2: list[str], + data_1: Union[list[str], list[ScoreContentPartParam]], + data_2: Union[list[str], list[ScoreContentPartParam]], request: Union[RerankRequest, ScoreRequest], - request_id=str, + request_id: str, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[Union[LoRARequest, None]] = None, - prompt_adapter_request: Optional[Union[PromptAdapterRequest, - None]] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> list[PoolingRequestOutput]: - + ) -> Union[list[PoolingRequestOutput], ErrorResponse]: request_prompts: list[str] = [] engine_prompts: list[TokensPrompt] = [] - if len(texts_1) == 1: - texts_1 = texts_1 * len(texts_2) - - input_pairs = [(t1, t2) for t1, t2 in zip(texts_1, texts_2)] + if len(data_1) == 1: + data_1 = data_1 * len(data_2) if isinstance(tokenizer, MistralTokenizer): raise ValueError( "MistralTokenizer not supported for cross-encoding") - tokenize_async = make_async(tokenizer.__call__, - executor=self._tokenizer_executor) - tokenization_kwargs = tokenization_kwargs or {} - tokenized_prompts = await asyncio.gather( - *(tokenize_async(text=t1, text_pair=t2, **tokenization_kwargs) - for t1, t2 in input_pairs)) - for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): - sep_token = tokenizer.sep_token if tokenizer.sep_token else '' - request_prompt = f"{t1}{sep_token}{t2}" + input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)] - input_ids = prompt_inputs["input_ids"] - text_token_prompt = \ - self._validate_input(request, input_ids, request_prompt) - engine_prompt = TokensPrompt( - prompt_token_ids=text_token_prompt["prompt_token_ids"], - token_type_ids=prompt_inputs.get("token_type_ids")) + if self.model_config.is_multimodal_model: + + preprocess_async = make_async(self._preprocess_score, + executor=self._tokenizer_executor) + + preprocessed_prompts = await asyncio.gather( + *(preprocess_async(request=request, + tokenizer=tokenizer, + tokenization_kwargs=tokenization_kwargs, + data_1=t1, + data_2=t2) for t1, t2 in input_pairs)) - request_prompts.append(request_prompt) - engine_prompts.append(engine_prompt) + for full_prompt, engine_prompt in preprocessed_prompts: + request_prompts.append(full_prompt) + engine_prompts.append(engine_prompt) + + else: + tokenize_async = make_async(tokenizer.__call__, + executor=self._tokenizer_executor) + use_pad_token = self.model_config.use_pad_token + + if use_pad_token: + # cross_encoder models defaults to using pad_token. + tokenized_prompts = await asyncio.gather(*( + tokenize_async( + text=t1, # type: ignore[arg-type] + text_pair=t2, # type: ignore[arg-type] + **tokenization_kwargs) for t1, t2 in input_pairs)) + else: + # `llm as reranker` models defaults to not using pad_token. + tokenized_prompts = await asyncio.gather(*( + tokenize_async( + text=t1 + # type: ignore[operator] + t2, + **tokenization_kwargs) for t1, t2 in input_pairs)) + + for prompt_inputs, (t1, t2) in zip(tokenized_prompts, input_pairs): + sep_token = tokenizer.sep_token if (tokenizer.sep_token + and use_pad_token) else '' + request_prompt = f"{t1}{sep_token}{t2}" + + input_ids = prompt_inputs["input_ids"] + text_token_prompt = \ + self._validate_input(request, input_ids, request_prompt) + engine_prompt = TokensPrompt( + prompt_token_ids=text_token_prompt["prompt_token_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) # Schedule the request and get the result generator. generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] - pooling_params = request.to_pooling_params(use_cross_encoder=True) + pooling_params = request.to_pooling_params() + + try: + pooling_params.verify("score", self.model_config) + except ValueError as e: + return self.create_error_response(str(e)) for i, engine_prompt in enumerate(engine_prompts): request_id_item = f"{request_id}-{i}" @@ -196,8 +255,7 @@ async def _cross_encoding_score( self._log_inputs(request_id_item, request_prompts[i], params=pooling_params, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request) generator = self.engine_client.encode( engine_prompt, @@ -223,22 +281,14 @@ async def _cross_encoding_score( async def _run_scoring( self, - texts_1: Union[str, list[str]], - texts_2: Union[str, list[str]], + data_1: Union[list[str], str, ScoreMultiModalParam], + data_2: Union[list[str], str, ScoreMultiModalParam], request: Union[ScoreRequest, RerankRequest], request_id: str, raw_request: Optional[Request] = None, truncate_prompt_tokens: Optional[int] = None, - ) -> list[PoolingRequestOutput]: - - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) - - if prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for scoring models") + ) -> Union[list[PoolingRequestOutput], ErrorResponse]: + lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) @@ -249,35 +299,44 @@ async def _run_scoring( trace_headers = (None if raw_request is None else await self._get_trace_headers(raw_request.headers)) - if isinstance(texts_1, str): - texts_1 = [texts_1] - if isinstance(texts_2, str): - texts_2 = [texts_2] + if not self.model_config.is_multimodal_model and (isinstance( + data_1, dict) or isinstance(data_2, dict)): + raise ValueError( + f"MultiModalParam is not supported for {self.model_config.architecture}" # noqa: E501 + ) + + if isinstance(data_1, str): + data_1 = [data_1] + elif isinstance(data_1, dict): + data_1 = data_1.get("content") # type: ignore[assignment] - _validate_score_input_lens(texts_1, texts_2) + if isinstance(data_2, str): + data_2 = [data_2] + elif isinstance(data_2, dict): + data_2 = data_2.get("content") # type: ignore[assignment] + + _validate_score_input_lens(data_1, data_2) # type: ignore[arg-type] if self.model_config.is_cross_encoder: return await self._cross_encoding_score( tokenizer=tokenizer, - texts_1=texts_1, - texts_2=texts_2, + data_1=data_1, # type: ignore[arg-type] + data_2=data_2, # type: ignore[arg-type] request=request, request_id=request_id, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers) else: return await self._embedding_score( tokenizer=tokenizer, - texts_1=texts_1, - texts_2=texts_2, + texts_1=data_1, # type: ignore[arg-type] + texts_2=data_2, # type: ignore[arg-type] request=request, request_id=request_id, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers) async def create_score( @@ -306,6 +365,8 @@ async def create_score( raw_request, request.truncate_prompt_tokens, ) + if isinstance(final_res_batch, ErrorResponse): + return final_res_batch return self.request_output_to_score_response( final_res_batch, @@ -339,7 +400,9 @@ async def do_rerank( request_id = f"rerank-{self._base_request_id(raw_request)}" documents = request.documents - top_n = request.top_n if request.top_n > 0 else len(documents) + top_n = request.top_n if request.top_n > 0 else ( + len(documents) + if isinstance(documents, list) else len(documents["content"])) try: final_res_batch = await self._run_scoring( @@ -350,6 +413,9 @@ async def do_rerank( raw_request, request.truncate_prompt_tokens, ) + if isinstance(final_res_batch, ErrorResponse): + return final_res_batch + return self.request_output_to_rerank_response( final_res_batch, request_id, @@ -400,7 +466,7 @@ def request_output_to_score_response( def request_output_to_rerank_response( self, final_res_batch: list[PoolingRequestOutput], request_id: str, - model_name: str, documents: list[str], + model_name: str, documents: Union[list[str], ScoreMultiModalParam], top_n: int) -> RerankResponse: """ Convert the output of do_rank to a RerankResponse @@ -412,7 +478,9 @@ def request_output_to_rerank_response( result = RerankResult( index=idx, - document=RerankDocument(text=documents[idx]), + document=RerankDocument(text=documents[idx]) if isinstance( + documents, list) else RerankDocument( + multi_modal=documents["content"][idx]), relevance_score=classify_res.outputs.score, ) results.append(result) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 3db0a71fadd1..58d720474768 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Final, Optional, Union +from dataclasses import dataclass +from typing import Any, Final, Optional, Union import jinja2 from fastapi import Request @@ -17,11 +17,13 @@ ErrorResponse, TokenizeChatRequest, TokenizeRequest, - TokenizeResponse) + TokenizeResponse, + TokenizerInfoResponse) # yapf: enable from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -58,10 +60,7 @@ async def create_tokenize( request_id = f"tokn-{self._base_request_id(raw_request)}" try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) @@ -102,11 +101,8 @@ async def create_tokenize( self._log_inputs(request_id, request_prompts[i], params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) + lora_request=lora_request) - # Silently ignore prompt adapter since it does not affect - # tokenization (Unlike in Embeddings API where an error is raised) if isinstance(engine_prompt, dict) and "prompt_token_ids" in engine_prompt: input_ids.extend(engine_prompt["prompt_token_ids"]) @@ -131,21 +127,14 @@ async def create_detokenize( request_id = f"tokn-{self._base_request_id(raw_request)}" - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + lora_request = self._maybe_get_adapters(request) tokenizer = await self.engine_client.get_tokenizer(lora_request) self._log_inputs(request_id, request.tokens, params=None, - lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request) - - # Silently ignore prompt adapter since it does not affect tokenization - # (Unlike in Embeddings API where an error is raised) + lora_request=lora_request) prompt_input = await self._tokenize_prompt_input_async( request, @@ -155,3 +144,49 @@ async def create_detokenize( input_text = prompt_input["prompt"] return DetokenizeResponse(prompt=input_text) + + async def get_tokenizer_info( + self, ) -> Union[TokenizerInfoResponse, ErrorResponse]: + """Get comprehensive tokenizer information.""" + try: + tokenizer = await self.engine_client.get_tokenizer() + info = TokenizerInfo(tokenizer, self.chat_template).to_dict() + return TokenizerInfoResponse(**info) + except Exception as e: + return self.create_error_response( + f"Failed to get tokenizer info: {str(e)}") + + +@dataclass +class TokenizerInfo: + tokenizer: AnyTokenizer + chat_template: Optional[str] + + def to_dict(self) -> dict[str, Any]: + """Return the tokenizer configuration.""" + return self._get_tokenizer_config() + + def _get_tokenizer_config(self) -> dict[str, Any]: + """Get tokenizer configuration directly from the tokenizer object.""" + config = dict(getattr(self.tokenizer, "init_kwargs", None) or {}) + + # Remove file path fields + config.pop("vocab_file", None) + config.pop("merges_file", None) + + config = self._make_json_serializable(config) + config["tokenizer_class"] = type(self.tokenizer).__name__ + if self.chat_template: + config["chat_template"] = self.chat_template + return config + + def _make_json_serializable(self, obj): + """Convert any non-JSON-serializable objects to serializable format.""" + if hasattr(obj, "content"): + return obj.content + elif isinstance(obj, dict): + return {k: self._make_json_serializable(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._make_json_serializable(item) for item in obj] + else: + return obj diff --git a/vllm/entrypoints/openai/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text.py index 0ab029e5305b..c2227a21a4b9 100644 --- a/vllm/entrypoints/openai/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text.py @@ -6,12 +6,12 @@ import time from collections.abc import AsyncGenerator from functools import cached_property -from math import ceil from typing import Callable, Literal, Optional, TypeVar, Union, cast import numpy as np from fastapi import Request +import vllm.envs as envs from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger @@ -25,10 +25,8 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger -from vllm.model_executor.model_loader import get_model_cls from vllm.model_executor.models import SupportsTranscription from vllm.outputs import RequestOutput -from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import PlaceholderModule try: @@ -41,13 +39,6 @@ logger = init_logger(__name__) -# As per https://platform.openai.com/docs/guides/speech-to-text#overview. -# TODO configurable -MAX_AUDIO_CLIP_FILESIZE_MB = 25 -MAX_AUDIO_CLIP_SECONDS = 30 -OVERLAP_CHUNK_SECOND = 1 -MIN_ENERGY_WINDOW_SIZE = 1600 # 1600 ~ 100ms for 16000 Hz audio - class OpenAISpeechToText(OpenAIServing): """Base class for speech-to-text operations like transcription and @@ -71,63 +62,60 @@ def __init__( self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) - processor = cached_get_processor(model_config.model) - self.max_audio_clip_s = processor.feature_extractor.chunk_length \ - if hasattr(processor.feature_extractor, 'chunk_length') \ - else MAX_AUDIO_CLIP_SECONDS - self.model_sr = processor.feature_extractor.sampling_rate - self.hop_length = processor.feature_extractor.hop_length self.task_type = task_type + self.asr_config = self.model_cls.get_speech_to_text_config( + model_config, task_type) + + self.max_audio_filesize_mb = envs.VLLM_MAX_AUDIO_CLIP_FILESIZE_MB + if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", self.default_sampling_params) @cached_property - def model_cls(self): - return get_model_cls(self.model_config) + def model_cls(self) -> type[SupportsTranscription]: + from vllm.model_executor.model_loader import get_model_cls + model_cls = get_model_cls(self.model_config) + return cast(type[SupportsTranscription], model_cls) async def _preprocess_speech_to_text( self, request: SpeechToTextRequest, audio_data: bytes, ) -> tuple[list[PromptType], float]: - model_cls = cast(SupportsTranscription, self.model_cls) - # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 lang = request.language or "en" - model_cls.validate_language(lang) + self.model_cls.validate_language(lang) - if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB: + if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: raise ValueError("Maximum file size exceeded.") with io.BytesIO(audio_data) as bytes_: # NOTE resample to model SR here for efficiency. This is also a # pre-requisite for chunking, as it assumes Whisper SR. - y, sr = librosa.load(bytes_, sr=self.model_sr) + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) duration = librosa.get_duration(y=y, sr=sr) - chunks = [y - ] if duration < self.max_audio_clip_s else self._split_audio( - y, int(sr)) + do_split_audio = (self.asr_config.allow_audio_chunking + and duration > self.asr_config.max_audio_clip_s) + chunks = [y] if not do_split_audio else self._split_audio(y, int(sr)) prompts = [] for chunk in chunks: - prompt = { - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "audio": (chunk, sr), - }, - }, - "decoder_prompt": - model_cls.get_decoder_prompt(lang, self.task_type, - request.prompt) - } - prompts.append(cast(PromptType, prompt)) + # The model has control over the construction, as long as it + # returns a valid PromptType. + prompt = self.model_cls.get_generation_prompt( + audio=chunk, + stt_config=self.asr_config, + model_config=self.model_config, + language=lang, + task_type=self.task_type, + request_prompt=request.prompt) + prompts.append(prompt) return prompts, duration async def _create_speech_to_text( @@ -161,19 +149,12 @@ async def _create_speech_to_text( raw_request.state.request_metadata = request_metadata try: - ( - lora_request, - prompt_adapter_request, - ) = self._maybe_get_adapters(request) + lora_request = self._maybe_get_adapters(request) if lora_request: return self.create_error_response( "Currently do not support LoRA for " f"{self.task_type.title()}.") - if prompt_adapter_request: - return self.create_error_response( - f"Currently do not support PromptAdapter for " - f"{self.task_type.title()}.") prompts, duration_s = await self._preprocess_speech_to_text( request=request, @@ -196,10 +177,10 @@ async def _create_speech_to_text( self._log_inputs( request_id, - prompts[0]['decoder_prompt'], # type: ignore + # It will not display special tokens like <|startoftranscript|> + request.prompt, params=sampling_params, - lora_request=None, - prompt_adapter_request=None) + lora_request=None) list_result_generator = [ self.engine_client.generate( @@ -261,17 +242,11 @@ async def _speech_to_text_stream_generator( async for res in result_generator: # On first result. if res.prompt_token_ids is not None: - # Do not account the 4-tokens `<|startoftranscript|>..` - # Could be negative when language token - # is not specified. - num_prompt_tokens = max( - len(res.prompt_token_ids) - 4, 0) - # NOTE(NickLucche) user can't pass encoder - # prompts directly at least not to Whisper. - # One indicator of the encoder amount of processing - # is the log-mel spectogram length. - num_prompt_tokens += ceil( - audio_duration_s * self.model_sr / self.hop_length) + num_prompt_tokens = len(res.prompt_token_ids) + if audio_tokens := self.model_cls.get_num_audio_tokens( + audio_duration_s, self.asr_config, + self.model_config): + num_prompt_tokens += audio_tokens # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST @@ -347,8 +322,8 @@ async def _speech_to_text_stream_generator( def _split_audio(self, audio_data: np.ndarray, sample_rate: int) -> list[np.ndarray]: - chunk_size = sample_rate * self.max_audio_clip_s - overlap_size = sample_rate * OVERLAP_CHUNK_SECOND + chunk_size = sample_rate * self.asr_config.max_audio_clip_s + overlap_size = sample_rate * self.asr_config.overlap_chunk_second chunks = [] i = 0 while i < audio_data.shape[-1]: @@ -384,10 +359,10 @@ def _find_split_point(self, wav: np.ndarray, start_idx: int, # Calculate RMS energy in small windows min_energy = math.inf quietest_idx = 0 - for i in range(0, - len(segment) - MIN_ENERGY_WINDOW_SIZE, - MIN_ENERGY_WINDOW_SIZE): - window = segment[i:i + MIN_ENERGY_WINDOW_SIZE] + min_energy_window = self.asr_config.min_energy_split_window_size + assert min_energy_window is not None + for i in range(0, len(segment) - min_energy_window, min_energy_window): + window = segment[i:i + min_energy_window] energy = (window**2).mean()**0.5 if energy < min_energy: quietest_idx = i + start_idx diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 57e675515e12..88c8aa929b78 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -3,23 +3,41 @@ from .abstract_tool_parser import ToolParser, ToolParserManager from .deepseekv3_tool_parser import DeepSeekV3ToolParser +from .glm4_moe_tool_parser import Glm4MoeModelToolParser from .granite_20b_fc_tool_parser import Granite20bFCToolParser from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser +from .hunyuan_a13b_tool_parser import HunyuanA13BToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser +from .kimi_k2_tool_parser import KimiK2ToolParser from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .minimax_tool_parser import MinimaxToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser from .pythonic_tool_parser import PythonicToolParser +from .qwen3coder_tool_parser import Qwen3CoderToolParser from .xlam_tool_parser import xLAMToolParser __all__ = [ - "ToolParser", "ToolParserManager", "Granite20bFCToolParser", - "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", - "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", - "DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser" + "ToolParser", + "ToolParserManager", + "Granite20bFCToolParser", + "GraniteToolParser", + "Hermes2ProToolParser", + "MistralToolParser", + "Internlm2ToolParser", + "Llama3JsonToolParser", + "JambaToolParser", + "Llama4PythonicToolParser", + "PythonicToolParser", + "Phi4MiniJsonToolParser", + "DeepSeekV3ToolParser", + "xLAMToolParser", + "MinimaxToolParser", + "KimiK2ToolParser", + "HunyuanA13BToolParser", + "Glm4MoeModelToolParser", + "Qwen3CoderToolParser", ] diff --git a/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py new file mode 100644 index 000000000000..c3f9d7923575 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py @@ -0,0 +1,402 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# code modified from deepseekv3_tool_parser.py + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("glm4_moe") +class Glm4MoeModelToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id = -1 + self.streamed_args_for_tool: list[str] = [] + self.tool_call_start_token = "<tool_call>" + self.tool_call_end_token = "</tool_call>" + + self.tool_calls_start_token = self.tool_call_start_token + + # Updated regex for the XML-based format + self.tool_call_regex = re.compile( + r"<tool_call>\s*" + r"(?P<function_name>[^\n<]+)\s*" # 函数名(到换行或 <) + r"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*" + r"<arg_value>[^<]*</arg_value>\s*)*)\s*" + r"</tool_call>", + re.DOTALL, + ) + + # Regex for parsing individual arguments + self.arg_regex = re.compile( + r"<arg_key>(?P<key>[^<]+)</arg_key>\s*<arg_value>(?P<value>[^<]*)</arg_value>", + re.DOTALL, + ) + + # Streaming regex + self.stream_tool_call_portion_regex = re.compile( + r"(?P<function_name>[^\n<]+)\s*" + r"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*" + r"<arg_value>[^<]*</arg_value>\s*)*)", + re.DOTALL, + ) + + # For streaming, we also need a regex to match just the function name + self.stream_tool_call_name_regex = re.compile( + r"(?P<function_name>[^\n<]+)", + re.DOTALL, + ) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + def _parse_arguments(self, args_text: str) -> str: + """Parse XML-based arguments into JSON format.""" + if not args_text or not args_text.strip(): + return "{}" + + args_dict = {} + matches = self.arg_regex.findall(args_text) + + for key, value in matches: + args_dict[key.strip()] = value.strip() + + import json + return json.dumps(args_dict, ensure_ascii=False) + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # Find all tool calls in the output + function_call_matches = self.tool_call_regex.findall(model_output) + + logger.debug("function_call_matches: %s", function_call_matches) + + if not function_call_matches: + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=model_output, + ) + + tool_calls = [] + for i, match in enumerate(function_call_matches): + function_name, function_args_xml = match + function_name = function_name.strip() + + # Parse XML arguments to JSON + function_args_json = self._parse_arguments(function_args_xml) + + tool_calls.append( + ToolCall( + id=f"call_{i}", + type='function', + function=FunctionCall(name=function_name, + arguments=function_args_json), + )) + + # Extract content before the first tool call + content = model_output[:model_output.find(self. + tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=bool(tool_calls), + tool_calls=tool_calls, + content=content.strip() if content.strip() else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_call_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_id, tool_args = (current_tool_call_matches.groups()) + tool_name = tool_id.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_id_str, = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id_str + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + tool_id = current_tool_call.get("id") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py new file mode 100644 index 000000000000..2b65f2579fb4 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py @@ -0,0 +1,372 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501, SIM102 + +import json +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import consume_space +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("hunyuan_a13b") +class HunyuanA13BToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # Initialize state for streaming mode + self.prev_tool_calls: list[dict] = [] + self.current_tool_id = -1 + self.current_tool_name_sent = False + self.streamed_args: list[str] = [ + ] # Track arguments sent for each tool + + # For backward compatibility with tests + self.current_tools_sent: list[bool] = [] + + # For backward compatibility with serving code + self.prev_tool_call_arr = [] + + # Regex patterns for preprocessing + self.answer_tool_calls_pattern = re.compile( + r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL) + + self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"') + + self.tool_empty_arg_reg = re.compile( + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}') + + # TODO: not support nested json object in fc arguments. + self.tool_non_empty_arg_reg = re.compile( + r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})' + ) + + self.bot_string = "<tool_calls>" + + # Define streaming state type to be initialized later + self.streaming_state: dict[str, Any] = { + "current_tool_index": -1, + "tool_ids": [], + "sent_tools": [], + } + + def preprocess_model_output( + self, model_output: str) -> tuple[Optional[str], Optional[str]]: + # find the location tool call + for match in self.answer_tool_calls_pattern.finditer(model_output): + start, end = match.span() + # check tool_calls whether in side of <think> + think_regions = [(m.start(), m.end()) for m in re.finditer( + r"<think>(.*?)</think>", model_output, flags=re.DOTALL)] + in_think = any(start > t_start and end < t_end + for t_start, t_end in think_regions) + if not in_think: + content = model_output[:start] + tool_calls_content = match.group(1).strip() + try: + json.loads(tool_calls_content) + return content, tool_calls_content + except Exception: + continue + return model_output, None + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract tool calls from a complete model output. + """ + try: + # Preprocess the model output + content, potential_tool_calls = self.preprocess_model_output( + model_output) + + if not potential_tool_calls: + # some text should be filtered out for no function call + # this text is in a13b's chat template. + if content: + content = content.replace("助手:", "", 1) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=content) + + # Parse the potential tool calls as JSON + tool_calls_data = json.loads(potential_tool_calls) + + # Ensure it's an array + if not isinstance(tool_calls_data, list): + logger.debug("Tool calls data is not an array") + return ExtractedToolCallInformation( + tools_called=False, + tool_calls=[], + content=content or model_output, + ) + + tool_calls: list[ToolCall] = [] + + for idx, call in enumerate(tool_calls_data): + if (not isinstance(call, dict) or "name" not in call + or "arguments" not in call): + continue + + tool_call = ToolCall( + id=f"call_{random_uuid()}", + type="function", + function=FunctionCall( + name=call["name"], + arguments=(json.dumps(call["arguments"]) if isinstance( + call["arguments"], dict) else call["arguments"]), + ), + ) + tool_calls.append(tool_call) + + if not content or len(content.strip()) == 0: + # clear the whitespace content. + content = None + + return ExtractedToolCallInformation( + tools_called=len(tool_calls) > 0, + tool_calls=tool_calls, + content=content, + ) + + except Exception: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Extract tool calls for streaming mode. + """ + + start_idx = consume_space(0, current_text) + if current_text[start_idx:].startswith(self.bot_string): + start_idx = consume_space(start_idx + len(self.bot_string), + current_text) + if not current_text or start_idx >= len( + current_text) or current_text[start_idx] != '[': + return DeltaMessage(content=delta_text) + + self._try_parse_json_tools(current_text[start_idx:]) + + test_delta = self._handle_test_compatibility(current_text) + if test_delta: + return test_delta + + name_matches = list(self.tool_name_reg.finditer(current_text)) + tool_count = len(name_matches) + if tool_count == 0: + return None + self._ensure_state_arrays(tool_count) + current_idx = self.streaming_state["current_tool_index"] + + name_delta = self._handle_tool_name_streaming(current_idx, tool_count, + name_matches) + if name_delta: + return name_delta + + args_delta = self._handle_tool_args_streaming(current_text, + current_idx, tool_count) + if args_delta: + return args_delta + + return None + + def _try_parse_json_tools(self, current_text: str): + try: + parsed_tools = json.loads(current_text) + if isinstance(parsed_tools, list): + self.prev_tool_call_arr = parsed_tools + except json.JSONDecodeError: + pass + + def _handle_test_compatibility(self, current_text: str): + if len(self.current_tools_sent) > 0: + if (len(self.current_tools_sent) == 1 + and self.current_tools_sent[0] is False): + name_match = self.tool_name_reg.search(current_text) + if name_match: + function_name = name_match.group(1) + tool_id = f"chatcmpl-tool-{random_uuid()}" + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=0, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + self.current_tools_sent = [True] + self.current_tool_id = 0 + self.streaming_state["current_tool_index"] = 0 + if len(self.streaming_state["sent_tools"]) == 0: + self.streaming_state["sent_tools"].append({ + "sent_name": + True, + "sent_arguments_prefix": + False, + "sent_arguments": + "", + }) + else: + self.streaming_state["sent_tools"][0][ + "sent_name"] = True + self.current_tool_name_sent = True + return delta + return None + + def _ensure_state_arrays(self, tool_count: int): + while len(self.streaming_state["sent_tools"]) < tool_count: + self.streaming_state["sent_tools"].append({ + "sent_name": False, + "sent_arguments_prefix": False, + "sent_arguments": "", + }) + while len(self.streaming_state["tool_ids"]) < tool_count: + self.streaming_state["tool_ids"].append(None) + + def _handle_tool_name_streaming(self, current_idx: int, tool_count: int, + name_matches): + if current_idx == -1 or current_idx < tool_count - 1: + next_idx = current_idx + 1 + if (next_idx < tool_count + and not self.streaming_state["sent_tools"][next_idx] + ["sent_name"]): + self.streaming_state["current_tool_index"] = next_idx + self.current_tool_id = next_idx + current_idx = next_idx + tool_name = name_matches[current_idx].group(1) + tool_id = f"call_{current_idx}_{random_uuid()}" + self.streaming_state["tool_ids"][current_idx] = tool_id + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + type="function", + id=tool_id, + function=DeltaFunctionCall(name=tool_name).model_dump( + exclude_none=True), + ) + ]) + self.streaming_state["sent_tools"][current_idx][ + "sent_name"] = True + self.current_tool_name_sent = True + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + return delta + return None + + def _handle_tool_args_streaming(self, current_text: str, current_idx: int, + tool_count: int): + + if current_idx >= 0 and current_idx < tool_count: + empty_args_match = self.tool_empty_arg_reg.search(current_text) + if empty_args_match and empty_args_match.start() > 0: + for i in range(tool_count): + if i == current_idx: + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"]: + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = "{}" + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{}" + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{}").model_dump( + exclude_none=True), + ) + ]) + if current_idx < tool_count - 1: + self.streaming_state["current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] + return delta + + args_matches = list( + self.tool_non_empty_arg_reg.finditer(current_text)) + if current_idx < len(args_matches): + args_text = args_matches[current_idx].group(1) + is_last_tool = current_idx == tool_count - 1 + if not is_last_tool: + next_tool_pos = current_text.find( + "},{", args_matches[current_idx].start()) + if next_tool_pos != -1: + args_end_pos = (next_tool_pos + 1) + args_text = ( + current_text[args_matches[current_idx].start( + ):args_end_pos].split('"arguments":')[1].strip()) + sent_args = self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] + if not self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] and args_text.startswith("{"): + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments_prefix"] = True + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = "{" + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += "{" + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments="{").model_dump(exclude_none=True), + ) + ]) + return delta + + if args_text.startswith(sent_args): + args_diff = args_text[len(sent_args):] + if args_diff: + self.streaming_state["sent_tools"][current_idx][ + "sent_arguments"] = args_text + while len(self.streamed_args) <= current_idx: + self.streamed_args.append("") + self.streamed_args[current_idx] += args_diff + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=current_idx, + function=DeltaFunctionCall( + arguments=args_diff).model_dump( + exclude_none=True), + ) + ]) + return delta + + if args_text.endswith("}") and args_text == sent_args: + if current_idx < tool_count - 1: + self.streaming_state["current_tool_index"] += 1 + self.current_tool_id = self.streaming_state[ + "current_tool_index"] + return None diff --git a/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py new file mode 100644 index 000000000000..b0df442dd864 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py @@ -0,0 +1,377 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# code modified from deepseekv3_tool_parser.py + +from collections.abc import Sequence +from typing import Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["kimi_k2"]) +class KimiK2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: list[str] = ( + []) # map what has been streamed for each tool so far to a list + + self.tool_calls_start_token: str = "<|tool_calls_section_begin|>" + self.tool_calls_end_token: str = "<|tool_calls_section_end|>" + + self.tool_call_start_token: str = "<|tool_call_begin|>" + self.tool_call_end_token: str = "<|tool_call_end|>" + + self.tool_call_regex = re.compile( + r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*?)\s*<\|tool_call_end\|>" + ) + + self.stream_tool_call_portion_regex = re.compile( + r"(?P<tool_call_id>[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P<function_arguments>.*)" + ) + + self.stream_tool_call_name_regex = re.compile( + r"(?P<tool_call_id>[\w\.]+:\d+)\s*") + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_calls_start_token_id = self.vocab.get( + self.tool_calls_start_token) + self.tool_calls_end_token_id = self.vocab.get( + self.tool_calls_end_token) + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_calls_start_token_id is None + or self.tool_calls_end_token_id is None): + raise RuntimeError( + "Kimi-K2 Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_calls_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = self.tool_call_regex.findall( + model_output) + + logger.debug("function_call_tuples: %s", function_call_tuples) + + tool_calls = [] + for match in function_call_tuples: + function_id, function_args = match + # function_id: functions.get_weather:0 + function_name = function_id.split('.')[1].split(':')[0] + tool_calls.append( + ToolCall( + id=function_id, + type='function', + function=FunctionCall(name=function_name, + arguments=function_args), + )) + + content = model_output[:model_output. + find(self.tool_calls_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception( + "Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_calls_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + delta_text = delta_text.replace(self.tool_calls_start_token, + "").replace(self.tool_calls_end_token, + "") + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + tool_call_portion = None + text_portion = None + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count + and self.tool_call_end_token not in delta_text): + logger.debug("Generating text content! skipping tool parsing.") + return DeltaMessage(content=delta_text) + + if self.tool_call_end_token in delta_text: + logger.debug("tool_call_end_token in delta_text") + full_text = current_text + delta_text + tool_call_portion = full_text.split( + self.tool_call_start_token)[-1].split( + self.tool_call_end_token)[0].rstrip() + delta_text = delta_text.split( + self.tool_call_end_token)[0].rstrip() + text_portion = delta_text.split( + self.tool_call_end_token)[-1].lstrip() + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count >= prev_tool_end_count): + if self.prev_tool_call_arr is None or len( + self.prev_tool_call_arr) == 0: + logger.debug( + "attempting to close tool call, but no tool call") + return None + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = (diff.encode("utf-8").decode("unicode_escape") + if diff is str else diff) + if '"}' not in delta_text: + return None + end_loc = delta_text.rindex('"}') + diff = delta_text[:end_loc] + '"}' + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", + diff, + ) + self.streamed_args_for_tool[self.current_tool_id] += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + current_tool_call = dict() + if tool_call_portion: + current_tool_call_matches = ( + self.stream_tool_call_portion_regex.match( + tool_call_portion)) + if current_tool_call_matches: + tool_id, tool_args = (current_tool_call_matches.groups()) + tool_name = tool_id.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = tool_args + else: + current_tool_call_name_matches = ( + self.stream_tool_call_name_regex.match( + tool_call_portion)) + if current_tool_call_name_matches: + tool_id_str, = current_tool_call_name_matches.groups() + tool_name = tool_id_str.split('.')[1].split(':')[0] + current_tool_call['id'] = tool_id_str + current_tool_call["name"] = tool_name + current_tool_call["arguments"] = "" + else: + logger.debug("Not enough token") + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + if current_tool_call is None: + return None + function_name: Union[str, None] = current_tool_call.get("name") + tool_id = current_tool_call.get("id") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ]) + else: + return None + + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = (DeltaMessage( + content=delta_text) if text_portion is not None else None) + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=cur_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + if (isinstance(delta_text, str) + and cur_arguments != prev_arguments + and len(cur_arguments) > len(prev_arguments) + and cur_arguments.startswith(prev_arguments)): + delta_arguments = cur_arguments[len(prev_arguments):] + logger.debug("got diff %s", delta_text) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] = cur_arguments + else: + delta = None + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[ + self.current_tool_id] = current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception: + logger.exception("Error trying to handle streaming tool call.") + return None # do not stream a delta. skip this token ID. \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py new file mode 100644 index 000000000000..cf4d0b231aee --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -0,0 +1,669 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +import uuid +from collections.abc import Sequence +from typing import Any, Optional, Union + +import regex as re + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionToolsParam, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["qwen3_coder"]) +class Qwen3CoderToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + + # Sentinel tokens for streaming mode + self.tool_call_start_token: str = "<tool_call>" + self.tool_call_end_token: str = "</tool_call>" + self.tool_call_prefix: str = "<function=" + self.function_end_token: str = "</function>" + self.parameter_prefix: str = "<parameter=" + self.parameter_end_token: str = "</parameter>" + self.is_tool_call_started: bool = False + self.failed_count: int = 0 + + # Streaming state variables + self.current_tool_index: int = 0 + self.header_sent: bool = False + self.current_tool_string_id: Optional[str] = None + self.current_function_name: Optional[str] = None + self.current_param_name: Optional[str] = None + self.current_param_value: str = "" + self.param_count: int = 0 + self.in_param: bool = False + self.in_function: bool = False + self.accumulated_text: str = "" + self.json_started: bool = False + self.json_closed: bool = False + + # Enhanced streaming state - reset for each new message + self._reset_streaming_state() + + # Regex patterns + self.tool_call_complete_regex = re.compile( + r"<tool_call>(.*?)</tool_call>", re.DOTALL) + self.tool_call_regex = re.compile( + r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL) + self.tool_call_function_regex = re.compile( + r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL) + self.tool_call_parameter_regex = re.compile( + r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + + if (self.tool_call_start_token_id is None + or self.tool_call_end_token_id is None): + raise RuntimeError( + "Qwen3 XML Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + logger.debug("vLLM Successfully import tool parser %s !", + self.__class__.__name__) + + def _generate_tool_call_id(self) -> str: + """Generate a unique tool call ID.""" + return f"call_{uuid.uuid4().hex[:24]}" + + def _reset_streaming_state(self): + """Reset all streaming state.""" + self.current_tool_index = 0 + self.is_tool_call_started = False + self.header_sent = False + self.current_tool_string_id = None + self.current_function_name = None + self.current_param_name = None + self.current_param_value = "" + self.param_count = 0 + self.in_param = False + self.in_function = False + self.accumulated_text = "" + self.json_started = False + self.json_closed = False + + def _parse_xml_function_call( + self, function_call_str: str, + tools: Optional[list[ChatCompletionToolsParam]] + ) -> Optional[ToolCall]: + + def get_arguments_config(func_name: str) -> dict: + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not ( + hasattr(config, "function") + and hasattr(config.function, "name")): + continue + if (config.type == "function" + and config.function.name == func_name): + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def convert_param_value(param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + # Handle null value for any type + if param_value.lower() == "null": + return None + + converted_value: Any + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if (isinstance(param_config[param_name], dict) + and "type" in param_config[param_name]): + param_type = str( + param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in [ + "string", "str", "text", "varchar", "char", "enum" + ]: + return param_value + elif (param_type.startswith("int") or param_type.startswith("uint") + or param_type.startswith("long") + or param_type.startswith("short") + or param_type.startswith("unsigned")): + try: + converted_value = int(param_value) + return converted_value + except ValueError: + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + elif (param_type.startswith("num") + or param_type.startswith("float")): + try: + float_param_value = float(param_value) + converted_value = (float_param_value if float_param_value - + int(float_param_value) != 0 else + int(float_param_value)) + return converted_value + except ValueError: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a " + "boolean (`true` of `false`) in tool '%s', " + "degenerating to false.", param_value, param_name, + func_name) + return param_value == "true" + else: + if param_type == "object" or param_type.startswith("dict"): + try: + converted_value = json.loads(param_value) + return converted_value + except json.JSONDecodeError: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a " + "valid JSON object in tool '%s', will try other " + "methods to parse it.", param_value, param_name, + func_name) + try: + converted_value = eval(param_value) + return converted_value + except Exception: + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `eval()` in tool '%s', " + "degenerating to string.", param_value, param_name, + func_name) + return param_value + + # Extract function name + end_index = function_call_str.index(">") + function_name = function_call_str[:end_index] + param_config = get_arguments_config(function_name) + parameters = function_call_str[end_index + 1:] + param_dict = {} + for match in self.tool_call_parameter_regex.findall(parameters): + match_text = match[0] if match[0] else match[1] + idx = match_text.index(">") + param_name = match_text[:idx] + param_value = str(match_text[idx + 1:]) + # Remove prefix and trailing \n + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = convert_param_value( + param_value, param_name, param_config, function_name) + return ToolCall( + type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(param_dict, + ensure_ascii=False)), + ) + + def _get_function_calls(self, model_output: str) -> list[str]: + # Find all tool calls + matched_ranges = self.tool_call_regex.findall(model_output) + raw_tool_calls = [ + match[0] if match[0] else match[1] for match in matched_ranges + ] + + # Back-off strategy if no tool_call tags found + if len(raw_tool_calls) == 0: + raw_tool_calls = [model_output] + + raw_function_calls = [] + for tool_call in raw_tool_calls: + raw_function_calls.extend( + self.tool_call_function_regex.findall(tool_call)) + + function_calls = [ + match[0] if match[0] else match[1] for match in raw_function_calls + ] + return function_calls + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + # Quick check to avoid unnecessary processing + if self.tool_call_prefix not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + function_calls = self._get_function_calls(model_output) + if len(function_calls) == 0: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + tool_calls = [ + self._parse_xml_function_call(function_call_str, request.tools) + for function_call_str in function_calls + ] + + # Populate prev_tool_call_arr for serving layer to set + # finish_reason + self.prev_tool_call_arr.clear() # Clear previous calls + for tool_call in tool_calls: + if tool_call: + self.prev_tool_call_arr.append({ + "name": + tool_call.function.name, + "arguments": + tool_call.function.arguments, + }) + + # Extract content before tool calls + content_index = model_output.find(self.tool_call_start_token) + content_index = (content_index if content_index >= 0 else + model_output.find(self.tool_call_prefix)) + content = model_output[:content_index] # .rstrip() + + return ExtractedToolCallInformation( + tools_called=(len(tool_calls) > 0), + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error in extracting tool call from response.") + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + # If no delta text, return None unless it's an EOS token after tool + # calls + if not delta_text: + # Check if this is an EOS token after all tool calls are complete + # We check for tool calls in the text even if is_tool_call_started + # is False because it might have been reset after processing all + # tools + if (delta_token_ids + and self.tool_call_end_token_id not in delta_token_ids): + # Count complete tool calls + complete_calls = len( + self.tool_call_complete_regex.findall(current_text)) + + # If we have completed tool calls and populated + # prev_tool_call_arr + if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + # Check if all tool calls are closed + open_calls = ( + current_text.count(self.tool_call_start_token) - + current_text.count(self.tool_call_end_token)) + if open_calls == 0: + # Return empty delta message to allow finish_reason + # processing + return DeltaMessage(content="") + elif not self.is_tool_call_started and current_text: + # This is a regular content response that's now complete + return DeltaMessage(content="") + return None + + # Check if this is the first call (reset state if needed) + if not previous_text: + self._reset_streaming_state() + + # Update accumulated text + self.accumulated_text = current_text + + # Check if we need to advance to next tool + if self.json_closed and not self.in_function: + # Check if this tool call has ended + tool_ends = current_text.count(self.tool_call_end_token) + if tool_ends > self.current_tool_index: + # This tool has ended, advance to next + self.current_tool_index += 1 + self.header_sent = False + self.param_count = 0 + self.json_started = False + self.json_closed = False + + # Check if there are more tool calls + tool_starts_count = current_text.count( + self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # No more tool calls + self.is_tool_call_started = False + # Continue processing next tool + return None + + # Handle normal content before tool calls + if not self.is_tool_call_started: + # Check if tool call is starting + if (self.tool_call_start_token_id in delta_token_ids + or self.tool_call_start_token in delta_text): + self.is_tool_call_started = True + # Return any content before the tool call + if self.tool_call_start_token in delta_text: + content_before = delta_text[:delta_text.index( + self.tool_call_start_token)] + if content_before: + return DeltaMessage(content=content_before) + return None + else: + # Check if we're between tool calls - skip whitespace + if (current_text.rstrip().endswith(self.tool_call_end_token) + and delta_text.strip() == ""): + # We just ended a tool call, skip whitespace + return None + # Normal content, no tool call + return DeltaMessage(content=delta_text) + + # Check if we're between tool calls (waiting for next one) + # Count tool calls we've seen vs processed + tool_starts_count = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts_count: + # We're past all tool calls, shouldn't be here + return None + + # We're in a tool call, find the current tool call portion + # Need to find the correct tool call based on current_tool_index + tool_starts: list[int] = [] + idx = 0 + while True: + idx = current_text.find(self.tool_call_start_token, idx) + if idx == -1: + break + tool_starts.append(idx) + idx += len(self.tool_call_start_token) + + if self.current_tool_index >= len(tool_starts): + # No more tool calls to process yet + return None + + tool_start_idx = tool_starts[self.current_tool_index] + # Find where this tool call ends (or current position if not ended yet) + tool_end_idx = current_text.find(self.tool_call_end_token, + tool_start_idx) + if tool_end_idx == -1: + tool_text = current_text[tool_start_idx:] + else: + tool_text = current_text[tool_start_idx:tool_end_idx + + len(self.tool_call_end_token)] + + # Looking for function header + if not self.header_sent: + if self.tool_call_prefix in tool_text: + func_start = (tool_text.find(self.tool_call_prefix) + + len(self.tool_call_prefix)) + func_end = tool_text.find(">", func_start) + + if func_end != -1: + # Found complete function name + self.current_function_name = tool_text[func_start:func_end] + self.current_tool_string_id = self._generate_tool_call_id() + self.header_sent = True + self.in_function = True + + # IMPORTANT: Add to prev_tool_call_arr immediately when we + # detect a tool call. This ensures + # finish_reason="tool_calls" even if parsing isn't complete + already_added = any( + tool.get("name") == self.current_function_name + for tool in self.prev_tool_call_arr) + if not already_added: + self.prev_tool_call_arr.append({ + "name": self.current_function_name, + "arguments": + "{}", # Placeholder, will be updated later + }) + + # Send header with function info + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_string_id, + function=DeltaFunctionCall( + name=self.current_function_name, arguments=""), + type="function", + ) + ]) + return None + + # We've sent header, now handle function body + if self.in_function: + # Send opening brace if not sent yet + if (not self.json_started + and self.parameter_prefix not in delta_text): + self.json_started = True + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="{"), + ) + ]) + + # Make sure json_started is set if we're processing parameters + if not self.json_started: + self.json_started = True + + # Check for function end in accumulated text + if not self.json_closed and self.function_end_token in tool_text: + # Close JSON + self.json_closed = True + + # Extract the complete tool call to update prev_tool_call_arr + # with final arguments. Find the function content + func_start = (tool_text.find(self.tool_call_prefix) + + len(self.tool_call_prefix)) + func_content_end = tool_text.find(self.function_end_token, + func_start) + if func_content_end != -1: + func_content = tool_text[func_start:func_content_end] + # Parse to get the complete arguments + try: + parsed_tool = self._parse_xml_function_call( + func_content, request.tools if request else None) + if parsed_tool: + # Update existing entry in prev_tool_call_arr with + # complete arguments + for i, tool in enumerate(self.prev_tool_call_arr): + if (tool.get("name") == + parsed_tool.function.name): + self.prev_tool_call_arr[i]["arguments"] = ( + parsed_tool.function.arguments) + break + except Exception: + pass # Ignore parsing errors during streaming + + result = DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments="}"), + ) + ]) + + # Reset state for next tool + self.in_function = False + self.json_closed = True + + return result + + # Look for parameters + # Count how many complete parameters we have processed + complete_params = tool_text.count(self.parameter_end_token) + + # Check if we should start a new parameter + if not self.in_param and self.param_count < complete_params: + # Find the unprocessed parameter + # Count parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) + + if len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] + + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] + + # Find where this parameter ends + param_end_idx = value_text.find( + self.parameter_end_token) + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + # Build complete JSON fragment for this parameter + if self.param_count == 0: + json_fragment = ( + '"' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + else: + json_fragment = ( + ', "' + self.current_param_name + '": "' + + json.dumps(param_value)[1:-1] + '"') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value + if self.in_param: + if self.parameter_end_token in delta_text: + # End of parameter + end_idx = delta_text.find(self.parameter_end_token) + value_chunk = delta_text[:end_idx] + + # Skip past > if at start + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if (not self.current_param_value + and value_chunk.startswith("\n")): + value_chunk = value_chunk[1:] + + # Calculate incremental JSON + full_value = self.current_param_value + value_chunk + prev_escaped = (json.dumps(self.current_param_value)[1:-1] + if self.current_param_value else "") + full_escaped = json.dumps(full_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + self.in_param = False + self.current_param_value = "" + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped + '"'), + ) + ]) + else: + # Continue accumulating value + value_chunk = delta_text + + # Handle first chunk after param name + if not self.current_param_value and ">" in value_chunk: + gt_idx = value_chunk.find(">") + value_chunk = value_chunk[gt_idx + 1:] + + if (not self.current_param_value + and value_chunk.startswith("\n")): + value_chunk = value_chunk[1:] + + if value_chunk: + # Stream the escaped delta + prev_escaped = (json.dumps( + self.current_param_value)[1:-1] + if self.current_param_value else "") + self.current_param_value += value_chunk + full_escaped = json.dumps( + self.current_param_value)[1:-1] + delta_escaped = full_escaped[len(prev_escaped):] + + if delta_escaped: + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=delta_escaped), + ) + ]) + + return None diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index c4e044f3a28e..f3f042355c9e 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -1,13 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Union +from typing import Any, Optional, Union, cast from torch.nn import CosineSimilarity +from typing_extensions import Required, TypeAlias, TypedDict +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + BaseMultiModalItemTracker, ChatCompletionContentPartImageEmbedsParam, + ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam, + MultiModalItemTracker, _ContentPart, _parse_chat_message_content_part) +from vllm.inputs import TokensPrompt +from vllm.model_executor.models.interfaces import supports_score_template +from vllm.multimodal.inputs import MultiModalDataDict from vllm.outputs import PoolingRequestOutput -from vllm.transformers_utils.tokenizer import (PreTrainedTokenizer, +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + PreTrainedTokenizer, PreTrainedTokenizerFast) +ScoreContentPartParam: TypeAlias = Union[ + ChatCompletionContentPartImageParam, + ChatCompletionContentPartImageEmbedsParam] + + +class ScoreMultiModalParam(TypedDict, total=False): + """ + A specialized parameter type for scoring multimodal content + + The reasons why don't reuse `CustomChatCompletionMessageParam` directly: + 1. Score tasks don't need the 'role' field (user/assistant/system) that's required in chat completions + 2. Including chat-specific fields would confuse users about their purpose in scoring + 3. This is a more focused interface that only exposes what's needed for scoring + """ # noqa: E501 + content: Required[list[ScoreContentPartParam]] + """The multimodal contents""" + def _cosine_similarity( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], @@ -39,12 +66,133 @@ def _cosine_similarity( def _validate_score_input_lens( - texts_1: Union[list[str], list[dict]], - texts_2: Union[list[str], list[dict]], + data_1: Union[list[str], list[ScoreContentPartParam]], + data_2: Union[list[str], list[ScoreContentPartParam]], ): - if len(texts_1) > 1 and len(texts_1) != len(texts_2): + len_1 = len(data_1) + len_2 = len(data_2) + + if len_1 > 1 and len_1 != len_2: raise ValueError("Input lengths must be either 1:1, 1:N or N:N") - if len(texts_1) == 0: + if len_1 == 0: raise ValueError("At least one text element must be given") - if len(texts_2) == 0: - raise ValueError("At least one text_pair element must be given") \ No newline at end of file + if len_2 == 0: + raise ValueError("At least one text_pair element must be given") + + +def parse_score_data( + data_1: Union[str, ScoreContentPartParam], + data_2: Union[str, ScoreContentPartParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> tuple[str, str, Optional[MultiModalDataDict]]: + mm_tracker = MultiModalItemTracker(model_config, tokenizer) + + content_1 = _parse_score_content(data_1, mm_tracker) + + content_2 = _parse_score_content(data_2, mm_tracker) + + def ensure_str(content: Optional[_ContentPart]) -> str: + if content is not None and isinstance(content, str): + return cast(str, content) + else: + raise ValueError( + f"Only string content is supported, but got {content}.") + + prompt_1 = ensure_str(content_1) + prompt_2 = ensure_str(content_2) + + return prompt_1, prompt_2, mm_tracker.all_mm_data() + + +def _parse_score_content( + data: Union[str, ScoreContentPartParam], + mm_tracker: BaseMultiModalItemTracker, +) -> Optional[_ContentPart]: + + if isinstance(data, str): + data = ChatCompletionContentPartTextParam(type="text", text=data) + + mm_parser = mm_tracker.create_parser() + + parse_res = _parse_chat_message_content_part( + data, + mm_parser, + wrap_dicts=False, + interleave_strings=False, + ) + + if parse_res: + return parse_res + + mm_placeholder_storage = mm_parser.mm_placeholder_storage() + + if len(mm_placeholder_storage) != 1 or len( + next(iter(mm_placeholder_storage.values()))) != 1: + raise ValueError("Only one multi-modal item is supported") + + return next(iter(mm_placeholder_storage.values()))[0] + + +def apply_score_template( + model_config: ModelConfig, + prompt_1: str, + prompt_2: str, +) -> str: + # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf) + from vllm.model_executor.model_loader import get_model_cls + + model = get_model_cls(model_config) + if supports_score_template(model): + full_prompt = model.get_score_template(prompt_1, prompt_2) + if full_prompt is None: + raise ValueError("Get empty score template from model") + return full_prompt + + raise ValueError( + f"Unsupported model architecture: {model_config.architecture}") + + +def post_process_tokens( + model_config: ModelConfig, + prompt: TokensPrompt, +) -> None: + """ + Perform architecture-specific manipulations on the input tokens. + + Note: + This is an in-place operation. + """ + # NOTE(Simon): lazy import to avoid bring in all dependencies (e.g. gguf) + from vllm.model_executor.model_loader import get_model_cls + + model = get_model_cls(model_config) + if supports_score_template(model): + model.post_process_tokens(prompt) + + +def get_score_prompt( + model_config: ModelConfig, + tokenizer: AnyTokenizer, + tokenization_kwargs: dict[str, Any], + data_1: Union[str, ScoreContentPartParam], + data_2: Union[str, ScoreContentPartParam], +) -> tuple[str, TokensPrompt]: + prompt_1, prompt_2, mm_data = parse_score_data( + data_1, + data_2, + model_config, + tokenizer, + ) + + full_prompt = apply_score_template(model_config, prompt_1, prompt_2) + + prompt_inputs = tokenizer(full_prompt, **tokenization_kwargs) + + engine_prompt = TokensPrompt(prompt_token_ids=prompt_inputs["input_ids"]) + + post_process_tokens(model_config, engine_prompt) + + if mm_data is not None: + engine_prompt["multi_modal_data"] = mm_data + return full_prompt, engine_prompt diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 423b99dbe565..87334f458fee 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -5,6 +5,7 @@ import asyncio import functools import os +import subprocess import sys from typing import Any, Optional, Union @@ -25,7 +26,8 @@ " - To view a argument group: --help=ModelConfig\n" " - To view a single argument: --help=max-num-seqs\n" " - To search by keyword: --help=max\n" - " - To list all groups: --help=listgroup") + " - To list all groups: --help=listgroup\n" + " - To view help with pager: --help=page") async def listen_for_disconnect(request: Request) -> None: @@ -33,10 +35,12 @@ async def listen_for_disconnect(request: Request) -> None: while True: message = await request.receive() if message["type"] == "http.disconnect": - if request.app.state.enable_server_load_tracking: - # on timeout/cancellation the BackgroundTask in load_aware_call - # cannot decrement the server load metrics. - # Must be decremented by with_cancellation instead. + # If load tracking is enabled *and* the counter exists, decrement + # it. Combines the previous nested checks into a single condition + # to satisfy the linter rule. + if (getattr(request.app.state, "enable_server_load_tracking", + False) + and hasattr(request.app.state, "server_load_metrics")): request.app.state.server_load_metrics -= 1 break @@ -101,9 +105,14 @@ async def wrapper(*args, **kwargs): raise ValueError( "raw_request required when server load tracking is enabled") - if not raw_request.app.state.enable_server_load_tracking: + if not getattr(raw_request.app.state, "enable_server_load_tracking", + False): return await func(*args, **kwargs) + # ensure the counter exists + if not hasattr(raw_request.app.state, "server_load_metrics"): + raw_request.app.state.server_load_metrics = 0 + raw_request.app.state.server_load_metrics += 1 try: response = await func(*args, **kwargs) @@ -183,6 +192,24 @@ def _validate_truncation_size( return truncate_prompt_tokens +def _output_with_pager(text: str): + """Output text using scrolling view if available and appropriate.""" + + pagers = ['less -R', 'more'] + for pager_cmd in pagers: + try: + proc = subprocess.Popen(pager_cmd.split(), + stdin=subprocess.PIPE, + text=True) + proc.communicate(input=text) + return + except (subprocess.SubprocessError, OSError, FileNotFoundError): + continue + + # No pager worked, fall back to normal print + print(text) + + def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, subcommand_name: list[str]): @@ -201,16 +228,24 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, if arg.startswith('--help='): search_keyword = arg.split('=', 1)[1] + # Enable paged view for full help + if search_keyword == 'page': + help_text = parser.format_help() + _output_with_pager(help_text) + sys.exit(0) + # List available groups if search_keyword == 'listgroup': - print("\nAvailable argument groups:") + output_lines = ["\nAvailable argument groups:"] for group in parser._action_groups: if group.title and not group.title.startswith( "positional arguments"): - print(f" - {group.title}") + output_lines.append(f" - {group.title}") if group.description: - print(" " + group.description.strip()) - print() + output_lines.append(" " + + group.description.strip()) + output_lines.append("") + _output_with_pager("\n".join(output_lines)) sys.exit(0) # For group search @@ -222,7 +257,7 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, formatter.add_text(group.description) formatter.add_arguments(group._group_actions) formatter.end_section() - print(formatter.format_help()) + _output_with_pager(formatter.format_help()) sys.exit(0) # For single arg @@ -236,10 +271,10 @@ def show_filtered_argument_or_group_from_help(parser: argparse.ArgumentParser, matched_actions.append(action) if matched_actions: - print(f"\nParameters matching '{search_keyword}':\n") + header = f"\nParameters matching '{search_keyword}':\n" formatter = parser._get_formatter() formatter.add_arguments(matched_actions) - print(formatter.format_help()) + _output_with_pager(header + formatter.format_help()) sys.exit(0) print(f"\nNo group or parameter matching '{search_keyword}'") diff --git a/vllm/envs.py b/vllm/envs.py old mode 100644 new mode 100755 index 0cc6792d72bb..5c414e82d93b --- a/vllm/envs.py +++ b/vllm/envs.py @@ -42,9 +42,9 @@ VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None - VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_CPU_KVCACHE_SPACE: Optional[int] = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" - VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0 + VLLM_CPU_NUM_OF_RESERVED_CPU: Optional[int] = None VLLM_CPU_MOE_PREPACK: bool = True VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") @@ -61,6 +61,7 @@ VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25 VLLM_VIDEO_LOADER_BACKEND: str = "opencv" VLLM_MM_INPUT_CACHE_GIB: int = 8 VLLM_TARGET_DEVICE: str = "cuda" @@ -94,7 +95,6 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True - VLLM_QUARK_EMU_MEM_OPT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -107,8 +107,6 @@ VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_BUNDLE_INDICES: str = "" VLLM_CUDART_SO_PATH: Optional[str] = None - VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH: bool = True - VLLM_HPU_USE_DELAYED_SAMPLING: bool = False VLLM_DP_RANK: int = 0 VLLM_DP_RANK_LOCAL: int = -1 VLLM_DP_SIZE: int = 1 @@ -118,9 +116,12 @@ VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False + VLLM_V1_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None VLLM_USE_DEEP_GEMM: bool = False + VLLM_USE_FLASHINFER_MOE_FP8: bool = False + VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -138,6 +139,10 @@ VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE" VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None + VLLM_NIXL_ABORT_REQUEST_TIMEOUT: int = 120 + VLLM_USE_CUDNN_PREFILL: bool = False + VLLM_ENABLE_CUDAGRAPH_GC: bool = False + VLLM_LOOPBACK_IP: str = "" def get_default_cache_root(): @@ -427,9 +432,10 @@ def get_vllm_port() -> Optional[int]: lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), # (CPU backend only) CPU key-value cache space. - # default is 4 GiB + # default is None and will be set as 4 GB "VLLM_CPU_KVCACHE_SPACE": - lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) + if "VLLM_CPU_KVCACHE_SPACE" in os.environ else None, # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. @@ -439,7 +445,8 @@ def get_vllm_port() -> Optional[int]: # (CPU backend only) CPU cores not used by OMP threads . # Those CPU cores will not be used by OMP threads of a rank. "VLLM_CPU_NUM_OF_RESERVED_CPU": - lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")), + lambda: int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0")) + if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None, # (CPU backend only) whether to use prepack for MoE layer. This will be # passed to ipex.llm.modules.GatedMLPMOE. On unsupported CPUs, you might @@ -513,6 +520,12 @@ def get_vllm_port() -> Optional[int]: "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10")), + # Maximum filesize in MB for a single audio file when processing + # speech-to-text requests. Files larger than this will be rejected. + # Default is 25 MB + "VLLM_MAX_AUDIO_CLIP_FILESIZE_MB": + lambda: int(os.getenv("VLLM_MAX_AUDIO_CLIP_FILESIZE_MB", "25")), + # Backend for Video IO # - "opencv": Default backend that uses OpenCV stream buffered backend. # @@ -722,14 +735,6 @@ def get_vllm_port() -> Optional[int]: lambda: maybe_convert_int( os.environ.get("VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB", None)), - # If set, when running in Quark emulation mode, do not dequantize the - # weights at load time. Instead, dequantize weights on-the-fly during - # kernel execution. - # This allows running larger models at the cost of slower inference. - # This flag has no effect when not running in Quark emulation mode. - "VLLM_QUARK_EMU_MEM_OPT": - lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))), - # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), @@ -785,19 +790,6 @@ def get_vllm_port() -> Optional[int]: "VLLM_CUDART_SO_PATH": lambda: os.getenv("VLLM_CUDART_SO_PATH", None), - # Contiguous cache fetching to avoid using costly gather operation on - # Gaudi3. This is only applicable to HPU contiguous cache. If set to true, - # contiguous cache fetch will be used. - "VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH": - lambda: os.environ.get("VLLM_CONTIGUOUS_PA", "true").lower() in - ("1", "true"), - - # Use delayed sampling for HPU to reduce host cpu overhead - # between each step. - "VLLM_HPU_USE_DELAYED_SAMPLING": - lambda: os.environ.get("VLLM_DELAYED_SAMPLING", "false").lower() in - ("1", "true"), - # Rank of the process in the data parallel setting "VLLM_DP_RANK": lambda: int(os.getenv("VLLM_DP_RANK", "0")), @@ -855,6 +847,12 @@ def get_vllm_port() -> Optional[int]: "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", + # Whether to turn on the outlines cache for V1 + # This cache is unbounded and on disk, so it's not safe to use in + # an environment with potentially malicious users. + "VLLM_V1_USE_OUTLINES_CACHE": + lambda: os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1", + # Gap between padding buckets for the forward pass. So we have # 8, we will run forward pass with [16, 24, 32, ...]. "VLLM_TPU_BUCKET_PADDING_GAP": @@ -867,6 +865,14 @@ def get_vllm_port() -> Optional[int]: "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))), + # Allow use of FlashInfer MoE kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_FP8": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))), + + # Allow use of FlashInfer CUTLASS kernels for fused moe ops. + "VLLM_USE_FLASHINFER_MOE_FP4": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))), + # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. @@ -953,7 +959,32 @@ def get_vllm_port() -> Optional[int]: # generations on machines < 100 for compressed-tensors # models "VLLM_USE_NVFP4_CT_EMULATIONS": - lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))) + lambda: bool(int(os.getenv("VLLM_USE_NVFP4_CT_EMULATIONS", "0"))), + + # Time (in seconds) after which the KV cache on the producer side is + # automatically cleared if no READ notification is received from the + # consumer. This is only applicable when using NixlConnector in a + # disaggregated decode-prefill setup. + "VLLM_NIXL_ABORT_REQUEST_TIMEOUT": + lambda: int(os.getenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", "120")), + + # Controls whether or not to use cudnn prefill + "VLLM_USE_CUDNN_PREFILL": + lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))), + + # If set to 1, use the TRTLLM Decode Attention backend in flashinfer. + "VLLM_USE_TRTLLM_DECODE_ATTENTION": + lambda: os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", None), + + # Controls garbage collection during CUDA graph capture. + # If set to 0 (default), enables GC freezing to speed up capture time. + # If set to 1, allows GC to run during capture. + "VLLM_ENABLE_CUDAGRAPH_GC": + lambda: bool(int(os.getenv("VLLM_ENABLE_CUDAGRAPH_GC", "0"))), + + # Used to force set up loopback IP + "VLLM_LOOPBACK_IP": + lambda: os.getenv("VLLM_LOOPBACK_IP", ""), } # --8<-- [end:env-vars-definition] diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 99e12201c96a..483fdb1486f7 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -4,6 +4,7 @@ import asyncio import time from abc import ABC, abstractmethod +from functools import cached_property from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union) @@ -15,7 +16,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.pooling_params import PoolingTask from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async from vllm.worker.worker_base import WorkerBase @@ -48,7 +49,6 @@ def __init__( self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self._init_executor() self.is_sleeping = False @@ -135,6 +135,11 @@ def rpc_func(worker: WorkerBase) -> _R: return self.collective_rpc(rpc_func) + @cached_property # Avoid unnecessary RPC calls + def supported_pooling_tasks(self) -> tuple[PoolingTask, ...]: + output = self.collective_rpc("get_supported_pooling_tasks") + return tuple({task for tasks in output for task in tasks}) + def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: @@ -164,35 +169,6 @@ def list_loras(self) -> Set[int]: assert s == sets[0], "All workers should have the same LORAs." return sets[0] - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - assert prompt_adapter_request.prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return all( - self.collective_rpc("add_prompt_adapter", - args=(prompt_adapter_request, ))) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return all( - self.collective_rpc("remove_prompt_adapter", - args=(prompt_adapter_id, ))) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - assert prompt_adapter_id > 0, \ - "prompt_adapter_id must be greater than 0." - return all( - self.collective_rpc("pin_prompt_adapter", - args=(prompt_adapter_id, ))) - - def list_prompt_adapters(self) -> Set[int]: - sets = self.collective_rpc("list_prompt_adapters") - for s in sets: - assert (s == sets[0] - ), "All workers should have the same prompt adapters." - return sets[0] - def start_profile(self) -> None: self.collective_rpc("start_profile") diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 84e8ddd8e274..e9ad62aeb99a 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio -import json import os from collections import defaultdict from dataclasses import dataclass @@ -20,6 +19,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform +from vllm.ray.ray_env import get_env_vars_to_copy from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, make_async) @@ -58,28 +58,20 @@ class RayDistributedExecutor(DistributedExecutorBase): "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" } - config_home = envs.VLLM_CONFIG_ROOT - # This file contains a list of env vars that should not be copied - # from the driver to the Ray workers. - non_carry_over_env_vars_file = os.path.join( - config_home, "ray_non_carry_over_env_vars.json") - if os.path.exists(non_carry_over_env_vars_file): - with open(non_carry_over_env_vars_file) as f: - non_carry_over_env_vars = set(json.load(f)) - else: - non_carry_over_env_vars = set() + # These non-vLLM env vars are copied from the driver to workers + ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"} uses_ray: bool = True def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1 and not current_platform.is_xpu(): + if envs.VLLM_USE_V1: # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" - # For TPU, avoid compiling NVIDIA's NCCL - if current_platform.is_tpu(): + # For TPU or XPU, avoid compiling NVIDIA's NCCL + if current_platform.is_tpu() or current_platform.is_xpu(): os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm" # If the env var is set, it uses the Ray's compiled DAG API @@ -335,13 +327,11 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): } for (node_id, _) in worker_node_and_gpu_ids] # Environment variables to copy from driver to workers - env_vars_to_copy = [ - v for v in envs.environment_variables - if v not in self.WORKER_SPECIFIC_ENV_VARS - and v not in self.non_carry_over_env_vars - ] - - env_vars_to_copy.extend(current_platform.additional_env_vars) + env_vars_to_copy = get_env_vars_to_copy( + exclude_vars=self.WORKER_SPECIFIC_ENV_VARS, + additional_vars=set(current_platform.additional_env_vars).union( + self.ADDITIONAL_ENV_VARS), + destination="workers") # Copy existing env vars to each worker's args for args in all_args_to_update_environment_variables: @@ -350,15 +340,6 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): if name in os.environ: args[name] = os.environ[name] - logger.info("non_carry_over_env_vars from config: %s", - self.non_carry_over_env_vars) - logger.info( - "Copying the following environment variables to workers: %s", - [v for v in env_vars_to_copy if v in os.environ]) - logger.info( - "If certain env vars should NOT be copied to workers, add them to " - "%s file", self.non_carry_over_env_vars_file) - self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index c222f1609096..033ecc00853b 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -145,7 +145,9 @@ def override_env_vars(self, vars: Dict[str, str]): except ImportError as e: ray = None # type: ignore - ray_import_err = e + # only capture string to avoid variable references in the traceback that can + # prevent garbage collection in some cases + ray_import_err = str(e) RayWorkerWrapper = None # type: ignore @@ -157,8 +159,8 @@ def ray_is_available() -> bool: def assert_ray_available(): """Raise an exception if Ray is not available.""" if ray is None: - raise ValueError("Failed to import Ray, please install Ray with " - "`pip install ray`.") from ray_import_err + raise ValueError(f"Failed to import Ray: {ray_import_err}." + "Please install Ray with `pip install ray`.") def _verify_bundles(placement_group: "PlacementGroup", diff --git a/vllm/executor/uniproc_executor.py b/vllm/executor/uniproc_executor.py index 7ebeb4a22556..aabc9ed9b80a 100644 --- a/vllm/executor/uniproc_executor.py +++ b/vllm/executor/uniproc_executor.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, run_method) +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -62,6 +63,14 @@ def check_health(self) -> None: # it's running. return + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self.driver_worker.reinitialize_distributed(reconfig_request) + if reconfig_request.new_data_parallel_rank == \ + ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + self.shutdown() + return + UniProcExecutorAsync = UniProcExecutor diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index deda9bc23daf..de5dc0876651 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -13,7 +13,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup @@ -168,18 +167,6 @@ def _prepare_decoder_input_ids_for_generation( return decoder_input_ids - def _apply_prompt_adapter( - self, - prompt_token_ids: list[int], - prompt_adapter_request: Optional[PromptAdapterRequest], - ) -> list[int]: - if prompt_adapter_request: - prompt_token_ids = ( - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens - + prompt_token_ids) - - return prompt_token_ids - def _get_tokenization_kw( self, overrides: Optional[dict[str, Any]] = None, @@ -786,15 +773,10 @@ async def _process_encoder_decoder_prompt_async( def _build_decoder_only_llm_inputs( self, prompt_inputs: DecoderOnlyInputs, - prompt_adapter_request: Optional[PromptAdapterRequest], ) -> DecoderOnlyInputs: if "prompt_token_ids" in prompt_inputs: prompt_inputs = cast(Union[TokenInputs, MultiModalInputs], prompt_inputs) # Needed for mypy - prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( - prompt_inputs["prompt_token_ids"], - prompt_adapter_request=prompt_adapter_request, - ) return prompt_inputs @@ -803,7 +785,6 @@ def _process_decoder_only_prompt( prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: """ @@ -815,7 +796,6 @@ def _process_decoder_only_prompt( * prompt: input prompt * lora_request - * prompt_adapter_request * return_mm_hashes Returns: @@ -830,17 +810,13 @@ def _process_decoder_only_prompt( return_mm_hashes=return_mm_hashes, ) - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) + return self._build_decoder_only_llm_inputs(prompt_comps) async def _process_decoder_only_prompt_async( self, prompt: SingletonPrompt, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, ) -> DecoderOnlyInputs: """ @@ -854,17 +830,13 @@ async def _process_decoder_only_prompt_async( return_mm_hashes=return_mm_hashes, ) - return self._build_decoder_only_llm_inputs( - prompt_comps, - prompt_adapter_request=prompt_adapter_request, - ) + return self._build_decoder_only_llm_inputs(prompt_comps) def preprocess( self, prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, ) -> ProcessorInputs: """Preprocess the input prompt.""" @@ -886,7 +858,6 @@ def preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, return_mm_hashes=return_mm_hashes, ) @@ -895,7 +866,6 @@ async def preprocess_async( prompt: PromptType, tokenization_kwargs: Optional[dict[str, Any]] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, return_mm_hashes: bool = False, ) -> ProcessorInputs: """ @@ -919,6 +889,5 @@ async def preprocess_async( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, return_mm_hashes=return_mm_hashes, ) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 2f0913683c6d..652136fbbfe7 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -5,15 +5,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch -from packaging.version import Version from transformers import BatchFeature, PretrainedConfig, ProcessorMixin -from transformers import __version__ as TRANSFORMERS_VERSION from typing_extensions import TypeVar from vllm.jsontree import JSONTree, json_map_leaves from vllm.logger import init_logger from vllm.transformers_utils.processor import cached_processor_from_config -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import resolve_mm_processor_kwargs if TYPE_CHECKING: @@ -21,6 +18,14 @@ from vllm.multimodal import (MultiModalDataDict, MultiModalPlaceholderDict, MultiModalRegistry) from vllm.sequence import SequenceData + from vllm.transformers_utils.tokenizer import AnyTokenizer +else: + ModelConfig = Any + MultiModalDataDict = Any + MultiModalPlaceholderDict = Any + MultiModalRegistry = Any + SequenceData = Any + AnyTokenizer = Any _T = TypeVar("_T") _C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig) @@ -36,7 +41,7 @@ class InputContext: modify the inputs. """ - model_config: "ModelConfig" + model_config: ModelConfig """The configuration of the model.""" def get_hf_config( @@ -130,13 +135,9 @@ def get_hf_processor( /, **kwargs: object, ) -> _P: - # Transformers 4.53.0 has issue with passing tokenizer to - # initialize processor. We disable it for this version. - # See: https://github.com/vllm-project/vllm/issues/20224 - if Version(TRANSFORMERS_VERSION) < Version("4.53.0"): - kwargs["tokenizer"] = self.tokenizer return super().get_hf_processor( typ, + tokenizer=self.tokenizer, **kwargs, ) @@ -200,9 +201,9 @@ class DummyData(NamedTuple): Note: This is only used in V0. """ - seq_data: "SequenceData" - multi_modal_data: Optional["MultiModalDataDict"] = None - multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None + seq_data: SequenceData + multi_modal_data: Optional[MultiModalDataDict] = None + multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None class InputRegistry: @@ -212,9 +213,9 @@ class InputRegistry: def dummy_data_for_profiling( self, - model_config: "ModelConfig", + model_config: ModelConfig, seq_len: int, - mm_registry: "MultiModalRegistry", + mm_registry: MultiModalRegistry, is_encoder_data: bool = False, ) -> DummyData: """ diff --git a/vllm/logger.py b/vllm/logger.py index 0ddb83cb8ba7..69aaf4390a7d 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -53,6 +53,12 @@ } +@lru_cache +def _print_debug_once(logger: Logger, msg: str, *args: Hashable) -> None: + # Set the stacklevel to 2 to print the original caller's line info + logger.debug(msg, *args, stacklevel=2) + + @lru_cache def _print_info_once(logger: Logger, msg: str, *args: Hashable) -> None: # Set the stacklevel to 2 to print the original caller's line info @@ -74,6 +80,13 @@ class _VllmLogger(Logger): `intel_extension_for_pytorch.utils._logger`. """ + def debug_once(self, msg: str, *args: Hashable) -> None: + """ + As [`debug`][logging.Logger.debug], but subsequent calls with + the same message are silently dropped. + """ + _print_debug_once(self, msg, *args) + def info_once(self, msg: str, *args: Hashable) -> None: """ As [`info`][logging.Logger.info], but subsequent calls with @@ -132,6 +145,7 @@ def init_logger(name: str) -> _VllmLogger: logger = logging.getLogger(name) methods_to_patch = { + "debug_once": _print_debug_once, "info_once": _print_info_once, "warning_once": _print_warning_once, } diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3d0c58317502..c3512ec3dbd4 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -28,8 +28,6 @@ RowParallelLinear) # yapf: enable from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.rotary_embedding import ( - LinearScalingRotaryEmbedding, RotaryEmbedding) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.platforms import current_platform @@ -240,17 +238,19 @@ def set_lora( def forward(self, x: torch.Tensor) -> torch.Tensor: added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) - embeddings_indices = torch.narrow( - self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) - indices = embeddings_indices[1] + # NB: Don't use torch.narrow here. torch.narrow triggers some + # Dynamic Shape specialization in torch.compile + num_tokens = x.shape[0] + indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens] + indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens] + full_lora_a_embeddings = F.embedding( - x + indices, + x + indices_1, self.lora_a_stacked_2d, ) - indices = embeddings_indices[0] full_output = self.base_layer.forward(x + - (indices * added_tokens_mask)) + (indices_0 * added_tokens_mask)) full_output_org = full_output if full_output.ndim == 3: @@ -1162,10 +1162,6 @@ def _get_logits( posinf=pos_inf, neginf=neg_inf)) - # HPU needs special handling to prune out dummy samples. - if current_platform.is_hpu(): - lora_logits = lora_logits[:logits.shape[0], :] - logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits @@ -1195,91 +1191,3 @@ def can_replace_layer( ) -> bool: # Special handling for the LogitsProcessor. return False - - -class LinearScalingRotaryEmbeddingWithLoRA(BaseLayerWithLoRA): - """Implements RoPE-scaled embeddings with linear scaling for - multiple LoRA adapters with a specialized kernel. - - Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding - which can handle multi lora adapters in a specialized kernel. - """ - - def __init__(self, base_layer: RotaryEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - - @property - def scaling_factors(self): - return self.base_layer.scaling_factors - - @property - def rotary_dim(self): - return self.base_layer.rotary_dim - - def create_lora_weights( - self, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, - ) -> None: - scaling_factors = (list(lora_config.long_lora_scaling_factors) - if lora_config.long_lora_scaling_factors else []) - base_scaling_factor = (self.base_layer.scaling_factor if isinstance( - self.base_layer, LinearScalingRotaryEmbedding) else 1.0) - scaling_factors = sorted( - list(set([base_scaling_factor] + scaling_factors))) - self.base_layer = LinearScalingRotaryEmbedding( - self.base_layer.head_size, - self.base_layer.rotary_dim, - self.base_layer.max_position_embeddings, - self.base_layer.base, - self.base_layer.is_neox_style, - scaling_factors, - self.base_layer.dtype, - ) - - def reset_lora(self, index: int): - ... - - def set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: Optional[torch.Tensor], - bias: Optional[torch.Tensor] = None, - ): - ... - - def forward( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - return self.base_layer( - positions, - query, - key, - offsets=self.punica_wrapper.long_lora_indices, - ) - - @property - def scaling_factor_to_offset(self) -> dict[float, int]: - return self.base_layer.scaling_factor_to_offset - - @classmethod - def can_replace_layer( - cls, - source_layer: nn.Module, - lora_config: LoRAConfig, - packed_modules_list: list, - model_config: Optional[PretrainedConfig], - ) -> bool: - """Returns True if the layer can be replaced by this LoRA layer.""" - return (type(source_layer) is LinearScalingRotaryEmbedding - or type(source_layer) is RotaryEmbedding) - - def extra_repr(self) -> str: - return self.base_layer.extra_repr() diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 9e1ed3a77179..e6b19d4748f4 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,6 @@ import math import os from collections.abc import Sequence -from dataclasses import dataclass, field from typing import Any, Callable, Optional, Union import regex as re @@ -19,9 +18,7 @@ remove_adapter, set_adapter_mapping) from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import (BaseLayerWithLoRA, - LinearScalingRotaryEmbeddingWithLoRA, - LoRAMapping) +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.peft_helper import PEFTHelper from vllm.lora.punica_wrapper import get_punica_wrapper @@ -29,6 +26,7 @@ get_supported_lora_modules, is_regex_target_modules, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.interfaces import is_pooling_model @@ -42,24 +40,23 @@ _GLOBAL_LORA_ID = 0 -@dataclass -class LongContextLoRAContext: - """Context for lora adapters that support long context.""" - # The scaling factors to support long context lora fine tuned models. - scaling_factors: list[float] - # dimension to apply rotary embedding. - rot_dim: int - # offsets to the sin_cos_cache for each lora_id loaded. - # This value is dynamically modified. - offsets_by_lora_id: dict[int, int] = field(default_factory=dict) - - def get_lora_id(): global _GLOBAL_LORA_ID _GLOBAL_LORA_ID += 1 return _GLOBAL_LORA_ID +def is_moe_model(model: nn.Module) -> bool: + """Checks if the model contains FusedMoE layers and warns the user.""" + if any(isinstance(module, FusedMoE) for module in model.modules()): + logger.warning_once( + "For MoE models, vLLM currently does not support fused MoE LoRA " + "inference. Please ensure that the loaded LoRA model does not " + "contain expert weights.") + return True + return False + + class LoRAModel(AdapterModel): """A LoRA fine-tuned model.""" @@ -68,20 +65,16 @@ def __init__( lora_model_id: int, rank: int, loras: dict[str, LoRALayerWeights], - scaling_factor: Optional[float] = None, ) -> None: """ Args: lora_model_id: The integer id for the lora model. rank: lora rank. loras: module name -> weights for lora-replaced layers. - scaling_factor: Scaling factor to support long context lora model. - None if the lora is not tuned for long context support. + """ self.id = lora_model_id - # Scaling factor for long context lora model. None if it is not - # fine tuned for the long context. - self.scaling_factor = scaling_factor + assert ( lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" @@ -180,10 +173,7 @@ def from_lora_tensors( for lora in loras.values(): lora.optimize() - return cls(lora_model_id, - peft_helper.r, - loras, - scaling_factor=peft_helper.vllm_long_context_scaling_factor) + return cls(lora_model_id, peft_helper.r, loras) @classmethod def from_local_checkpoint( @@ -245,9 +235,10 @@ def check_unexpected_modules(modules: dict): lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir, "adapter_model.tensors") tensorizer_args = tensorizer_config._construct_tensorizer_args() - tensors = TensorDeserializer(lora_tensor_path, - dtype=tensorizer_config.dtype, - **tensorizer_args.deserializer_params) + tensors = TensorDeserializer( + lora_tensor_path, + dtype=tensorizer_config.dtype, + **tensorizer_args.deserialization_kwargs) check_unexpected_modules(tensors) elif os.path.isfile(lora_tensor_path): @@ -347,24 +338,17 @@ def __init__( self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size - self.long_lora_context: Optional[LongContextLoRAContext] = None self.punica_wrapper = get_punica_wrapper( max_num_batched_tokens, max_batches=self.max_num_seqs, device=self.device, max_loras=self.lora_config.max_loras) - # Scaling factor -> offset to the sin_cos_cache to it. - # Used for long context lora. - self.scaling_factor_to_offset: dict[float, int] = {} + super().__init__(model) self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" f" {self.model.__class__.__name__}." - if lora_config.long_lora_scaling_factors: - # We need to replace rotary emb layer to do batch computation - # for long lora. - self.supported_lora_modules.append("rotary_emb") self.packed_modules_mapping = get_packed_modules_mapping(self.model) # Used to indicate whether the model is a multimodal model @@ -374,6 +358,7 @@ def __init__( # text modules (e.g. ChatGLM) and hasattr(self.model, "get_mm_mapping")) self.is_pooling_model = is_pooling_model(self.model) + self.is_moe_model = is_moe_model(self.model) self.packed_modules: dict[str, list[str]] = {} self.modules: dict[str, BaseLayerWithLoRA] = {} # Dict instead of a set for compatibility with LRUCache. @@ -440,25 +425,9 @@ def _deactivate_adapter(self, lora_id: int): except ValueError: pass - def _set_long_lora_context(self, lora: LoRAModel): - if self.long_lora_context is None: - return - - if lora.scaling_factor is None: - return - - if (lora.scaling_factor not in self.scaling_factor_to_offset): - raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" - " has not been initialized.") - - offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) - if offsets: - self.long_lora_context.offsets_by_lora_id[lora.id] = offsets - def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) self._registered_adapters[lora.id] = lora - self._set_long_lora_context(lora) def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" @@ -474,7 +443,6 @@ def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: self.lora_slots + 1, self.vocab_size, self.lora_config.lora_extra_vocab_size, - self.long_lora_context, ) def remove_all_adapters(self): @@ -484,6 +452,14 @@ def remove_all_adapters(self): self._active_adapters.clear() def _create_lora_modules(self): + + def _parent_module(module_name: str) -> str: + # module name is a dot separated name. + # for example: + # - given an input 'x.y.z' return 'x.y' + # - given an input 'x' return '' + return module_name.rpartition('.')[0] + for module_name, module in self.model.named_modules( remove_duplicate=False): if isinstance(module, PPMissingLayer): @@ -506,19 +482,19 @@ def _create_lora_modules(self): from_layer(module, self.lora_slots, self.lora_config, packed_moduled_lst, self.model.config)) - # LinearScalingRotaryEmbeddingWithLoRA is used to handle - # long context lora. Register relevant metadata. - if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA): - self.long_lora_context = LongContextLoRAContext( - new_module.scaling_factors, new_module.rotary_dim) - self.scaling_factor_to_offset = \ - new_module.scaling_factor_to_offset # (yard1): TODO make this more robust if "lm_head" in module_name: + logits_processor_module_name = 'logits_processor' + parent_module = _parent_module(module_name) + if parent_module: + logits_processor_module_name = ( + f"{parent_module}.{logits_processor_module_name}") + logits_processor_module = self.model.get_submodule( - "logits_processor") + logits_processor_module_name) + new_module = replace_submodule( - self.model, "logits_processor", + self.model, logits_processor_module_name, from_layer_logits_processor(logits_processor_module, module, self.lora_slots, self.lora_config, @@ -545,15 +521,13 @@ def create_dummy_lora( self, lora_id: int, rank: int, - scaling_factor: Optional[float], embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel: """Create zero-initialized LoRAModel for warmup.""" - model = LoRAModel(lora_id, rank, {}, scaling_factor) + model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): bias_enabled = self.lora_config.bias_enabled if (not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) - or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA) or self._filter_unsupported_mm_module(module_name)): continue parts = module_name.split(".") @@ -694,11 +668,8 @@ def deactivate_adapter(self, adapter_id: int) -> bool: self._deactivate_adapter) def add_adapter(self, adapter: LoRAModel) -> bool: - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", adapter.id, adapter.id, - adapter.scaling_factor) + logger.debug("Adding lora. Model id: %d, " + "int id: %d", adapter.id, adapter.id) return add_adapter(adapter, self._registered_adapters, self.capacity, self._add_adapter) @@ -743,10 +714,8 @@ def list_adapters(self) -> dict[int, LoRAModel]: def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) + logger.debug("Adding lora. Model id: %d, " + "int id: %d", lora.id, lora.id) if lora.id not in self._registered_adapters: self._add_adapter(lora) was_added = True diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 9e1f90e757cd..b1ab84e08ba7 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -8,11 +8,11 @@ """ import torch -import triton -import triton.language as tl from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -283,6 +283,7 @@ def _lora_expand_fake( op_func=_lora_expand, mutates_args=["output_tensor"], fake_impl=_lora_expand_fake, + dispatch_key=current_platform.dispatch_key, ) lora_expand = torch.ops.vllm.lora_expand diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 3f9edfc6d655..1e7075ab0715 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -8,11 +8,11 @@ """ import torch -import triton -import triton.language as tl from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import direct_register_custom_op @@ -237,6 +237,7 @@ def _lora_shrink_fake( op_func=_lora_shrink, mutates_args=["output_tensor"], fake_impl=_lora_shrink_fake, + dispatch_key=current_platform.dispatch_key, ) lora_shrink = torch.ops.vllm.lora_shrink diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 5857f7fecb5b..4c50fbd27051 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -35,7 +35,9 @@ def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device): lora_strides_d1.append(lora_a_weight.stride(1)) lora_strides_d2.append(lora_a_weight.stride(2)) if len(lora_a_weights) > 1: - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + lora_ptr_tensor = torch.tensor(tensor_ptrs, + device=device, + dtype=torch.uint64) else: lora_ptr_tensor = lora_a_weights[0] @@ -89,8 +91,12 @@ def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int, if len(lora_weights) > 1: # note these are device tensors - lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) - slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + lora_ptr_tensor = torch.tensor(tensor_ptrs, + device=device, + dtype=torch.uint64) + slice_start_tensor = torch.tensor(slice_offset_lst, + device=device, + dtype=torch.uint64) else: slice_start_tensor = slice_offset_lst[0] lora_ptr_tensor = lora_b_weight[0] diff --git a/vllm/lora/peft_helper.py b/vllm/lora/peft_helper.py index a20d73f0f725..8b8e5cb7d5fa 100644 --- a/vllm/lora/peft_helper.py +++ b/vllm/lora/peft_helper.py @@ -35,12 +35,9 @@ class PEFTHelper: use_rslora: bool = field(default=False) # True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353) use_dora: bool = field(default=False) - # long context lora field - context_length: int = field(default=0) # Extra vllm field, start with 'vllm_' to avoid conflict vllm_lora_scaling_factor: float = field(default=1.0) vllm_max_position_embeddings: Optional[int] = field(default=False) - vllm_long_context_scaling_factor: Optional[float] = field(default=None) def _validate_features(self) -> list[str]: """ @@ -59,12 +56,6 @@ def __post_init__(self): self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r) else: self.vllm_lora_scaling_factor = self.lora_alpha / self.r - if self.context_length: - if self.vllm_max_position_embeddings is None: - self.vllm_max_position_embeddings = self.context_length - self.vllm_long_context_scaling_factor = float( - math.ceil(self.context_length / - self.vllm_max_position_embeddings)) @classmethod def from_dict(cls, config_dict: dict) -> "PEFTHelper": @@ -102,15 +93,15 @@ def from_local_dir( tensorizer_config = TensorizerConfig(**tensorizer_config_dict) tensorizer_args = tensorizer_config._construct_tensorizer_args() from tensorizer.stream_io import open_stream - lora_config_path = os.path.join(tensorizer_config.lora_dir, + lora_config_path = os.path.join(tensorizer_config.tensorizer_dir, "adapter_config.json") with open_stream(lora_config_path, mode="rb", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: config = json.load(f) logger.info("Successfully deserialized LoRA config from %s", - tensorizer_config.lora_dir) + tensorizer_config.tensorizer_dir) else: with open(lora_config_path) as f: diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 5b4902dcbeb3..b3413de1c816 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -17,7 +17,6 @@ if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext class PunicaWrapperABC(ABC): @@ -33,7 +32,6 @@ def update_metadata( max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, **kwargs, ) -> None: """ @@ -144,14 +142,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, max_num_batched_tokens, dtype=torch.long, device=device) - self._long_lora_indices = torch.empty(max_num_batched_tokens, - dtype=torch.long, - device=device) - # 5 is the number of indices tensors. + # 4 is the number of indices tensors. # base_indices, sampler_indices, sampler_indices_padded, - # embeddings_indices,long_lora_indices - self.indices_len: list[Optional[int]] = [None] * 5 + # embeddings_indices + self.indices_len: list[Optional[int]] = [None] * 4 # these attributes are the information required for sgmv kernel self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, @@ -176,14 +171,12 @@ def _update_base_metadata( max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, ): ( base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, - long_lora_offsets_tensor, indices_len, ) = convert_mapping( mapping, @@ -192,7 +185,6 @@ def _update_base_metadata( vocab_size, extra_vocab_size, self.device, - long_lora_context, ) self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) @@ -201,11 +193,7 @@ def _update_base_metadata( self._embeddings_indices[:embeddings_indices. shape[0], :embeddings_indices.shape[1]].copy_( embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() + self.indices_len[:] = indices_len def _update_prefill_metadata(self, @@ -312,28 +300,13 @@ def embeddings_indices(self) -> torch.Tensor: embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] - @property - def long_lora_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLoRA. - """ - long_lora_len = self.indices_len[4] - return self._long_lora_indices[:long_lora_len] - - def update_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - **kwargs): + def update_metadata(self, mapping: "LoRAMapping", + lora_index_to_id: list[Optional[int]], max_loras: int, + vocab_size: int, extra_vocab_size: int, **kwargs): self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + vocab_size, extra_vocab_size) + if mapping.is_prefill: # Update metadata required for prefill-related operators. self._update_prefill_metadata(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 6b038309d55d..2db0e9fee142 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -7,7 +7,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import TYPE_CHECKING, Optional, Union, final +from typing import Optional, Union, final import torch @@ -21,10 +21,6 @@ from .punica_base import PunicaWrapperBase -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.models import LongContextLoRAContext - @final class PunicaWrapperGPU(PunicaWrapperBase): @@ -55,20 +51,13 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, max_num_prompts, device=device) - def update_metadata( - self, - mapping: LoRAMapping, - lora_index_to_id: list[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - **kwargs): + def update_metadata(self, mapping: LoRAMapping, + lora_index_to_id: list[Optional[int]], max_loras: int, + vocab_size: int, extra_vocab_size: int, **kwargs): self.is_prefill = mapping.is_prefill self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + vocab_size, extra_vocab_size) # Prepare cuda kernel metadata tensors self.token_mapping_meta.prepare_tensors(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_hpu.py b/vllm/lora/punica_wrapper/punica_hpu.py deleted file mode 100644 index b20c9785a74c..000000000000 --- a/vllm/lora/punica_wrapper/punica_hpu.py +++ /dev/null @@ -1,145 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import TYPE_CHECKING, Optional, Union, final - -import torch -from vllm_hpu_extension.ops import (dispatch_bgmv_embedding, - dispatch_bgmv_linear) - -from .punica_base import PunicaWrapperBase -from .utils import convert_mapping - -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext - - -@final -class PunicaWrapperHPU(PunicaWrapperBase): - - def __init__(self, max_num_batched_tokens: int, max_batches: int, - device: Union[torch.device, str], **kwargs): - # Increasing max_num_batched_tokens by 3x to handle increase in - # tensor size due to padding. - PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens, - max_batches, device) - - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: list[Optional[int]], - max_loras: int, - vocab_size: int, - extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, - ): - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - long_lora_offsets_tensor, - indices_len, - ) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size, - extra_vocab_size, self.device, None) - # Updating each element in `long_lora_offsets` with `lora_offset` slows - # down perf in HPU due to a series of `strided_insert` ops during lazy - # graph accumulation. Hence HPU appends `lora_offset` to a list and - # converts it to a tensor only after it is ready. - if long_lora_context: - index_mapping_indices: list[int] = list( - mapping.index_mapping).copy() - long_lora_offsets: list[int] = [] - for i in range(len(index_mapping_indices)): - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets.append(lora_offset) - long_lora_offsets_tensor = torch.tensor(long_lora_offsets, - device=self.device, - dtype=torch.long) - indices_len[-1] = long_lora_offsets_tensor.shape[-1] - - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) - self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) - if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) - else: - self._long_lora_indices.zero_() - self.indices_len[:] = indices_len - - def add_lora_embedding(self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs) -> None: - dispatch_bgmv_embedding(y, x, lora_b_stacked, 0) - - def add_lora_linear(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: Optional[tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: - y_org = y - x = x.view(-1, x.shape[-1]) - y = y.view(-1, y.shape[-1]) - offset_left = 0 - - for slice_idx in range(len(output_slices)): - dispatch_bgmv_linear( - y[:, offset_left:offset_left + output_slices[slice_idx]], x, - lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale) - offset_left += output_slices[slice_idx] - y = y.view_as(y_org) - - def add_lora_logits(self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale) - y = y.view_as(y_org) - - def add_shrink( - self, - y: Union[tuple[torch.Tensor, ...], torch.Tensor], - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, - **kwargs, - ) -> None: - raise NotImplementedError - - def add_expand( - self, - y: torch.Tensor, - x: Union[tuple[torch.Tensor, ...], torch.Tensor], - lora_b_stacked: tuple[torch.Tensor, ...], - lora_bias_stacked: Optional[tuple[torch.Tensor, ...]], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs, - ) -> None: - raise NotImplementedError diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py index 6b48268c5006..07dc337a1cc8 100644 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -14,7 +14,6 @@ if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext from .punica_base import PunicaWrapperBase @@ -45,7 +44,6 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True) torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) @@ -323,7 +321,6 @@ def _update_base_metadata( max_loras: int, vocab_size: int, extra_vocab_size: int, - long_lora_context: Optional["LongContextLoRAContext"] = None, ): # Make sure we don't accidentally collect outside operations xm.mark_step() @@ -339,7 +336,6 @@ def _update_base_metadata( sampler_indices, sampler_indices_padded, embeddings_indices, - long_lora_offsets_tensor, indices_len, ) = convert_mapping( mapping, @@ -348,7 +344,6 @@ def _update_base_metadata( vocab_size, extra_vocab_size, "cpu", - long_lora_context, ) self._token_lora_indices = self._pad_to_shape( base_indices, self._token_lora_indices.shape, @@ -362,15 +357,6 @@ def _update_base_metadata( self._embeddings_indices = self._pad_to_shape( embeddings_indices, self._embeddings_indices.shape, dims=2).to(self.device) - if long_lora_offsets_tensor is not None: - self._long_lora_indices = self._pad_to_shape( - long_lora_offsets_tensor, - self._long_lora_indices.shape, - dims=1).to(self.device) - else: - zeroed = torch.zeros_like(self._long_lora_indices.cpu(), - dtype=torch.int32) - self._long_lora_indices = zeroed.to(self.device) self.indices_len[:] = indices_len def _update_prefill_metadata(self, diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index 8430cb91865f..d22c29da1c61 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -8,7 +8,6 @@ if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping - from vllm.lora.models import LongContextLoRAContext def compute_meta( @@ -49,9 +48,7 @@ def convert_mapping( vocab_size: int, extra_vocab_size: int, device: torch.device, - long_lora_context: Optional["LongContextLoRAContext"] = None, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], list[int]]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]: """Converts LoRAMapping to index tensors. Args: @@ -60,7 +57,6 @@ def convert_mapping( max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. - long_lora_context: Passed if there are long context lora in a batch. Returns: A tuple of tensors: @@ -78,21 +74,14 @@ def convert_mapping( requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. - long_lora_indices: Tensor of shape [batch_size] mapping - requests to RoPE offsets and rot dims for long LoRAs. - None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. It contains (base_indices, sampler_indices, sampler_indices_padded, - embeddings_indices, long_lora_indices). + embeddings_indices). """ index_mapping_indices: list[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() - long_lora_offsets: Optional[torch.Tensor] = None - if long_lora_context: - long_lora_offsets = torch.zeros(len(index_mapping_indices), - device=device, - dtype=torch.long) + prompt_mapping: list[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping @@ -104,20 +93,13 @@ def convert_mapping( if index_mapping_indices[i] > 0 else -1) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx - if long_lora_context: - assert long_lora_offsets is not None - lora_offset: int = long_lora_context.offsets_by_lora_id.get( - index_mapping_indices[i], 0) - long_lora_offsets[i] = lora_offset indices_list: list[Union[list[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices, ] - if long_lora_context: - assert long_lora_offsets is not None - indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device=device) prompt_mapping_tensor = torch.tensor(prompt_mapping, dtype=torch.long, @@ -136,11 +118,7 @@ def convert_mapping( sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) - long_lora_indices = None - long_lora_indices_len: Optional[int] = None - if long_lora_context: - long_lora_indices = indices[3] - long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. indices_len = [ base_indices.shape[-1], @@ -148,17 +126,11 @@ def convert_mapping( sampler_indices_padded.shape[-1], embeddings_indices.shape[-1], ] - if long_lora_indices_len is not None: - indices_len.append(long_lora_indices_len) - else: - # If long_lora doesn't exist,append None - indices_len.append(None) return ( base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, - long_lora_indices, indices_len, ) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index ee196e3f689a..ab0a9fbd255d 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import huggingface_hub import regex as re @@ -22,7 +22,6 @@ # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, - LinearScalingRotaryEmbeddingWithLoRA, LogitsProcessorWithLoRA, MergedColumnParallelLinearWithLoRA, MergedQKVParallelLinearWithLoRA, @@ -31,10 +30,14 @@ RowParallelLinearWithLoRA, VocabParallelEmbeddingWithLoRA) from vllm.model_executor.layers.linear import LinearBase + # yapf: enable -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.models.utils import WeightsMapper + +if TYPE_CHECKING: + from vllm.model_executor.layers.logits_processor import LogitsProcessor + from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) + from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -52,7 +55,6 @@ MergedColumnParallelLinearWithShardedLoRA, MergedQKVParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA, - LinearScalingRotaryEmbeddingWithLoRA, } @@ -75,8 +77,8 @@ def from_layer(layer: nn.Module, def from_layer_logits_processor( - layer: LogitsProcessor, - lm_head: ParallelLMHead, + layer: "LogitsProcessor", + lm_head: "ParallelLMHead", max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, @@ -98,8 +100,8 @@ def replace_submodule(model: nn.Module, module_name: str, def parse_fine_tuned_lora_name( - name: str, - weights_mapper: Optional[WeightsMapper] = None + name: str, + weights_mapper: Optional["WeightsMapper"] = None ) -> tuple[str, bool, bool]: """Parse the name of lora weights. @@ -184,16 +186,20 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: """ In vLLM, all linear layers support LoRA. """ + supported_lora_modules: set[str] = set() - # step1: traverse the model to get all the linear subfixes. for name, module in model.named_modules(): + # get the embedding modules if the module's embedding_modules + # is not empty. + embedding_modules = getattr(module, "embedding_modules", None) + if embedding_modules is not None: + for name in embedding_modules: + supported_lora_modules.add(name) + + # get all the linear subfixes. if isinstance(module, (LinearBase, )): supported_lora_modules.add(name.split(".")[-1]) - # step 2: get the embedding modules if the model's mbedding_modules - # is not empty. - if model.embedding_modules: - for name in model.embedding_modules: - supported_lora_modules.add(name) + return list(supported_lora_modules) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 7a4af74cbeb1..248d2954f1ef 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -154,7 +154,7 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: lora_request.lora_int_id) else: dummy_lora = self._adapter_manager.create_dummy_lora( - lora_request.lora_int_id, rank, 1, self.embedding_modules) + lora_request.lora_int_id, rank, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) diff --git a/vllm/mocks/__init__.py b/vllm/mocks/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/mocks/mock_nixl_connector.py b/vllm/mocks/mock_nixl_connector.py new file mode 100644 index 000000000000..54e2c5ee3b0a --- /dev/null +++ b/vllm/mocks/mock_nixl_connector.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import uuid +from collections import defaultdict +from typing import Optional + + +class FakeNixlWrapper: + """Mock implementation of NixlWrapper for testing. + + We don't inherit from nixl._api.nixl_agent because nixl may not be + installed. + """ + + AGENT_METADATA = b"fake_agent_metadata" + REMOTE_AGENT_NAME = "remote_agent" + + def __init__(self, agent_name: str, *args, **kwargs): + self._cycles_before_xfer_done = 0 + self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict( + lambda: 0) + + def get_reg_descs(self, caches_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in caches_data] + + def register_memory(self, descs) -> None: + pass + + def get_xfer_descs(self, blocks_data, memory_type: str) -> list: + return [str(uuid.uuid4()) for _ in blocks_data] + + def prep_xfer_dlist(self, agent_name: str, descs: list) -> int: + return uuid.uuid4().int + + def get_agent_metadata(self) -> bytes: + return self.AGENT_METADATA + + def add_remote_agent(self, agent_metadata: bytes) -> str: + return self.REMOTE_AGENT_NAME + + def get_new_notifs(self) -> dict[str, list[bytes]]: + # Used to collect done_sending, which we don't test yet. + return {} + + def check_xfer_state(self, handle: int) -> str: + if self._check_xfer_state_cycles[ + handle] >= self._cycles_before_xfer_done: + return "DONE" + self._check_xfer_state_cycles[handle] += 1 + return "PROC" + + def release_xfer_handle(self, handle: int) -> None: + pass + + def send_notif(self, agent_name: str, notif_msg: bytes) -> None: + pass + + def make_prepped_xfer(self, + xfer_type: str, + local_xfer_side_handle: int, + local_block_descs_ids: list[int], + remote_xfer_side_handle: int, + remote_block_descs_ids: list[int], + notif_msg: Optional[bytes] = None) -> int: + return uuid.uuid4().int + + def transfer(self, handle: int) -> str: + return "PROC" + + ############################################################ + # Follow are for changing the behavior during testing. + ############################################################ + + def set_cycles_before_xfer_done(self, cycles: int): + """Set the number of cycles before a transfer is considered done.""" + self._cycles_before_xfer_done = cycles diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 9c88721fb278..f6e79cd676f8 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -73,11 +73,6 @@ def forward_tpu(self, *args, **kwargs): # NOTE(woosuk): This is a placeholder for future extensions. return self.forward_native(*args, **kwargs) - def forward_hpu(self, *args, **kwargs): - # By default, we assume that Gaudi ops are compatible with the - # PyTorch-native implementation. - return self.forward_native(*args, **kwargs) - def forward_neuron(self, *args, **kwargs): # By default, we assume that Neuron ops are compatible with the # PyTorch-native implementation. @@ -106,8 +101,6 @@ def dispatch_forward(self): return self.forward_hip elif current_platform.is_cpu(): return self.forward_cpu - elif current_platform.is_hpu(): - return self.forward_hpu elif current_platform.is_tpu(): return self.forward_tpu elif current_platform.is_xpu(): diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 3c2998bece44..7540e6344a49 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -79,20 +79,33 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str, fallback_or_error( guided_params, "xgrammar does not support Lark grammars and the " - "grammar failed to convert to GBNF.", "outlines") + "grammar failed to convert to GBNF.", "guidance") # If the xgrammar module cannot be imported successfully, # we should still allow users to use guided decoding with a fallback. elif not xgr_installed: fallback_or_error( guided_params, - "xgrammar module cannot be imported successfully.", "outlines") - - if (guided_params.backend == "outlines" - and guided_params.json_object is not None): - # outlines doesn't support json_object, fallback to guidance - fallback_or_error(guided_params, - "outlines does not support json_object.", "guidance") + "xgrammar module cannot be imported successfully.", "guidance") + + if guided_params.backend == "outlines": + if guided_params.json_object is not None: + # outlines doesn't support json_object, fallback to guidance + fallback_or_error(guided_params, + "outlines does not support json_object.", + "guidance") + elif guided_params.grammar is not None: + # outlines grammar support has been removed, fallback to guidance + # if it is a lark-based grammar and xgrammar otherwise + if grammar_is_likely_lark(guided_params.grammar): + fallback_or_error(guided_params, + "outlines no longer supports grammars.", + "guidance") + else: + # The grammar is likely already GBNF format. + fallback_or_error(guided_params, + "outlines no longer supports grammars.", + "xgrammar") return guided_params @@ -111,7 +124,6 @@ async def get_guided_decoding_logits_processor( guided_params = maybe_backend_fallback(guided_params) - # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa @@ -152,7 +164,6 @@ def get_local_guided_decoding_logits_processor( reasoning_backend) reasoner = reasoner_class(tokenizer) - # CFG grammar not supported by LMFE, so we use outlines instead if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 26c2d958e751..7e365b294438 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase from vllm.model_executor.guided_decoding.outlines_logits_processors import ( - CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) + JSONLogitsProcessor, RegexLogitsProcessor) from vllm.reasoning import ReasoningParser from vllm.sampling_params import GuidedDecodingParams @@ -21,36 +21,8 @@ class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" CHOICE = "choice" - GRAMMAR = "grammar" -# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark -# the main difference is that we changed the start: value to -# start: object | array, so we are denying scalar values as the root of the -# JSON. Starting with scalars as the root seems to cause llama to generate -# without stop. -JSON_GRAMMAR = r""" -?start: object | array - -?value: object -| array -| UNESCAPED_STRING -| SIGNED_NUMBER -> number -| "true" -> true -| "false" -> false -| "null" -> null - -array : "[" [value ("," value)*] "]" -object : "{" [pair ("," pair)*] "}" -pair : UNESCAPED_STRING ":" value - -%import common.UNESCAPED_STRING -%import common.SIGNED_NUMBER -%import common.WS - -%ignore WS -""" - global_thread_pool = None # used for generating logits processor fsm # It's not yet clear that using more provides a benefit, and it could @@ -60,16 +32,12 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser], -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser] +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. """ global global_thread_pool guide, mode = _get_guide_and_mode(guided_params) @@ -83,7 +51,6 @@ async def get_outlines_guided_decoding_logits_processor( global_thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=max_workers) loop = asyncio.get_running_loop() - return await loop.run_in_executor(global_thread_pool, _get_logits_processor, guide, tokenizer, mode, guided_params.whitespace_pattern, @@ -91,16 +58,12 @@ async def get_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser], -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser] +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. - We cache logit processors by (guide, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying FSM. """ guide, mode = _get_guide_and_mode(guided_params) if not guide or not mode: @@ -130,9 +93,10 @@ def _get_guide_and_mode( choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE elif guided_params.grammar: - return guided_params.grammar, GuidedDecodingMode.GRAMMAR - elif guided_params.json_object: - return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + raise ValueError( + "The `outlines` guided decoding backend no longer supports grammar " + "guided generation. Please use either the `xgrammar` or `guidance` " + "backend") else: return None, None @@ -143,13 +107,11 @@ def _get_logits_processor( mode: GuidedDecodingMode, whitespace_pattern: Union[str, None], reasoner: Optional[ReasoningParser], -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, reasoner) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, tokenizer, reasoner) - elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, tokenizer, reasoner) else: raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 4ef4db7c4a39..7f047a1df6a5 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -1,168 +1,124 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright 2024-present the Outlines developers +from __future__ import annotations -# Copyright 2024- the Outlines developers -# This file is adapted from -# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import copy +import hashlib +import importlib.metadata import json -from collections import defaultdict -from functools import lru_cache -from typing import Callable, Optional, Union +import os +from typing import Optional, Union -import numpy as np +import regex as re import torch -from outlines import grammars -from outlines.caching import cache, disable_cache -from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide, - RegexGuide, Write) -from outlines.fsm.parsing import PartialLark -from outlines_core.fsm.json_schema import build_regex_from_schema +from cachetools import LRUCache +from diskcache import Cache +from outlines_core import Guide, Index, Vocabulary +from outlines_core.json_schema import build_regex_from_schema +from outlines_core.kernels.torch import (_apply_token_bitmask_inplace_kernel, + allocate_token_bitmask) from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +from transformers.file_utils import SPIECE_UNDERLINE +from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode import vllm.envs as envs from vllm.logger import init_logger -from vllm.platforms import current_platform from vllm.reasoning import ReasoningParser +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -if envs.VLLM_V0_USE_OUTLINES_CACHE: - logger.warning("Enabling outlines cache. This is an unbounded on-disk " - "cache. It may consume a lot of disk space and should " - "not be used with untrusted clients.") -else: - disable_cache() +CACHE = None class BaseLogitsProcessor: - def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]): + def __init__(self, guide: Guide, eos_token_id: int, + reasoner: Optional[ReasoningParser]) -> None: self._guide: Guide = guide + self._eos_token_id: int = eos_token_id self._reasoner: Optional[ReasoningParser] = reasoner - # CFGState is used for the FSM state for CFGGuide - self._fsm_state: defaultdict[int, Union[int, - CFGState]] = defaultdict(int) - - def clone(self) -> "BaseLogitsProcessor": - cloned = copy.copy(self) - cloned._guide = self._guide.copy() - cloned._fsm_state = copy.deepcopy(self._fsm_state) - return cloned + self._mask: Optional[torch.Tensor] = None def __call__(self, input_ids: list[int], scores: torch.Tensor) -> torch.Tensor: - """Use the FSM to bias the logits before sampling the next token.""" + if self._mask is None: + self._mask = allocate_token_bitmask(scores.size(-1)) # Skip the structured logits processing if reasoning is not finished. # reasoner is not None only when `--reasoning-parser` is set. - if self._reasoner is not None: - if not self._reasoner.is_reasoning_end(input_ids): - return scores - else: - # Remove the reasoning tokens from the input_ids - # We need this because our implementation relies on the - # hash of the input_ids to store the FSM state. - input_ids = self._reasoner.extract_content_ids(input_ids) - - seq_id = hash(tuple(input_ids)) - - if len(input_ids) > 0: - last_token = input_ids[-1] - last_seq_id = hash(tuple(input_ids[:-1])) - self._fsm_state[seq_id] = self._guide.get_next_state( - state=self._fsm_state[last_seq_id], token_id=last_token) - else: - # Note: this is a hack. - # Lark pickling does not work properly (silent failure), - # which breaks the RPC (which uses python pickleing). - # We need to find a better solution. - # On the first time this is called, we simply re-create - # the Lark object. - if isinstance(self._guide, CFGGuide): - self._guide.parser = PartialLark( - self._guide.cfg_string, - parser="lalr", - import_paths=[grammars.GRAMMAR_PATH], - ) - self._fsm_state[seq_id] = CFGState( - parser_state=self._guide.parser.parse(""), prev_token=None) - - instruction = self._guide.get_next_instruction( - state=self._fsm_state[seq_id]) - - if type(instruction) == Generate: # noqa: E721 - allowed_tokens = instruction.tokens - elif type(instruction) == Write: # noqa: E721 - # TODO: support fast forward tokens - allowed_tokens = [instruction.tokens[0]] - else: - raise TypeError( - f"Unsupported instruction type {type(instruction)}") - - mask = torch.full((scores.shape[-1], ), - -torch.inf, - device=scores.device) - # The tokenizer may support more token ids than the model can generate, - # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256 - # but scores.shape == torch.Size([128256]) - # Using NumPy is faster for filtering token ids - allowed_tokens = np.array(allowed_tokens, dtype=np.int64) - allowed_tokens = torch.tensor(allowed_tokens, device=scores.device) - allowed_tokens = allowed_tokens.masked_select( - allowed_tokens < scores.shape[-1]) - mask.index_fill_(0, allowed_tokens, 0) - if current_platform.is_hpu(): - # Workaround for HPU bug where add_() raise RuntimeError: - # synNodeCreateWithId failed for node: strided_insert - # with synStatus 1 [Invalid argument], hopefully it will - # be fixed in the future releases of the HPU runtime. - scores = scores.add(mask) - else: - scores.add_(mask) + if self._reasoner is not None and not self._reasoner.is_reasoning_end( + input_ids): + return scores + + # Remove the reasoning tokens from the input_ids + # We need this because our implementation relies on the + # input_ids sequence to store the FSM state. + input_ids = (self._reasoner.extract_content_ids(input_ids) + if self._reasoner is not None else input_ids) + + # Vllm V0 engine has a weird bug where we have to repeat + # the eos token id twice for generation to stop, or at least + # that is what we have to do from here in any case. + # This is a patch until a better solution can be pushed + # to outlines_core + if input_ids and input_ids[-1] != self._eos_token_id: + self._guide.advance(token_id=input_ids[-1], return_tokens=False) + + self._guide.write_mask_into( + data_ptr=self._mask.data_ptr(), + numel=self._mask.numel(), + element_size=self._mask.element_size(), + ) + + # Any allowed tokens beyond the length of the scores will + # be ignored by the kernel, taking care of the issue with + # models such as Llama 3.2 Vision with an `<|image|>` token + # with id 128256, but scores.shape == torch.Size([128256]) + _apply_token_bitmask_inplace_kernel( + logits=scores.unsqueeze(dim=0), + # mask must be on same device + mask=self._mask.to(scores.device, non_blocking=True)) + self._mask.to("cpu", non_blocking=True) + return scores + def clone(self) -> BaseLogitsProcessor: + guide = copy.deepcopy(self._guide) + guide.reset() + return BaseLogitsProcessor(guide=guide, + eos_token_id=self._eos_token_id, + reasoner=self._reasoner) + class RegexLogitsProcessor(BaseLogitsProcessor): @classmethod - @cache() def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase) -> Guide: - tokenizer = _adapt_tokenizer(tokenizer) - return RegexGuide.from_regex(regex_string, tokenizer) - - def __init__( - self, - regex_string: str, - tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser], - ): - """Compile the FSM that drives the regex-structured generation. - - Parameters - ---------- - regex_string - A string that represents a regular expression - tokenizer - The model's tokenizer - - """ + global CACHE + if CACHE is None: + CACHE = get_cache() + vocabulary = get_vocabulary(tokenizer) # type: ignore[arg-type] + cache_key = f"{vocabulary._hash}_{regex_string}" + if CACHE is not None and cache_key in CACHE: + return Guide(CACHE[cache_key]) + + index = Index(regex_string, vocabulary.inner) + + if CACHE is not None: + CACHE[cache_key] = index + + return Guide(index) + + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase, + reasoner: Optional[ReasoningParser]) -> None: super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer), reasoner) + guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer), + eos_token_id=tokenizer.eos_token_id, # type: ignore + reasoner=reasoner) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -170,22 +126,8 @@ class JSONLogitsProcessor(RegexLogitsProcessor): def __init__(self, schema: Union[str, dict, BaseModel], tokenizer: PreTrainedTokenizerBase, whitespace_pattern: Union[str, None], - reasoner: Optional[ReasoningParser]): - """Compile the FSM that drives the JSON-guided generation. - - Parameters - ---------- - schema - A JSON schema that encodes the structure we want the model to - generate - tokenizer - The model's tokenizer - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact - string literals) - Example: allow only a single space or newline with - `whitespace_pattern=r"[\n ]?"` - """ + reasoner: Optional[ReasoningParser]) -> None: + if isinstance(schema, type(BaseModel)): schema_str = json.dumps(schema.model_json_schema()) elif isinstance(schema, dict): @@ -197,63 +139,42 @@ def __init__(self, schema: Union[str, dict, BaseModel], f"Cannot parse schema {schema}. The schema must be either " f"a Pydantic object, a dictionary or a string that contains " f"the JSON Schema specification") + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) super().__init__(regex_string, tokenizer, reasoner) -class CFGLogitsProcessor(BaseLogitsProcessor): +class OutlinesVocabulary: + """ + Wrapper class for `outlines_core.Vocabulary`, + which allows us to store a hash with the vocabulary + """ + + def __init__(self, vocabulary: Vocabulary) -> None: + # Actual vocabulary object + self.inner = vocabulary + # Have to do abs(hash()) because python hashes can + # be negative, and we are using hash as a cache key. + hex_str = hashlib.sha256( + vocabulary.__repr__().encode('utf-8')).hexdigest() + hash_int = int(hex_str, 16) + self._hash = hash_int - @classmethod - @cache() - def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: - tokenizer = _adapt_tokenizer(tokenizer) - return CFGGuide(cfg, tokenizer) - - def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, - reasoner: Optional[ReasoningParser]): - """Compile the FSM that drives the context free grammar generation. - - Parameters - ---------- - cfg - A string that represents a context-free grammar - tokenizer - The model's tokenizer - - """ - super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer), - reasoner) - self._guide = self._guide.copy() - - def clone(self) -> "CFGLogitsProcessor": - cloned = copy.copy(self) - cloned._fsm_state = copy.deepcopy(self._fsm_state) - cloned._guide = self._guide.copy() - return cloned - - -@lru_cache(maxsize=32) -def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): - """Adapt vLLM's tokenizer to use to compile the FSM. - - The API of Outlines tokenizers is slightly different to that of - `transformers`. The decoder of outlines, returns a list whereas - the decode of vLLM returns an str. To sync the vLLM decoder with - outlines internal api, the decoder should be adapted. In addition - we need to handle the missing spaces to Llama's tokenizer to be - able to compile FSMs for this model. - """ - if getattr(tokenizer, "_outlines_adapted", False): - return tokenizer +re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") +re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$") - tokenizer = copy.deepcopy(tokenizer) - tokenizer.vocabulary = tokenizer.get_vocab() - tokenizer.special_tokens = set(tokenizer.all_special_tokens) +def _reduced_vocabulary(tokenizer: AnyTokenizer, + eos_token_id: int) -> dict[bytes, list[int]]: + """Create a map from vocabulary tokens to lists of equivalent token ids. + + Returns: + A Dict of token string -> equivalent token ids + """ + unicode_to_bytes = {v: k for k, v in bytes_to_unicode().items()} def convert_token_to_string(token: str) -> str: - from transformers.file_utils import SPIECE_UNDERLINE string = tokenizer.convert_tokens_to_string([token]) @@ -264,21 +185,123 @@ def convert_token_to_string(token: str) -> str: return string - def change_decoder( - decoder: Callable[[list[int]], - str]) -> Callable[[list[int]], list[str]]: - """Sync vLLM's decoder with the outlines by returning list.""" + vocabulary: dict[bytes, list[int]] = {} + empty_token_ids: list[int] = [] + for token, token_idx in tokenizer.get_vocab().items(): + if token in tokenizer.all_special_tokens: # type: ignore + continue + + token_str = convert_token_to_string(token) + if token_str: + if isinstance(token, (bytes, bytearray)): + # For BPE tokenizers where tokens are stored as bytes. + + # safe to ignore since token_str is of type (bytearray, bytes) + # by this point. + token_bytes = bytes(token_str) # type: ignore[arg-type] + + elif "\ufffd" in token_str and not re_replacement_seq.match( + token_str): + # Handle tokens with invalid UTF-8 sequences. + if re_llama_byte_token.match(token): + # Llama-like tokenizers use <0xXX> for incomplete sequences. + token_bytes = bytes([int(token[3:5], 16)]) + else: + # GPT2 tokenizers: map each byte back using unicode_to_bytes + byte_vals = [unicode_to_bytes.get(c) for c in token] + if None in byte_vals: + raise RuntimeError( + f"Cannot convert token `{token}`" + f" ({token_idx}) to bytes: {token_str}") + # safe to ignore, since if None in byte_vals, + # an error is thrown. + token_bytes = bytes(byte_vals) # type: ignore[arg-type] + else: + token_bytes = token_str.encode('utf-8') - def new_decoder(inp_tokens: list[int]) -> list[str]: - if (isinstance(inp_tokens, list) and len(inp_tokens) == 1 - and isinstance(inp_tokens[0], list)): - inp_tokens = inp_tokens[0] - return [decoder(inp_tokens)] + if token_idx != eos_token_id: + vocabulary.setdefault(token_bytes, []).append(token_idx) + else: + empty_token_ids.append(token_idx) - return new_decoder + return vocabulary - tokenizer.convert_token_to_string = convert_token_to_string - tokenizer.decode = change_decoder(tokenizer.decode) - setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 - return tokenizer +def get_vocabulary(tokenizer: AnyTokenizer) -> Vocabulary: + """Get the `Vocabulary` object for a given tokenizer. + """ + if hasattr(tokenizer, "_outlines_vocabulary"): + return tokenizer._outlines_vocabulary # type: ignore + + try: + if hasattr( + tokenizer, + "eos_token_id", + ) and tokenizer.eos_token_id is not None: + eos_token_id = tokenizer.eos_token_id + else: + raise ValueError( + f"Error during guided decoding setup: Tokenizer" + f" ({type(tokenizer)}) has no `eos_token_id` property, " + "but `eos_token_id` is required for guided decoding" + " to work properly.") + + reduced_vocab = _reduced_vocabulary( + tokenizer, + eos_token_id #type: ignore + ) + vocabulary = OutlinesVocabulary(Vocabulary(eos_token_id, + reduced_vocab)) + tokenizer._outlines_vocabulary = vocabulary # type: ignore + + return vocabulary + except AttributeError as e: + raise ValueError(f"Cannot get the vocabulary of the tokenizer " + f"({type(tokenizer)}). The tokenizer should have a " + "get_vocab method.") from e + + +def get_cache_path() -> str: + """Get the context object that contains previously-computed return values""" + outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR") + xdg_cache_home = os.getenv("XDG_CACHE_HOME") + home_dir = os.path.expanduser("~") + + if outlines_cache_dir: + # OUTLINES_CACHE_DIR takes precedence + return outlines_cache_dir + elif xdg_cache_home: + return os.path.join(xdg_cache_home, ".cache", "outlines") + # If homedir is "/", we may be inside a container, and thus writing to + # root would be problematic, so we fallback to using a tempfile. + # Also validate the path exists, since os.path.expanduser does + # not garuntee existence. + elif os.path.isdir(home_dir) and home_dir != "/": + # Default Unix fallback: ~/.cache/outlines + return os.path.join(home_dir, ".cache", "outlines") + else: + import tempfile + + # home_dir may be / inside a docker container without existing user + tempdir = tempfile.gettempdir() + return os.path.join(tempdir, ".cache", "outlines") + + +def get_cache(): + """Get the Cache instance to be used for index caching""" + + cache_dir = get_cache_path() + if envs.VLLM_V0_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") + cache = Cache(cache_dir, eviction_policy="none", cull_limit=0) + outlines_version = importlib.metadata.version("outlines_core") + + cached_version = cache.get('__version__', None) + if cached_version != outlines_version: + cache.clear() + cache.set('__version__', outlines_version) + return cache + else: + return LRUCache(maxsize=128) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a8788e340fc8..3ccddb52998b 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import _resize_cache from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked, + is_blackwell_deep_gemm_used) logger = init_logger(__name__) @@ -47,9 +51,11 @@ def _silu_mul_fp8_quant_deep_gemm( eps: tl.constexpr, fp8_min: tl.constexpr, fp8_max: tl.constexpr, + use_ue8m0: tl.constexpr, # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, ): G = H // GROUP_SIZE @@ -68,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm( cols = cols.to(tl.int64) mask_h = cols < BLOCK - t = tl.zeros([], tl.int64) - while t < n_tokens: + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): base_i_offset = (e * stride_i_e + t * stride_i_t + g * GROUP_SIZE * stride_i_h) base_yq_offset = (e * stride_yq_e + t * stride_yq_t + @@ -89,14 +94,14 @@ def _silu_mul_fp8_quant_deep_gemm( y = x * y2 _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil( + tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) tl.store(y_s_ptr + base_ys_offset, y_s) - t += 1 - def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) float32 @@ -171,8 +176,10 @@ def silu_mul_fp8_quant_deep_gemm( eps, fp8_min, fp8_max, + is_blackwell_deep_gemm_used(), BLOCK=group_size, - num_warps=4, + NUM_STAGES=8, + num_warps=1, ) return y_q, y_s @@ -217,6 +224,10 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return False + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, @@ -227,6 +238,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 # FIXME (varun): We should be able to dispatch only from the leader @@ -242,27 +254,21 @@ def workspace_shapes( output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output, a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): - import deep_gemm as dg + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + assert hidden_states.ndim == 3 assert self.block_shape is not None @@ -280,19 +286,11 @@ def apply( # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens + fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), + workspace1, expert_num_tokens, expected_m) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale), - (w1, w1_scale), - out=workspace1, - masked_m=expert_num_tokens, - expected_m=expected_m) - - assert expert_num_tokens is not None a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, expert_num_tokens) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), - (w2, w2_scale), - out=output, - masked_m=expert_num_tokens, - expected_m=expected_m) + fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output, + expert_num_tokens, expected_m) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 0d67b4a4a6d6..fc30e84e6656 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch @@ -37,7 +37,6 @@ def __init__(self, block_shape=block_shape, per_act_token_quant=per_act_token_quant, )) - self.allow_deep_gemm = allow_deep_gemm self.batched_triton_experts = BatchedTritonExperts( max_num_tokens=max_num_tokens, @@ -88,6 +87,25 @@ def supports_expert_map(self) -> bool: return ((bdge is None or bdge.supports_expert_map()) and (bte is None or bte.supports_expert_map())) + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + bdge = self.batched_deep_gemm_experts + bte = self.batched_triton_experts + bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None + bte_war = bte.finalize_weight_and_reduce_impl() if bte else None + is_bdge_war = bdge_war is not None + is_bte_war = bte_war is not None + + if is_bdge_war and is_bte_war: + assert bdge_war == bte_war, ( + "Both implementations should agree on WeightAndReduce impls. " + f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}") + + if bdge_war is not None: + return bdge_war + + assert bte_war is not None + return bte_war + def workspace_shapes( self, a: torch.Tensor, @@ -98,6 +116,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm @@ -105,36 +124,31 @@ def workspace_shapes( if self.allow_deep_gemm: assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts, + expert_tokens_metadata) else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts, + expert_tokens_metadata) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): experts = (self.batched_deep_gemm_experts if self.allow_deep_gemm else self.batched_triton_experts) assert experts is not None - experts.apply(output, hidden_states, w1, w2, topk_ids, activation, - global_num_experts, expert_map, w1_scale, w2_scale, - w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, - workspace2, expert_num_tokens) + experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids, + activation, global_num_experts, expert_map, w1_scale, + w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13, + workspace2, expert_tokens_meta, + apply_router_weight_on_input, extra_expert_args) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 6c03732030d1..f5ed2861b8fc 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -15,6 +15,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.utils import cdiv +from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe logger = init_logger(__name__) @@ -50,11 +51,14 @@ def get_config_quant_dtype( use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, -) -> Optional[torch.dtype]: + use_mxfp4_w4a4: bool, +) -> Union[None, torch.dtype, str]: if use_fp8_w8a8: return torch.float8_e4m3fn elif use_int8_w8a8: return torch.int8 + elif use_mxfp4_w4a4: + return "mxfp4" return None @@ -126,6 +130,7 @@ def make( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, block_shape: Optional[list[int]] = None, @@ -144,6 +149,7 @@ def make( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, ) return FusedMoEQuantConfig( quant_dtype, @@ -183,72 +189,83 @@ def use_deepep_ll_kernels(self): return (self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") + @property + def use_flashinfer_cutlass_kernels(self): + return (envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe()) + @staticmethod def make(tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig": """ - Determine MoE parallel configuration. Based on the input tp_size_, - dp_size_, ep_size_ and vllm's parallel config, determine what + Determine MoE parallel configuration. Based on the input `tp_size_`, + `dp_size_` and vllm's parallel config, determine what level's of parallelism to use in the fused moe layer. Args: - tp_size_ (int): tp_size passed into the FusedMoE constructor. - dp_size_ (int): dp_size passed into the FusedMoE constructor. - ep_size_ (int): ep_size passed into the FusedMoE constructor. - vllm_parallel_config (ParallelConfig): vllm's parallel config - object. + tp_size_ (int): `tp_size` passed into the FusedMoE constructor. + dp_size_ (int): `dp_size` passed into the FusedMoE constructor. + vllm_parallel_config (ParallelConfig): vLLM's parallel config + object which contains the `enable_expert_parallel` flag. Examples: - When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1, - we simply return the sizes unaltered and the ranks set to 0. + When there is no parallelism requested, + i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes + unaltered and the ranks set to 0. + + Expert Parallelism is considered only when either `dp_size_` or + `tp_size_` is non trivial. - Expert Parallelism is considered only when either dp_size_ or tp_size_ - is non trivial. + When TP = 2, DP = 1 and EP = False, the configuration on different + devices: - When TP = 2, DP = 1 and EP = False, the configuration on different - devices, - device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} // - legend : {size, rank} + legend : {size, rank} - device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0} - Comment : Tensors are sharded across 2 devices. - When TP = 1, DP = 2 and EP = False, the configuration on different - devices, + When TP = 1, DP = 2 and EP = False, the configuration on different + devices: + - device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0} - device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0} - Comment: There are 2 engine instances and the tensors are sharded - across 2 decvices. + across 2 decvices. + + When TP = 2, DP = 2 and EP = False, the configuration on different + devices: - When TP = 2, DP = 2 and EP = False, the configuration on different - devices, - device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0} - device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0} - device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0} - device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0} - Comment: There are 2 engine instances and the tensors are sharded - across 4 devices. + across 4 devices. + + When, TP = 2, DP = 1 and EP = True, the configuration on different + devices: - When, TP = 2, DP = 1 and EP = True, the configuration on different - devices, - device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1} - Comment: The experts are split between the 2 devices. - When, TP = 1, DP = 2 and EP = True, the configuration on different - devices, + When, TP = 1, DP = 2 and EP = True, the configuration on different + devices: + - device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0} - device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1} - Comment: There are 2 engine instances and the experts are split - between the 2 devices. + between the 2 devices. + + When TP = 2, DP = 2 and EP = True, the configuration on different + devices: - When TP = 2, DP = 2 and EP = True, the configuration on different - devices, - device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0} - device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1} - device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2} - device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3} - Comment: There are 2 engine instances and the experts are split - between the 4 devices. + between the 4 devices. """ def flatten_tp_across_dp(dp_rank: int): @@ -381,6 +398,10 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + @staticmethod def make( num_experts: int, @@ -424,6 +445,12 @@ def make( if quant_dtype is None and isinstance(quant_config, Fp8Config): quant_dtype = torch.float8_e4m3fn + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config) + if quant_dtype is None and isinstance(quant_config, + ModelOptNvFp4Config): + quant_dtype = torch.uint8 + if weight_quant is not None: per_out_ch_quant = ( weight_quant.strategy == QuantizationStrategy.CHANNEL) @@ -437,10 +464,11 @@ def make( ) else: _quant_config = FusedMoEQuantConfig() - logger.warning_once("MoE DP setup unable to determine " - "quantization scheme or unsupported " - "quantization type. This model will " - "not run with DP enabled.") + if moe_parallel_config.dp_size > 1: + logger.warning_once("MoE DP setup unable to determine " + "quantization scheme or unsupported " + "quantization type. This model will " + "not run with DP enabled.") else: _quant_config = quant_config diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..298a36175e60 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..0e210cb0f38d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..e4fa1e2e6e9b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..082456d319d3 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..c3b2e7fa91eb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..bba1d21aa2b6 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..de1c413b6e1a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 0f41414c4896..2585a2953c9d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Callable, Optional +from typing import Any, Callable, Optional import torch @@ -11,9 +11,12 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm, _fp8_quantize, - _resize_cache) + _resize_cache, + extract_required_args) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -180,7 +183,11 @@ def run_cutlass_moe_fp8( c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) - c1.fill_(0) + if not per_act_token and (expert_map is not None or use_batched_format): + # this is necessary to avoid imprecise scale calculation caused by + # random data in the unused workspace. The workspace is unused when + # this rank handles only partial tokens, or when it is batched . + c1.fill_(0) ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, problem_sizes1, ab_strides1, ab_strides1, c_strides1, @@ -251,6 +258,10 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return not self.use_batched_format + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, @@ -261,6 +272,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () @@ -275,35 +287,33 @@ def workspace_shapes( (N // 2)) output = (self.max_experts_per_worker, padded_M, K) else: - workspace1 = (M * topk, max(2 * N, K)) - workspace2 = (M * topk, N) + workspace1 = (M * topk, max(N, K)) + workspace2 = (M * topk, N // 2) output = (M * topk, K) return (workspace1, workspace2, output, self.out_dtype if self.out_dtype is not None else a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" - activation_callable = lambda i, o: self.activation(activation, i, o) + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + + activation_callable = lambda o, i: self.activation(activation, o, i) + in_dtype = hidden_states.dtype run_cutlass_moe_fp8( output, hidden_states, w1, w2, topk_ids, activation_callable, @@ -407,13 +417,28 @@ def cutlass_moe_fp8( FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max -def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, - w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, - w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, - w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, - w2_alphas: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, - device: torch.device): +def run_cutlass_moe_fp4( + output: torch.Tensor, + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + apply_router_weight_on_input: bool = False, +) -> None: """ MoE implementation for FP4 Inputs @@ -453,16 +478,16 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", " between weights.") - assert (k_a // 2 == half_k_w1 + assert (k_a == half_k_w1 * 2 and k == k_w2), ("Hidden size mismatch between a, w1 and w2") - assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " - "expected `n`") + assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in " + "expected `n`") assert (m == m_a), "input shape mismatch" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert (topk_weights.size(0) == m and topk_ids.size(0) == m), ("topk must be provided for each row of a") - + topk = topk_ids.size(1) out_dtype = a.dtype num_topk = topk_ids.size(1) @@ -476,6 +501,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + if apply_router_weight_on_input: + # TODO: this only works for topK=1, will need to update for topK>1 + assert num_topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a.mul_(topk_weights.to(out_dtype)) + # problem shapes should have [m, n, k] # Note that problem sizes are based on logical number of elements. ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, @@ -483,7 +514,6 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, blockscale_offsets) a = ops.shuffle_rows(a, a_map) - rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( a, a1_gscale, @@ -491,50 +521,270 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, blockscale_offsets, num_topk, ) - - c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, - w1_blockscale, w1_alphas, problem_sizes1, - expert_offsets[:-1], blockscale_offsets[:-1], - out_dtype, device) + c1 = _resize_cache(workspace13, (m * topk, n * 2)) + c2 = _resize_cache(workspace2, (m * topk, n)) + c3 = _resize_cache(workspace13, (m * topk, k)) + ops.cutlass_fp4_moe_mm(c1, rep_a_fp4, w1_fp4, rep_a_blockscale, + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1]) del rep_a_fp4, rep_a_blockscale - # hidden size dimension is split to one halfpytho sized tensor. - intermediate = torch.empty((m * num_topk, w1_fp4.size(1) // 2), - device=device, - dtype=out_dtype) - - torch.ops._C.silu_and_mul(intermediate, c1) - + torch.ops._C.silu_and_mul(c2, c1) int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( - intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk) + c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk) - c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, - w2_alphas, problem_sizes2, expert_offsets[:-1], - blockscale_offsets[:-1], out_dtype, device) + ops.cutlass_fp4_moe_mm(c3, int_fp4, w2_fp4, int_blockscale, w2_blockscale, + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1]) del int_fp4, int_blockscale - c2 = ops.shuffle_rows(c2, c_map) - out = (c2.view(m, num_topk, k) * - topk_weights.view(m, num_topk, 1).half()).sum(dim=1) - return out.to(dtype=out_dtype) + c3 = ops.shuffle_rows(c3, c_map) + + assert output.dtype == out_dtype + if not apply_router_weight_on_input: + output.copy_( + (c3.view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).to(out_dtype)).sum(dim=1), + non_blocking=True) + else: + output.copy_(c3.view(m, num_topk, k).sum(dim=1), non_blocking=True) + return + + +class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_experts_per_worker: int, + out_dtype: torch.dtype, + per_act_token_quant: bool, + per_out_ch_quant: bool, + block_shape: Optional[list[int]] = None, + use_batched_format: bool = False, + ): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.uint8, + per_act_token_quant=per_act_token_quant, + per_out_ch_quant=per_out_ch_quant, + block_shape=block_shape, + )) + self.max_experts_per_worker = max_experts_per_worker + self.out_dtype = out_dtype + self.use_batched_format = use_batched_format + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + if self.use_batched_format: + return (mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts) + else: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_expert_map(self) -> bool: + return False + + def supports_chunking(self) -> bool: + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + workspace1: tuple[int, ...] = () + workspace2: tuple[int, ...] = () + output: tuple[int, ...] = () + if self.use_batched_format: + padded_M = aq.size(1) + workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) + workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) + output = (self.max_experts_per_worker, padded_M, K) + else: + workspace1 = (M * topk, max(2 * N, K)) + workspace2 = (M * topk, N) + output = (M, K) + return (workspace1, workspace2, output, + self.out_dtype if self.out_dtype is not None else a.dtype) + + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor, + w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): + required_keys = [ + "g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k", + "e", "device" + ] + (g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e, + device) = extract_required_args(extra_expert_args, required_keys) + run_cutlass_moe_fp4( + output=output, + a=hidden_states, + a1_gscale=a1_gscale, + w1_fp4=w1, + w1_blockscale=w1_scale, + w1_alphas=g1_alphas, + a2_gscale=a2_gscale, + w2_fp4=w2, + w2_blockscale=w2_scale, + w2_alphas=g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + workspace13=workspace13, + workspace2=workspace2, + m=m, + n=n, + k=k, + e=e, + device=device, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +def cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + g1_alphas: torch.Tensor, + g2_alphas: torch.Tensor, + a1_gscale: torch.Tensor, + a2_gscale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False) -> torch.Tensor: + assert expert_map is None, ("Expert Parallelism / expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE's cutlass_moe_fp4.") + fn = mk.FusedMoEModularKernel( + MoEPrepareAndFinalizeNoEP(), + CutlassExpertsFp4( + max_experts_per_worker=e, + out_dtype=a.dtype, + per_act_token_quant=False, + per_out_ch_quant=False, + use_batched_format=False, + ), + ) + extra_expert_args = { + 'g1_alphas': g1_alphas, + 'g2_alphas': g2_alphas, + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, + 'm': m, + 'n': n, + 'k': k, + 'e': e, + 'device': device, + } + + # NVFP4 requires two levels of quantization, which involves computing some + # scaling factors dynamically. This makes it incompatible with the typical + # prepare -> MoE -> finalize pipeline. Move the quantization logic into the + # MoE body. + extra_prepare_args = { + 'skip_quant': True, + } + # Similar reason as above. + extra_finalize_args = { + 'skip_weight_reduce': True, + } + return fn( + hidden_states=a, + w1=w1_fp4, + w2=w2_fp4, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + activation="silu", + global_num_experts=e, + expert_map=None, + w1_scale=w1_blockscale, + w2_scale=w2_blockscale, + a1_scale=None, + a2_scale=None, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, + ) -def _valid_cutlass_block_scaled_grouped_gemm(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor) -> bool: +def _valid_cutlass_block_scaled_grouped_gemm( + w1: torch.Tensor, w2: torch.Tensor, inplace: bool, activation: str, + apply_router_weight_on_input: bool, + expert_map: Optional[torch.Tensor]) -> bool: - def _valid_cutlass_block_scaled_grouped_gemm_shape(M: int, N: int, K: int): - return M >= 128 and N % 128 == 0 and K % 128 == 0 + def _valid_cutlass_block_scaled_grouped_gemm_shape(N: int, K: int): + return N % 128 == 0 and K % 128 == 0 - m = hidden_states.size(0) _, K, N = w2.size() - if not _valid_cutlass_block_scaled_grouped_gemm_shape(m, N, K): - logger.debug( - "CutlassBlockScaledGroupedGemm disabled: unalinged problem size.") + if not _valid_cutlass_block_scaled_grouped_gemm_shape(N, K): + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled: unaligned problem size. " + "N: %s, K: %s", + N, + K, + ) return False if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): - logger.debug( - "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s).") + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled: invalid weight dtype(s). " + "w1.dtype: %s, w2.dtype: %s", + w1.dtype, + w2.dtype, + ) + return False + + if expert_map is not None: + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled: expert_parallel is" + " not supported.") + return False + + if activation != "silu": + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled: only activation silu is" + " supported.") + return False + + if apply_router_weight_on_input: + logger.debug_once("CutlassBlockScaledGroupedGemm disabled:" + " apply_router_weight_on_input is not supported.") + return False + + if inplace: + logger.debug_once( + "CutlassBlockScaledGroupedGemm disabled: inplace is not supported." + ) return False return True diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 8ad57c237fed..b89e5ac6f093 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -1,20 +1,24 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools -from typing import Optional +from typing import Any, Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_permute) +from vllm.model_executor.layers.fused_moe.deep_gemm_utils import ( + compute_aligned_M, deepgemm_moe_permute, deepgemm_unpermute_and_reduce) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, per_token_group_quant_fp8) -from vllm.utils import has_deep_gemm, round_up +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) +from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import m_grouped_fp8_gemm_nt_contiguous logger = init_logger(__name__) @@ -40,23 +44,39 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, aligned by `dg.get_m_alignment_for_contiguous_layout()`. """ if not has_deep_gemm(): - logger.debug("DeepGemm disabled: deep_gemm not available.") + logger.debug_once("DeepGemm disabled: deep_gemm not available.") return False M = hidden_states.size(0) _, K, N = w2.size() if not _valid_deep_gemm_shape(M, N, K): - logger.debug("DeepGemm disabled: unaligned problem size.") + logger.debug_once( + "DeepGemm disabled: unaligned problem size. M: %s, N: %s, K: %s", + M, + N, + K, + ) return False if (w1.dtype != torch.float8_e4m3fn or w2.dtype != torch.float8_e4m3fn): - logger.debug("DeepGemm disabled: invalid weight dtype(s).") + logger.debug_once( + "DeepGemm disabled: invalid weight dtype(s). " + "w1.dtype: %s, w2.dtype: %s", + w1.dtype, + w2.dtype, + ) return False if (not hidden_states.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - logger.debug( - "DeepGemm disabled: weights or activations not contiguous.") + logger.debug_once( + "DeepGemm disabled: weights or activations not contiguous. " + "hidden_states.is_contiguous(): %s, w1.is_contiguous(): %s, " + "w2.is_contiguous(): %s", + hidden_states.is_contiguous(), + w1.is_contiguous(), + w2.is_contiguous(), + ) return False return True @@ -85,20 +105,30 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + def workspace_shapes( - self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int, - topk: int, global_num_experts: int, local_num_experts: int + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert self.block_shape is not None - # We use global_num_experts due to how moe_align_block_size handles - # expert_maps. - num_experts = global_num_experts block_m = self.block_shape[0] - M_sum = (M * topk) + num_experts * (block_m - 1) - M_sum = round_up(M_sum, block_m) - workspace1 = (M_sum, max(N * 2, K)) - workspace2 = (M_sum, max(N, K)) - output = (M * topk, K) + M_sum = compute_aligned_M(M, topk, local_num_experts, block_m, + expert_tokens_meta) + assert M_sum % block_m == 0 + + workspace1 = (M_sum, max(N, K)) + workspace2 = (M_sum, max(N // 2, K)) + output = (M, K) return (workspace1, workspace2, output, a.dtype) def apply( @@ -107,6 +137,7 @@ def apply( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -119,47 +150,48 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): - import deep_gemm as dg assert self.block_shape is not None + assert a1q_scale is not None a1q = hidden_states _, N, K = w1.size() + local_num_experts = w1.size(0) if global_num_experts == -1: - global_num_experts = w1.size(0) + global_num_experts = local_num_experts assert w2.size(1) == K - a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( - a1q, - a1q_scale, - topk_ids, - global_num_experts, - expert_map, - self.block_shape[0], - ) - - if expert_map is not None: - # DeepGemm (Grouped Contiguous) kernel needs a valid B index - # for all rows of A. To that effect, simply compute with - # the 0th weight matrix. - # Note that this relies on the fact that corresponding topk - # weights would be 0 during weight multiplication. - expert_ids = torch.where(expert_ids == -1, 0, expert_ids) - - # Note: M_sum is different than the pre-permuted shape of a1q. - M_sum = a1q.size(0) + M_sum = compute_aligned_M(M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=deep_gemm_block_shape()[0], + expert_tokens_meta=expert_tokens_meta) + a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), + (M_sum, K)) mm1_out = _resize_cache(workspace13, (M_sum, N)) act_out = _resize_cache(workspace2, (M_sum, N // 2)) quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) + a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( + aq=a1q, + aq_scale=a1q_scale, + topk_ids=topk_ids, + local_num_experts=local_num_experts, + expert_map=expert_map, + expert_tokens_meta=expert_tokens_meta, + aq_out=a1q_perm) + assert a1q.size(0) == M_sum + + m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), + mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) @@ -169,10 +201,18 @@ def apply( column_major_scales=True, out_q=quant_out) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) + m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), + mm2_out, expert_ids) + + if apply_router_weight_on_input: + topk_weights = torch.ones_like(topk_weights) - torch.index_select(mm2_out, 0, inv_perm, out=output) + deepgemm_unpermute_and_reduce(a=mm2_out, + topk_ids=topk_ids, + topk_weights=topk_weights, + inv_perm=inv_perm, + expert_map=expert_map, + output=output) def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py new file mode 100644 index 000000000000..8cc5a747c673 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_utils.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Taken from https://github.com/ModelTC/LightLLM/blob/8ed97c74c18f11505b048b1ba00ba5c0cef8bff6/lightllm/common/fused_moe/deepep_scatter_gather.py +and updated to fit vllm needs and terminology. +""" + +import functools +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens +from vllm.triton_utils import tl, triton +from vllm.utils import round_up + + +@functools.cache +def deep_gemm_block_shape() -> list[int]: + # Lazy import to avoid CUDA initialization problems. + import deep_gemm as dg + block = dg.get_m_alignment_for_contiguous_layout() + return [block, block] + + +def expert_num_tokens_round_up_and_sum(expert_num_tokens: torch.Tensor, + alignment: int) -> int: + # Round up each element in expert_num_tokens to the nearest multiple of + # alignment. + ent = (expert_num_tokens.to(torch.int64) + + (alignment - 1)) // alignment * alignment + return torch.sum(ent).item() + + +def compute_aligned_M(M: int, num_topk: int, local_num_experts: int, + alignment: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata]): + + if ((expert_tokens_meta is not None) + and (expert_tokens_meta.expert_num_tokens_cpu is not None)): + return expert_num_tokens_round_up_and_sum( + expert_tokens_meta.expert_num_tokens_cpu, alignment=alignment) + + # expert_num_tokens information is not available on the cpu. + # compute the max required size. + M_sum = (M * num_topk) + local_num_experts * (alignment - 1) + M_sum = round_up(M_sum, alignment) + return M_sum + + +@triton.jit +def apply_expert_map(expert_id, expert_map): + if expert_id != -1: + expert_id = tl.load(expert_map + expert_id).to(tl.int64) + return expert_id + + +@triton.jit +def round_up_128(x: int) -> int: + y = 128 + return ((x + y - 1) // y) * y + + +@triton.jit +def _fwd_kernel_ep_scatter_1( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts: tl.constexpr, + BLOCK_E: tl.constexpr, + BLOCK_EXPERT_NUM: tl.constexpr, +): + cur_expert = tl.program_id(0) + + offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) + tokens_per_expert = tl.load(num_recv_tokens_per_expert + offset_cumsum, + mask=offset_cumsum < num_experts, + other=0) + tokens_per_expert = round_up_128(tokens_per_expert) + cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert + tl.store(expert_start_loc + offset_cumsum, + cumsum, + mask=offset_cumsum < num_experts) + + cur_expert_start = tl.load(expert_start_loc + cur_expert) + cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) + + m_indices_start_ptr = m_indices + cur_expert_start + off_expert = tl.arange(0, BLOCK_E) + + for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): + tl.store( + m_indices_start_ptr + start_m + off_expert, + cur_expert, + ) + + +@triton.jit +def _fwd_kernel_ep_scatter_2( + total_token_num, + expert_start_loc, + recv_x, + recv_x_stride0, + recv_x_stride1, + recv_x_scale, + recv_x_scale_stride0, + recv_x_scale_stride1, + recv_topk, + recv_topk_stride0, + recv_topk_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + output_tensor_scale, + output_tensor_scale_stride0, + output_tensor_scale_stride1, + output_index, + output_index_stride0, + output_index_stride1, + topk_num: tl.constexpr, + expert_map, + HAS_EXPERT_MAP: tl.constexpr, + HIDDEN_SIZE: tl.constexpr, + HIDDEN_SIZE_PAD: tl.constexpr, + SCALE_HIDDEN_SIZE: tl.constexpr, + SCALE_HIDDEN_SIZE_PAD: tl.constexpr, +): + start_token_id = tl.program_id(0) + grid_num = tl.num_programs(0) + + offset_in = tl.arange(0, HIDDEN_SIZE_PAD) + mask = offset_in < HIDDEN_SIZE + + offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) + mask_s = offset_in_s < SCALE_HIDDEN_SIZE + + for token_id in range(start_token_id, total_token_num, grid_num): + to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, + mask=mask) + to_copy_s = tl.load(recv_x_scale + token_id * recv_x_scale_stride0 + + offset_in_s, + mask=mask_s) + + for topk_index in tl.range(0, topk_num, 1, num_stages=4): + expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + + topk_index) + + if HAS_EXPERT_MAP: + expert_id = apply_expert_map(expert_id, expert_map) + + if expert_id >= 0: + dest_token_index = tl.atomic_add(expert_start_loc + expert_id, + 1) + tl.store( + output_index + token_id * output_index_stride0 + + topk_index, dest_token_index) + output_tensor_ptr = (output_tensor + + dest_token_index * output_tensor_stride0) + output_tensor_scale_ptr = ( + output_tensor_scale + + dest_token_index * output_tensor_scale_stride0) + tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) + tl.store(output_tensor_scale_ptr + offset_in_s, + to_copy_s, + mask=mask_s) + + +@torch.no_grad() +def ep_scatter( + recv_x: torch.Tensor, + recv_x_scale: torch.Tensor, + recv_topk: torch.Tensor, + num_recv_tokens_per_expert: torch.Tensor, + expert_map: Optional[torch.Tensor], + expert_start_loc: torch.Tensor, + output_tensor: torch.Tensor, + output_tensor_scale: torch.Tensor, + m_indices: torch.Tensor, + output_index: torch.Tensor, +): + BLOCK_E = 128 # token num of per expert is aligned to 128 + BLOCK_D = 128 # block size of quantization + num_warps = 8 + num_experts = num_recv_tokens_per_expert.shape[0] + hidden_size = recv_x.shape[1] + # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) + grid = num_experts + + assert m_indices.shape[0] % BLOCK_E == 0 + + _fwd_kernel_ep_scatter_1[(grid, )]( + num_recv_tokens_per_expert, + expert_start_loc, + m_indices, + num_experts=num_experts, + num_warps=num_warps, + BLOCK_E=BLOCK_E, + BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), + ) + + grid = min(recv_topk.shape[0], 1024 * 8) + + _fwd_kernel_ep_scatter_2[(grid, )]( + recv_topk.shape[0], + expert_start_loc, + recv_x, + recv_x.stride(0), + recv_x.stride(1), + recv_x_scale, + recv_x_scale.stride(0), + recv_x_scale.stride(1), + recv_topk, + recv_topk.stride(0), + recv_topk.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + output_tensor_scale, + output_tensor_scale.stride(0), + output_tensor_scale.stride(1), + output_index, + output_index.stride(0), + output_index.stride(1), + topk_num=recv_topk.shape[1], + expert_map=expert_map, + HAS_EXPERT_MAP=expert_map is not None, + num_warps=num_warps, + HIDDEN_SIZE=hidden_size, + HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), + SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, + SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), + ) + return + + +@triton.jit +def _fwd_kernel_ep_gather( + total_token_num, + input_tensor, + input_tensor_stride0, + input_tensor_stride1, + recv_topk_ids, + recv_topk_ids_stride0, + recv_topk_ids_stride1, + recv_topk_weight, + recv_topk_weight_stride0, + recv_topk_weight_stride1, + input_index, + input_index_stride0, + input_index_stride1, + output_tensor, + output_tensor_stride0, + output_tensor_stride1, + topk_num: tl.constexpr, + expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_D: tl.constexpr, +): + cur_block = tl.program_id(0) + start_cur_token = tl.program_id(1) + grid_num = tl.num_programs(1) + + for cur_token in range(start_cur_token, total_token_num, grid_num): + off_d = tl.arange(0, BLOCK_D) + accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) + for topk_index in range(0, topk_num): + expert_id = tl.load(recv_topk_ids + + cur_token * recv_topk_ids_stride0 + topk_index) + + if HAS_EXPERT_MAP: + expert_id = apply_expert_map(expert_id, expert_map) + + if expert_id >= 0: + source_token_index = tl.load(input_index + + cur_token * input_index_stride0 + + topk_index) + acc_weight = tl.load(recv_topk_weight + + cur_token * recv_topk_weight_stride0 + + topk_index) + tmp = tl.load(input_tensor + + source_token_index * input_tensor_stride0 + + cur_block * BLOCK_D + off_d) + accumulator += tmp.to(tl.float32) * acc_weight + + tl.store( + output_tensor + cur_token * output_tensor_stride0 + + cur_block * BLOCK_D + off_d, + accumulator.to(output_tensor.dtype.element_ty), + ) + + +@torch.no_grad() +def ep_gather( + input_tensor: torch.Tensor, + recv_topk_ids: torch.Tensor, + recv_topk_weight: torch.Tensor, + input_index: torch.Tensor, + expert_map: Optional[torch.Tensor], + output_tensor: torch.Tensor, +): + num_warps = 2 + num_tokens = output_tensor.shape[0] + hidden_size = input_tensor.shape[1] + BLOCK_D = min(hidden_size, 1024) + assert hidden_size % BLOCK_D == 0 + grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) + + _fwd_kernel_ep_gather[grid]( + num_tokens, + input_tensor, + input_tensor.stride(0), + input_tensor.stride(1), + recv_topk_ids, + recv_topk_ids.stride(0), + recv_topk_ids.stride(1), + recv_topk_weight, + recv_topk_weight.stride(0), + recv_topk_weight.stride(1), + input_index, + input_index.stride(0), + input_index.stride(1), + output_tensor, + output_tensor.stride(0), + output_tensor.stride(1), + topk_num=recv_topk_ids.shape[1], + expert_map=expert_map, + HAS_EXPERT_MAP=expert_map is not None, + num_warps=num_warps, + BLOCK_D=BLOCK_D, + ) + return + + +def deepgemm_moe_permute(aq: torch.Tensor, + aq_scale: torch.Tensor, + topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + aq_out: Optional[torch.Tensor] = None): + + assert aq.ndim == 2 + assert topk_ids.dtype.is_signed, ( + "The kernel uses -1 to represent invalid topk_ids") + H = aq.size(1) + device = aq.device + + block_m = deep_gemm_block_shape()[0] + block_k = deep_gemm_block_shape()[1] + + M_sum = compute_aligned_M(M=topk_ids.size(0), + num_topk=topk_ids.size(1), + local_num_experts=local_num_experts, + alignment=block_m, + expert_tokens_meta=expert_tokens_meta) + + expert_start_loc = torch.empty((local_num_experts), + device=device, + dtype=torch.int32) + + assert aq_out is None or aq_out.shape == (M_sum, H) + if aq_out is None: + aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype) + + aq_scale_out = torch.empty((M_sum, H // block_k), + device=device, + dtype=torch.float32) + + maybe_has_empty_blocks = ((expert_tokens_meta is None) + or (expert_tokens_meta.expert_num_tokens_cpu + is None)) + expert_ids_init = torch.zeros if maybe_has_empty_blocks else torch.empty + + expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32) + inv_perm = torch.empty(topk_ids.shape, device=device, dtype=torch.int32) + + expert_num_tokens = None + if expert_tokens_meta is not None: + expert_num_tokens = expert_tokens_meta.expert_num_tokens + else: + expert_num_tokens = count_expert_num_tokens(topk_ids, + local_num_experts, + expert_map) + + ep_scatter(recv_x=aq, + recv_x_scale=aq_scale, + recv_topk=topk_ids, + num_recv_tokens_per_expert=expert_num_tokens, + expert_start_loc=expert_start_loc, + expert_map=expert_map, + output_tensor=aq_out, + output_tensor_scale=aq_scale_out, + m_indices=expert_ids, + output_index=inv_perm) + + return aq_out, aq_scale_out, expert_ids, inv_perm + + +def deepgemm_unpermute_and_reduce( + a: torch.Tensor, # Grouped gemm output + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + inv_perm: torch.Tensor, + expert_map: Optional[torch.Tensor], + output: torch.Tensor): + + return ep_gather(input_tensor=a, + recv_topk_ids=topk_ids, + recv_topk_weight=topk_weights, + input_index=inv_perm, + expert_map=expert_map, + output_tensor=output) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index b625c28d4070..7016ff34c3a8 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import deep_ep import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -62,8 +63,9 @@ def _do_dispatch(self, tokens: torch.Tensor, has_scales = token_scales is not None - (num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, - is_token_in_rank, event) = self.buffer.get_dispatch_layout( + (num_tokens_per_rank, num_tokens_per_rdma_rank, + dispatch_expert_num_tokens, is_token_in_rank, + event) = self.buffer.get_dispatch_layout( topk_idx=rank_topk_ids, num_experts=num_experts, previous_event=None, @@ -83,7 +85,7 @@ def _do_dispatch(self, tokens: torch.Tensor, num_tokens_per_rank=num_tokens_per_rank, num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=expert_num_tokens, + num_tokens_per_expert=dispatch_expert_num_tokens, topk_idx=rank_topk_ids, topk_weights=rank_topk_weights, # expert_alignment rounds the number of tokens per expert @@ -115,22 +117,25 @@ def _do_dispatch(self, tokens: torch.Tensor, num_experts - 1 if self.rank_expert_offset == 0 else 0, expert_topk_ids + self.rank_expert_offset) - return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + # Makes a GPU-CPU copy. + # TODO (varun): Maybe it is better to re-compute the expert_num_tokens + # on GPU. + expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( + expert_num_tokens_per_expert_list, device=expert_x.device) + + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -149,7 +154,7 @@ def prepare( ) if a1q_scale is not None and a1q_scale.numel() == 1: a1q_scale = a1q_scale.view(1, 1) - (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1q, token_scales=a1q_scale, @@ -159,7 +164,7 @@ def prepare( else: # DeepEP kernels only support dispatching per-token-quant # quantization. dispatch in bfloat16. - (expert_x, _, expert_num_tokens, expert_topk_ids, + (expert_x, _, expert_tokens_meta, expert_topk_ids, expert_topk_weights) = self._do_dispatch( tokens=a1, token_scales=None, @@ -176,48 +181,29 @@ def prepare( per_act_token_quant=False, block_shape=quant_config.block_shape) - return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, expert_topk_weights) - def _apply_weights_and_reduce(self, num_tokens: int, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - apply_router_weight_on_input: bool, - output_dtype: torch.dtype): - - hidden_dim = fused_expert_output.size(-1) - if fused_expert_output.ndim == 2: - fused_expert_output = fused_expert_output.view( - num_tokens, -1, hidden_dim) - - if not apply_router_weight_on_input: - # The DeepEP combine kernels don't do the topk weight - # multiplication. We multiply the weights locally. - m_x_topk = fused_expert_output.size(0) - fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1)) - - out = torch.empty((num_tokens, hidden_dim), - device=fused_expert_output.device, - dtype=output_dtype) - ops.moe_sum(fused_expert_output, out) - - return out - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> None: + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: assert self.handle is not None # fused_expert_output can have 0 tokens - This happens when none of the # tokens from the all2all reach this EP rank. if fused_expert_output.numel() != 0: - fused_expert_output = self._apply_weights_and_reduce( - num_tokens=topk_ids.size(0), + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + fused_expert_output = weight_and_reduce_impl.apply( + output=None, fused_expert_output=fused_expert_output, topk_weights=topk_weights, + topk_ids=topk_ids, apply_router_weight_on_input=apply_router_weight_on_input, - output_dtype=output.dtype) + ) combined_x, _, event = self.buffer.combine( x=fused_expert_output, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 78ac4acc495d..57871ca250ae 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -1,12 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import Any, Optional, Union import deep_ep import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape) @@ -109,18 +111,15 @@ def _do_quant( return x, x_scales def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: hidden_size = a1.size(1) assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ @@ -158,12 +157,19 @@ def prepare( expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) - return (expert_x, expert_x_scale, expert_num_tokens, None, None) + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + + return (expert_x, expert_x_scale, expert_tokens_meta, None, None) def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool) -> None: - + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + assert isinstance( + weight_and_reduce_impl, TopKWeightAndReduceDelegate + ), ("Weight application and reduction happens in the combine kernel.") assert self.handle is not None combine_topk_weights = topk_weights diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py new file mode 100644 index 000000000000..3e79a1a8c24b --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) +from vllm.model_executor.layers.fused_moe.utils import extract_required_args +from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, + has_flashinfer_cutlass_fused_moe) + +logger = init_logger(__name__) + + +def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor) -> bool: + """ + Check if the given problem size is supported by the FlashInfer CUTLASS MoE + kernel. + """ + if not has_flashinfer_cutlass_fused_moe(): + logger.debug_once("FlashInferExperts disabled: " + "flashinfer_cutlass_fused_moe not available.") + return False + # Data type checks + if (w1.dtype != torch.uint8 or w2.dtype != torch.uint8 + or hidden_states.dtype + not in [torch.float32, torch.float16, torch.bfloat16]): + logger.debug_once( + "FlashInferExperts disabled: w1/w2 must be torch.uint8 " + f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " + f"float32, float16, or bfloat16 (got {hidden_states.dtype}).") + return False + return True + + +class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + use_nvfp4_w4a4: bool = False, + use_fp8_w8a8: bool = False, + use_dp: bool = False, + ep_rank: int = 0, + ep_size: int = 1, + tp_rank: int = 0, + tp_size: int = 1, + num_dispatchers: Optional[int] = None, + use_batched_format: bool = False, + ): + super().__init__( + FusedMoEQuantConfig( + quant_dtype=torch.uint8, + per_act_token_quant=False, + block_shape=None, + )) + self.use_nvfp4_w4a4 = use_nvfp4_w4a4 + self.use_fp8_w8a8 = use_fp8_w8a8 + self.ep_rank = ep_rank + self.ep_size = ep_size + self.tp_rank = tp_rank + self.tp_size = tp_size + self.use_dp = use_dp + assert not use_batched_format or num_dispatchers is not None + self.num_dispatchers = num_dispatchers + + @property + def activation_formats( + self + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return (mk.FusedMoEActivationFormat.Standard, + mk.FusedMoEActivationFormat.Standard) + + def supports_expert_map(self) -> bool: + return False + + def supports_chunking(self) -> bool: + # This refers to TP chunking; DP chunking is handled separately. + return True + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " + "currently supported.") + aq_m, aq_n = aq.shape + workspace2 = () + output_shape = (aq_m, aq_n * 2) + workspace_dtype = a.dtype + workspace1 = output_shape + # The workspace is determined by `aq`, since it comes after any + # potential communication op and is involved in the expert computation. + return (workspace1, workspace2, output_shape, workspace_dtype) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], # Not used + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: Optional[bool], + extra_expert_args: Optional[dict[str, Any]], + ): + assert extra_expert_args is not None, \ + "extra_expert_args must be provided" + required_keys = [ + 'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype' + ] + + g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = ( + extract_required_args(extra_expert_args, required_keys)) + + # Flashinfer CUTLASS kernel takes scalar global scales, + # min because inv_scale. + assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is " + "currently supported.") + + # Ensure w1_scale and w2_scale are not None before calling view + assert w1_scale is not None and w2_scale is not None, ( + "w1_scale and w2_scale must not " + "be None for FlashInferExperts") + + assert not apply_router_weight_on_input + + quant_scales = [ + a1_gscale, + w1_scale.view(torch.int32), + g1_alphas, + a2_gscale, + w2_scale.view(torch.int32), + g2_alphas, + ] + _ = flashinfer_cutlass_fused_moe( + input=hidden_states, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, + # FlashInfer API requires weight to be long for nvfp4 + fc1_expert_weights=w1.view(torch.long), + fc2_expert_weights=w2.view(torch.long), + output_dtype=out_dtype, + quant_scales=quant_scales, + input_sf=a1q_scale, + tp_size=self.tp_size, + tp_rank=self.tp_rank, + ep_size=self.ep_size, + ep_rank=self.ep_rank, + output=output, + ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py new file mode 100644 index 000000000000..e658990e95e5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch + +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_dp_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.utils import ( + extract_required_args, moe_kernel_quantize_input) +from vllm.utils.flashinfer import block_scale_interleave + + +def get_local_sizes(local_tokens): + cu_sizes = get_forward_context().dp_metadata.cu_tokens_across_dp_cpu + sizes = [cu_sizes[0].item()] + for i in range(1, len(cu_sizes)): + sizes.append((cu_sizes[i] - cu_sizes[i - 1]).item()) + max_num_tokens = envs.VLLM_MOE_DP_CHUNK_SIZE + sizes_chunked = [max_num_tokens] * len(sizes) + if local_tokens < max_num_tokens: + # When the number of local tokens is less than max_num_tokens, all other + # ranks will also have fewer than max_num_tokens. The remaining tokens + # are accounted for as residual. + sizes_chunked = [x % max_num_tokens for x in sizes] + + return sizes_chunked + + +class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + + def __init__( + self, + quant_dtype: Optional[torch.dtype] = None, + per_channel_quant: bool = False, + block_shape: Optional[list[int]] = None, + num_dispatchers: int = 1, + ): + super().__init__() + self.per_channel_quant = per_channel_quant + self.block_shape = block_shape + self.quant_dtype = quant_dtype + self.num_dispatchers_ = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + + def num_dispatchers(self) -> int: + return self.num_dispatchers_ + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], # Not used + a2_scale: Optional[torch.Tensor], # Not used + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + assert not apply_router_weight_on_input + + (a1_gscale, use_dp, local_tokens) = extract_required_args( + extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens']) + + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + a1_gscale, + quant_config.quant_dtype, + self.per_channel_quant, + self.block_shape, + is_fp4_scale_swizzled=not use_dp, # Swizzling after communication + ) + if use_dp: + topk_weights, topk_ids, a1q, a1q_scale = \ + get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501 + dim=0, + sizes=get_local_sizes(local_tokens)) + a1_m, a1_n = a1q.shape + a1q_scale = block_scale_interleave(a1q_scale) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + + (use_dp, + local_tokens) = extract_required_args(extra_finalize_args, + ['use_dp', 'local_tokens']) + if use_dp: + fused_expert_output = get_dp_group().reduce_scatterv( + fused_expert_output, + dim=0, + sizes=get_local_sizes(local_tokens), + ) + output.copy_(fused_expert_output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 0355abbf1d2b..9a5c85e120cc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1,21 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Fused batched MoE kernel.""" -from typing import Optional +from typing import Any, Optional import torch -import triton -import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, TopKWeightAndReduceNaiveBatched) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, moe_kernel_quantize_input, normalize_batched_scales_shape, normalize_scales_shape) from vllm.model_executor.layers.quantization.utils.quant_utils import ( group_broadcast) +from vllm.triton_utils import tl, triton @triton.jit @@ -495,18 +496,15 @@ def num_dispatchers(self) -> int: return self.num_dispatchers_ def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) @@ -587,34 +585,25 @@ def prepare( assert b_a1_scale is None or b_a1_scale.ndim == 3 - return b_a1, b_a1_scale, tokens_per_expert, None, None - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: - num_tokens = topk_ids.size(0) - num_local_experts = fused_expert_output.size(0) - K = fused_expert_output.size(-1) - assert output.size(0) == num_tokens and output.size(1) == K - - output.fill_(0) - - first_expert = num_local_experts * self.rank - last_expert = first_expert + num_local_experts - - for expert_id in range(first_expert, last_expert): - matching_tokens = topk_ids == expert_id - topks = torch.any(matching_tokens, dim=1).flatten() - rows = torch.count_nonzero(topks) - rhs = fused_expert_output[expert_id - first_expert, :rows, :] - if not apply_router_weight_on_input: - rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) - output[topks] = output[topks] + rhs + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None) + + return b_a1, b_a1_scale, expert_tokens_meta, None, None + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank) + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -632,6 +621,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, block_shape: Optional[list[int]] = None, per_act_token_quant: bool = False, ): @@ -641,12 +631,14 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + assert not use_mxfp4_w4a4, "NYI" self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -663,6 +655,10 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return False + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, @@ -673,6 +669,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.num_dispatchers @@ -691,28 +688,21 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: else: return t.to(f32) * group_broadcast(scale, t.shape) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): assert hidden_states.dim() == 3 - assert expert_num_tokens is not None + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens num_local_experts = w1.size(0) assert num_local_experts == w1.size(0), ( @@ -838,6 +828,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, ): @@ -847,18 +838,21 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) assert not use_int8_w8a8, "NYI" assert not use_int8_w8a16, "NYI" assert not use_int4_w4a16, "NYI" + assert not use_mxfp4_w4a4, "NYI" assert max_num_tokens > 0 assert num_dispatchers > 0 self.use_fp8_w8a8 = use_fp8_w8a8 self.use_int8_w8a8 = use_int8_w8a8 self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 + self.use_mxfp4_w4a4 = use_mxfp4_w4a4 self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers @@ -875,6 +869,10 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return False + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + def workspace_shapes( self, a: torch.Tensor, @@ -885,6 +883,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: assert a.dim() == 2 num_dp = self.num_dispatchers @@ -895,26 +894,18 @@ def workspace_shapes( output = (num_experts, max_num_tokens * num_dp, K) return (workspace13, workspace2, output, a.dtype) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): # Check constraints. if self.use_int4_w4a16: assert hidden_states.size(-1) // 2 == w1.size(2), ( @@ -931,6 +922,9 @@ def apply( assert hidden_states.dtype in [ torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] + assert expert_tokens_meta is not None + + expert_num_tokens = expert_tokens_meta.expert_num_tokens E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) @@ -941,6 +935,7 @@ def apply( config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, dtype=hidden_states.dtype) config = try_get_optimal_moe_config( diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index fbbccbb34d90..c412f695ae76 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -25,11 +25,16 @@ moe_align_block_size) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP) from vllm.model_executor.layers.fused_moe.utils import ( - _resize_cache, moe_kernel_quantize_input) + _resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + dequant_mxfp4) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton -from vllm.utils import direct_register_custom_op +from vllm.utils import direct_register_custom_op, is_torch_equal_or_newer +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled @@ -973,13 +978,16 @@ def get_config_dtype_str( dtype: torch.dtype, use_int4_w4a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False) -> Optional[str]: + use_fp8_w8a8: Optional[bool] = False, + use_mxfp4_w4a4: Optional[bool] = False) -> Optional[str]: if use_fp8_w8a8: return "fp8_w8a8" elif use_int8_w8a16: return "int8_w8a16" elif use_int4_w4a16: return "int4_w4a16" + elif use_mxfp4_w4a4: + return "mxfp4_w4a4" elif dtype == torch.float: # avoiding cases where kernel fails when float32 MoE # use fp16/bfloat16 configs @@ -998,6 +1006,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1011,9 +1020,9 @@ def inplace_fused_experts(hidden_states: torch.Tensor, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16, - per_channel_quant, global_num_experts, expert_map, - w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + use_mxfp4_w4a4, per_channel_quant, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, + a2_scale, block_shape) def inplace_fused_experts_fake( @@ -1028,6 +1037,7 @@ def inplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1046,6 +1056,105 @@ def inplace_fused_experts_fake( op_func=inplace_fused_experts, mutates_args=["hidden_states"], fake_impl=inplace_fused_experts_fake, + tags=(() if is_torch_equal_or_newer("2.7.0") else + (torch.Tag.needs_fixed_stride_order, )), +) + + +def next_positive_power_of_2(x: int) -> int: + if x < 1: + return 1 + return 1 << (x - 1).bit_length() + + +def _get_tile_tokens_dim(num_tokens, top_k, num_experts): + # Guess tokens per expert assuming perfect expert distribution first. + num_tokens_per_expert = (num_tokens * top_k) // num_experts + # And pad the number to the next power of 2. + tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. + tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim + + +def flashinfer_fused_moe_blockscale_fp8( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0) -> torch.Tensor: + from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe + assert top_k <= global_num_experts + assert top_k <= 8 + assert topk_group <= 4 + assert global_num_experts > num_expert_group + assert global_num_experts % num_expert_group == 0 + assert global_num_experts % 4 == 0 + assert top_k < (topk_group * global_num_experts / num_expert_group) + assert block_shape == [128, 128] + + a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) + # NOTE: scales of hidden states have to be transposed! + a_sf_t = a_sf.t().contiguous() + return flashinfer_trtllm_fp8_block_scale_moe( + routing_logits=routing_logits, + routing_bias=routing_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=w13_weight, + gemm1_weights_scale=w13_weight_scale_inv, + gemm2_weights=w2_weight, + gemm2_weights_scale=w2_weight_scale_inv, + num_experts=global_num_experts, + top_k=top_k, + n_group=num_expert_group, + topk_group=topk_group, + intermediate_size=intermediate_size, + local_expert_offset=expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling, + tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k, + global_num_experts), + routing_method_type=2, # DeepSeek-styled routing method + ) + + +def flashinfer_fused_moe_blockscale_fp8_fake( + routing_logits: torch.Tensor, + routing_bias: torch.Tensor, + x: torch.Tensor, + w13_weight: torch.Tensor, + w13_weight_scale_inv: torch.Tensor, + w2_weight: torch.Tensor, + w2_weight_scale_inv: torch.Tensor, + global_num_experts: int, + top_k: int, + num_expert_group: int, + topk_group: int, + intermediate_size: int, + expert_offset: int, + local_num_experts: int, + block_shape: list[int], + routed_scaling: float = 1.0) -> torch.Tensor: + return torch.empty_like(x) + + +direct_register_custom_op( + op_name="flashinfer_fused_moe_blockscale_fp8", + op_func=flashinfer_fused_moe_blockscale_fp8, + mutates_args=[], + fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, tags=(torch.Tag.needs_fixed_stride_order, ), ) @@ -1062,6 +1171,7 @@ def outplace_fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1075,10 +1185,10 @@ def outplace_fused_experts( return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, False, activation, apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, - use_int4_w4a16, per_channel_quant, - global_num_experts, expert_map, w1_scale, - w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, - block_shape) + use_int4_w4a16, use_mxfp4_w4a4, + per_channel_quant, global_num_experts, + expert_map, w1_scale, w2_scale, w1_zp, w2_zp, + a1_scale, a2_scale, block_shape) def outplace_fused_experts_fake( @@ -1092,6 +1202,7 @@ def outplace_fused_experts_fake( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1110,7 +1221,8 @@ def outplace_fused_experts_fake( op_func=outplace_fused_experts, mutates_args=[], fake_impl=outplace_fused_experts_fake, - tags=(torch.Tag.needs_fixed_stride_order, ), + tags=(() if is_torch_equal_or_newer("2.7.0") else + (torch.Tag.needs_fixed_stride_order, )), ) @@ -1145,6 +1257,7 @@ def fused_experts( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1159,9 +1272,15 @@ def fused_experts( allow_cutlass_block_scaled_grouped_gemm: bool = False) -> torch.Tensor: # For now, disable DeepGemm for small N (<= 512) until better # permute/unpermute ops are available. + # However, on B200, we use DeepGemm for all cases because they only support + # E8M0 scale, which means we requantize the weight and input to the specific + # scale. Fallen back to cutlass or triton for some cases would cause + # accuracy issue. N = w1.size(1) - if (allow_deep_gemm and use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2)): + should_use_deep_gemm = ((N > 512 + and _valid_deep_gemm(hidden_states, w1, w2)) + or is_blackwell_deep_gemm_used()) + if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, @@ -1180,8 +1299,9 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, ) elif (allow_cutlass_block_scaled_grouped_gemm and use_fp8_w8a8 - and _valid_cutlass_block_scaled_grouped_gemm(hidden_states, w1, w2)): - assert apply_router_weight_on_input is False + and _valid_cutlass_block_scaled_grouped_gemm( + w1, w2, inplace, activation, apply_router_weight_on_input, + expert_map)): return run_cutlass_block_scaled_fused_experts( a=hidden_states, w1=w1, @@ -1203,6 +1323,7 @@ def fused_experts( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1228,6 +1349,7 @@ def fused_experts_impl( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1243,6 +1365,9 @@ def fused_experts_impl( if use_int4_w4a16: assert hidden_states.size(1) // 2 == w1.size(2), ( "Hidden size mismatch") + elif use_mxfp4_w4a4: + # 16bit activation and fp4x2 packed weight + assert hidden_states.size(1) // 2 == w1.size(2), "hidden size mismatch" else: assert hidden_states.size(1) == w1.size(2), ( f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}") @@ -1268,12 +1393,14 @@ def fused_experts_impl( config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, dtype=hidden_states.dtype) qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, - use_int4_w4a16=use_int4_w4a16) + use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4) get_config_func = functools.partial( try_get_optimal_moe_config, @@ -1313,6 +1440,13 @@ def fused_experts_impl( else: out_hidden_states = torch.empty_like(hidden_states) + if use_mxfp4_w4a4: + # Weight has to be dequantized for mxfp4 emulation. + w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype) + w1_scale = None + w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype) + w2_scale = None + for chunk in range((num_tokens // CHUNK_SIZE) + 1): begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, min((chunk + 1) * CHUNK_SIZE, @@ -1336,7 +1470,6 @@ def fused_experts_impl( curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input( A=curr_hidden_states, A_scale=a1_scale, @@ -1429,6 +1562,7 @@ def fused_moe( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, @@ -1470,6 +1604,9 @@ def fused_moe( - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16 activation to compute the inner products for w1 and w2. Defaults to False. + - use_mxfp4_w4a4 (bool): If True, use matmul of OCP MXFP4 weight and + OCP MXFP4 activation to compute the inner products for w1 and w2. + Defaults to False. - global_num_experts (int): The total number of experts in the global expert space. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices @@ -1513,6 +1650,7 @@ def fused_moe( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, expert_map=expert_map, @@ -1533,6 +1671,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, ): @@ -1542,6 +1681,7 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) @@ -1550,6 +1690,7 @@ def __init__( self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a8 = use_int8_w8a8 self.use_int8_w8a16 = use_int8_w8a16 + self.use_mxfp4_w4a4 = use_mxfp4_w4a4 @property def activation_formats( @@ -1564,6 +1705,9 @@ def supports_chunking(self) -> bool: def supports_expert_map(self) -> bool: return True + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + def workspace_shapes( self, a: torch.Tensor, @@ -1574,10 +1718,11 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - workspace1 = (M, topk, max(N * 2, K)) - workspace2 = (M, topk, N) - output = (M, topk, K) + workspace1 = (M, topk, max(N // 2, K)) + workspace2 = (M, topk, max(N, K)) + output = (M, K) return (workspace1, workspace2, output, a.dtype) def apply( @@ -1586,6 +1731,7 @@ def apply( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -1598,7 +1744,9 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): # Check constraints. if self.use_int4_w4a16: @@ -1627,6 +1775,7 @@ def apply( config_dtype = get_config_dtype_str(use_fp8_w8a8=self.use_fp8_w8a8, use_int8_w8a16=self.use_int8_w8a16, use_int4_w4a16=self.use_int4_w4a16, + use_mxfp4_w4a4=self.use_mxfp4_w4a4, dtype=hidden_states.dtype) config = try_get_optimal_moe_config( @@ -1650,37 +1799,39 @@ def apply( raise ValueError( f"Unsupported compute_type: {hidden_states.dtype}") - # We can reuse the memory between these because by the time we need - # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, + # Note that the output tensor might be in workspace1 + intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) - intermediate_cache2 = _resize_cache(workspace2, + intermediate_cache2 = _resize_cache(workspace13, (num_tokens * top_k_num, N // 2)) + intermediate_cache3 = _resize_cache(workspace2, + (num_tokens, top_k_num, K)) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], global_num_experts, expert_map)) - invoke_fused_moe_kernel(hidden_states, - w1, - intermediate_cache1, - a1q_scale, - w1_scale, - w1_zp, - None, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.use_fp8_w8a8, - use_int8_w8a8=self.use_int8_w8a8, - use_int8_w8a16=self.use_int8_w8a16, - use_int4_w4a16=self.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape) + invoke_fused_moe_kernel( + hidden_states, + w1, + intermediate_cache1, + a1q_scale, + w1_scale, + w1_zp, + None, # topk_weights + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, # mul_routed_weights + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.use_fp8_w8a8, + use_int8_w8a8=self.use_int8_w8a8, + use_int8_w8a16=self.use_int8_w8a16, + use_int4_w4a16=self.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape) self.activation(activation, intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -1693,15 +1844,15 @@ def apply( invoke_fused_moe_kernel(qintermediate_cache2, w2, - output, + intermediate_cache3, a2q_scale, w2_scale, w2_zp, - None, + topk_weights, sorted_token_ids, expert_ids, num_tokens_post_padded, - False, + not apply_router_weight_on_input, 1, config, compute_type=compute_type, @@ -1712,12 +1863,15 @@ def apply( per_channel_quant=self.per_act_token_quant, block_shape=self.block_shape) + ops.moe_sum(intermediate_cache3, output) + def modular_triton_fused_moe( use_fp8_w8a8: bool, use_int8_w8a8: bool, use_int8_w8a16: bool, use_int4_w4a16: bool, + use_mxfp4_w4a4: bool, per_act_token_quant: bool, block_shape: Optional[list[int]] = None, ) -> mk.FusedMoEModularKernel: @@ -1728,6 +1882,7 @@ def modular_triton_fused_moe( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ), diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 36ac75a8df4b..4a6a3b95ec7f 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -34,6 +34,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx +from vllm.utils.flashinfer import has_flashinfer if current_platform.is_cuda_alike(): from .fused_batched_moe import BatchedTritonExperts @@ -45,6 +46,9 @@ from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE, DeepEPLLPrepareAndFinalize) + if has_flashinfer(): + from .flashinfer_cutlass_prepare_finalize import ( + FlashInferCutlassMoEPrepareAndFinalize) else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -81,15 +85,27 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, params_dtype: torch.dtype, **extra_weight_attrs): raise NotImplementedError - def init_prepare_finalize(self, moe: FusedMoEConfig, - quant_config: Optional[QuantizationConfig]): + def uses_weight_scale_2_pattern(self) -> bool: + """ + Returns True if this quantization method uses 'weight_scale_2' pattern + for per-tensor weight scales (e.g., FP4 variants), False otherwise. + + This method should be overridden by subclasses that use the + 'weight_scale_2' pattern instead of the standard 'weight_scale' pattern. + """ + return False + + @staticmethod + def maybe_make_prepare_finalize( + moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]: all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - self.moe = moe - prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None + if moe.use_flashinfer_cutlass_kernels: + prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=moe.quant_dtype, ) if moe.use_pplx_kernels: hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( moe.max_num_tokens, @@ -160,8 +176,6 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, and moe.quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE) - # Note (varun): Whether to use FP8 dispatch or not needs some - # profiling. Turning it off for now. prepare_finalize = DeepEPLLPrepareAndFinalize( handle, max_tokens_per_rank=moe.max_num_tokens, @@ -169,11 +183,18 @@ def init_prepare_finalize(self, moe: FusedMoEConfig, use_fp8_dispatch=use_fp8_dispatch, ) + return prepare_finalize + + def init_prepare_finalize(self, moe: FusedMoEConfig): + self.moe = moe + prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize( + self.moe) + self.topk_indices_dtype = None if prepare_finalize is not None: logger.debug("%s", prepare_finalize.__class__.__name__) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize, moe) + experts = self.select_gemm_impl(prepare_finalize, self.moe) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, @@ -190,6 +211,12 @@ def select_gemm_impl( f"{self.__class__.__name__} must select appropriate gemm " "implementation based on the prepare_finalize") + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): + pass + @abstractmethod def apply( self, @@ -238,9 +265,6 @@ def select_gemm_impl( prepare_finalize: FusedMoEPrepareAndFinalize, moe: FusedMoEConfig, ) -> FusedMoEPermuteExpertsUnpermute: - - assert self.fused_experts == fused_experts - if (prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts): logger.debug("BatchedTritonExperts %s", self.moe) @@ -348,8 +372,10 @@ def apply( logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: if enable_eplb: - raise NotImplementedError( - "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) return self.forward( x=x, @@ -366,7 +392,12 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input) + apply_router_weight_on_input=apply_router_weight_on_input, + enable_eplb=enable_eplb, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) def forward_cuda( self, @@ -385,6 +416,10 @@ def forward_cuda( e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( @@ -398,7 +433,12 @@ def forward_cuda( custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype) + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts( @@ -461,39 +501,6 @@ def forward_cpu( activation, ) - def forward_hpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - ) -> torch.Tensor: - assert not use_grouped_topk - assert num_expert_group is None - assert topk_group is None - assert custom_routing_function is None - assert layer is not None - assert apply_router_weight_on_input is False - if scoring_func != "softmax": - raise NotImplementedError( - "Only softmax scoring function is supported for HPU.") - if e_score_correction_bias is not None: - raise NotImplementedError( - "Expert score correction bias is not supported for HPU.") - return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight, - router_logits, top_k) - def forward_tpu( self, layer: torch.nn.Module, @@ -702,9 +709,6 @@ def __init__( if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " "non-grouped topk.") - if current_platform.is_hpu(): - from vllm_hpu_extension.ops import DynamicFusedMOE - self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) if vllm_config.model_config is not None: model_dtype = vllm_config.model_config.dtype @@ -739,7 +743,8 @@ def __init__( if self.enable_eplb: from vllm.model_executor.layers.quantization.fp8 import ( Fp8MoEMethod) - if not isinstance(quant_method, Fp8MoEMethod): + if not isinstance(quant_method, + (Fp8MoEMethod, UnquantizedFusedMoEMethod)): # TODO: Add support for additional quantization methods. # The implementation for other quantization methods does not # contain essential differences, but the current quant API @@ -766,12 +771,15 @@ def __init__( moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + if isinstance(self.quant_method, FusedMoEMethodBase): + self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config) # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or self.moe_parallel_config.use_flashinfer_cutlass_kernels): self.batched_hidden_states = torch.zeros( (moe.max_num_tokens, self.hidden_size), dtype=moe.in_dtype, @@ -823,6 +831,19 @@ def use_deepep_ht_kernels(self): def use_deepep_ll_kernels(self): return self.moe_parallel_config.use_deepep_ll_kernels + @property + def use_flashinfer_cutlass_kernels(self): + return self.moe_parallel_config.use_flashinfer_cutlass_kernels + + def update_expert_map(self): + # ep_size and ep_rank should already be updated + assert self.expert_map is not None + with self.expert_map.device: + self.local_num_experts, self.expert_map = determine_expert_map( + ep_size=self.ep_size, + ep_rank=self.ep_rank, + global_num_experts=self.global_num_experts) + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -883,14 +904,21 @@ def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, expert_data=expert_data, tp_rank=tp_rank) - def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, - shard_id: str, loaded_weight: torch.Tensor, tp_rank: int): + def _load_w13(self, + expert_data: torch.Tensor, + shard_dim: int, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + load_full: bool = False): # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, - shard_size) + if not load_full: + loaded_weight = loaded_weight.narrow(shard_dim, + shard_size * tp_rank, + shard_size) # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -998,6 +1026,27 @@ def weight_loader(self, param.data.copy_(loaded_weight) return True if return_success else None + # Case for BitsAndBytes + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + if use_bitsandbytes_4bit: + shard_dim = 0 + + expert_data = param.data[expert_id] + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + # BNB inflight quantization has already sharded the weights + full_load = True + self._load_w13( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank, + load_full=full_load, + ) + return True if return_success else None + # is_transposed: if the dim to shard the weight # should be flipped. Required by GPTQ, compressed-tensors # should be whatever dimension intermediate_size_per_partition is @@ -1049,12 +1098,23 @@ def weight_loader(self, # TODO @dsikka: ModelOpt should follow the proper MoE loading pattern if "ModelOpt" in quant_method_name: - if ('weight_scale_2' in weight_name - or 'input_scale' in weight_name): - self._load_per_tensor_weight_scale(shard_id=shard_id, - param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) + # Determine per-tensor weight scale patterns based on variant + # Use the dedicated method instead of brittle string matching + uses_weight_scale_2 = self.quant_method.uses_weight_scale_2_pattern( + ) + + # For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale" + per_tensor_conditions = ( + "weight_scale_2" in weight_name if uses_weight_scale_2 else + "weight_scale" in weight_name) or "input_scale" in weight_name + + if per_tensor_conditions: + self._load_per_tensor_weight_scale( + shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id, + ) elif "weight" in weight_name: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, @@ -1385,9 +1445,9 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): final_hidden_states, non_blocking=True) ctx = get_forward_context() + # flashinfer_cutlass_kernels can handle: optional DP + TP/EP max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens - num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank): @@ -1407,13 +1467,20 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + # Route to the chunked forward path using the FlashInfer Cutlass kernel + # only when data parallelism (DP) is enabled. + use_flashinfer_cutlass_kernels = ( + self.dp_size > 1 + and self.moe_parallel_config.use_flashinfer_cutlass_kernels) if (self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels): + or self.moe_parallel_config.use_deepep_ll_kernels + or use_flashinfer_cutlass_kernels): return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels) + and not self.moe_parallel_config.use_deepep_ht_kernels + and not self.moe_parallel_config.use_flashinfer_cutlass_kernels) if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) @@ -1443,7 +1510,6 @@ def forward_impl(self, hidden_states: torch.Tensor, if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): # Default set to False. (May have to add shared expert outputs. final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( @@ -1526,3 +1592,7 @@ def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, dispatch_key=current_platform.dispatch_key, tags=(torch.Tag.needs_fixed_stride_order, ), ) + +# Mark the FusedMoE weight_loader as supporting MoE-specific parameters +# to avoid expensive runtime reflection in model loading code +FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index f332b5168913..6262904e4dca 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1,15 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from math import prod -from typing import Optional, final +from typing import Any, Optional, final import torch import vllm.envs as envs from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable + _resize_cache, count_expert_num_tokens) from vllm.utils import cdiv # @@ -21,7 +23,7 @@ # # [Router] → [Quantize-Dispatch] → [Permute-Experts-Unpermute] → [Combine] # -# Each component will be independent of the others except for +# Each component will be independent of (but may inform) the others except for # [Quantize-Dispatch] and `[Combine] (see below). The components can then be # mixed and matched with so that DP+EP can be supported easily for multiple # MoE kernel implementations. @@ -30,13 +32,19 @@ # * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE # inputs (e.g. quantization, distribution) and finalization of Moe outputs. # The prepare method must take care of any needed quantization and the -# finalize method must apply weights and do the final reduction of the output. +# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method, +# may apply weights and/or do the final reduction of the output. # * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused -# MoE operation. One important feature to note is that this class does not -# apply topk weights or reduce the final output. +# MoE operation, i.e matmul + act_mul + optionally quant + matmul. +# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do +# the weight application and/or reduction. The class communicates this +# to [Finalize] via a TopKWeightAndReduce object. # * FusedMoEModularKernel - an interface class that combines a # FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to # provide the standard fused MoE kernel interface. +# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen +# by the FusedMoEPermuteExpertsUnpermute implementation that is passed +# on to [Finalize]. # # [Quantize-Prepare] and [Finalize] functionality are bundled into a single # class `FusedMoEPrepareAndFinalize` since they could use collective @@ -95,6 +103,44 @@ class FusedMoEActivationFormat(Enum): BatchedExperts = "batched_experts", +@dataclass +class ExpertTokensMetadata: + """ + Metadata regarding expert-token routing. + """ + expert_num_tokens: torch.Tensor + expert_num_tokens_cpu: Optional[torch.Tensor] + + @staticmethod + def make_from_list(expert_num_tokens_list: list[int], + device: str) -> "ExpertTokensMetadata": + expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list, + device="cpu", + dtype=torch.int32) + return ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens_cpu.to(device, + non_blocking=True), + expert_num_tokens_cpu=expert_num_tokens_cpu) + + +class TopKWeightAndReduce(ABC): + """ + An abstract base class for weight application and reduction implementations. + """ + + @abstractmethod + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + """ + Apply topk_weights to the fused_experts_outputs and/or reduce. + If an output tensor is not passed, it will be created in the + function. + """ + raise NotImplementedError + + # TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ @@ -104,18 +150,15 @@ class FusedMoEPrepareAndFinalize(ABC): @abstractmethod def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -134,7 +177,8 @@ def prepare( Returns a tuple of: - quantized + dispatched a. - quantized + dispatched a1_scales. - - Optional tensor as big as number of local experts that contains the + - Optional ExpertTokensMetadata containing gpu/cpu tensors + as big as the number of local experts with the information about the number of tokens assigned to each local expert. - Optional dispatched expert topk IDs - Optional dispatched expert topk weight @@ -142,14 +186,11 @@ def prepare( raise NotImplementedError @abstractmethod - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: """ Perform any combine plus apply weights and perform a reduction on the fused experts output. @@ -160,6 +201,8 @@ def finalize( - topk_ids: The topk_ids. - apply_router_weight_on_input: When False, apply the weights to fused_expert_output. + - weight_and_reduce_impl: An optional TopKWeightAndReduce + implementation. """ raise NotImplementedError @@ -266,6 +309,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: """ Compute the shapes for the temporary and final outputs of the two gemms @@ -299,6 +343,9 @@ def enable_chunking(self): return envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and \ self.supports_chunking() + def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: + raise NotImplementedError + @abstractmethod def apply( self, @@ -306,6 +353,7 @@ def apply( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, activation: str, global_num_experts: int, @@ -318,7 +366,9 @@ def apply( a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], ): """ This function computes the intermediate result of a Mixture of Experts @@ -330,6 +380,8 @@ def apply( layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. + - topk_weights: A map of row to expert weights. Some implementations + choose to do weight application. - topk_ids (torch.Tensor): A map of row to expert id. - activation (str): The activation function to apply after the first MoE layer. @@ -351,8 +403,13 @@ def apply( must be large enough to hold output of either MoE gemm. - workspace2 (torch.Tensor): A scratch tensor used for the activation function. - - expert_num_tokens: An optional tensor containing the number of tokens - assigned to each expert when using batched experts format input. + - expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional + ExpertTokensMetadata object containing gpu/cpu tensors + as big as the number of local experts with the information about the + number of tokens assigned to each local expert. + - apply_router_weight_on_input: True if router weights are already + applied on the input. This is relevant if the implementation + chooses to do weight application. """ raise NotImplementedError @@ -396,6 +453,219 @@ def __init__( f"{fused_experts.__class__.__name__}." f"{fused_experts.activation_formats[0]}") + def _do_fused_experts( + self, fused_out: Optional[torch.Tensor], a1: torch.Tensor, + a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str, global_num_experts: int, local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor: + + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + (workspace13_shape, workspace2_shape, fused_out_shape, + workspace_dtype) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, + expert_tokens_meta) + + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = torch.empty(prod(workspace13_shape), + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.empty(prod(workspace2_shape), + device=a1.device, + dtype=workspace_dtype) + + assert fused_out is None or fused_out.shape == fused_out_shape, ( + f"fused_out {fused_out.shape} but expected {fused_out_shape}") + if fused_out is None: + # reuse workspace13 for the output + fused_out = _resize_cache(workspace13, fused_out_shape) + + self.fused_experts.apply( + fused_out, + a1q, + w1, + w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) + + return fused_out + + def _maybe_chunk_fused_experts( + self, + a1: torch.Tensor, + a1q: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + expert_tokens_meta: Optional[ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]], + ) -> torch.Tensor: + + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + num_chunks = cdiv(M, CHUNK_SIZE) + + if not self.fused_experts.supports_chunking() or num_chunks == 1: + return self._do_fused_experts( + fused_out=None, + a1=a1, + a1q=a1q, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) + + # Chunking required case + assert num_chunks > 1 + + # Construct the entire output that can then be processed in chunks. + (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( + a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts, + expert_tokens_meta) + fused_out = torch.empty(fused_out_shape, + device=a1q.device, + dtype=a1.dtype) + + def slice_input_tensors( + chunk_idx: int + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor], torch.Tensor, torch.Tensor]: + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + return (a1q[s:e], _chunk_scales(a1q_scale, s, e), + _chunk_scales(a2_scale, s, + e), topk_ids[s:e], topk_weights[s:e]) + + def slice_output_tensor(chunk_idx: int) -> torch.Tensor: + assert fused_out.size(0) % M == 0, ( + f"fused_out shape {fused_out.shape} vs M {M}") + factor = fused_out.size(0) // M + out_chunk_size = CHUNK_SIZE * factor + s = chunk_idx * out_chunk_size + e = min(s + out_chunk_size, fused_out.size(0)) + return fused_out[s:e] + + def slice_expert_tokens_metadata( + full_expert_tokens_meta: ExpertTokensMetadata, + chunk_topk_ids: torch.Tensor, local_num_experts: int, + expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata: + # The existing expert_num_tokens is for the entire a1q + # input. Chunking forces recomputation of the number + # of tokens assigned to each expert. + c_expert_num_tokens = count_expert_num_tokens( + chunk_topk_ids, local_num_experts, expert_map) + + c_expert_num_tokens_cpu = None + need_expert_num_tokens_cpu = ( + full_expert_tokens_meta.expert_num_tokens_cpu is not None) + if need_expert_num_tokens_cpu: + # This is blocking as some implementations need the count + # on the CPU to determine appropriate input/out fused-moe + # buffers + c_expert_num_tokens_cpu = c_expert_num_tokens.to( + "cpu", non_blocking=False) + + return ExpertTokensMetadata( + expert_num_tokens=c_expert_num_tokens, + expert_num_tokens_cpu=c_expert_num_tokens_cpu) + + m = None + if extra_expert_args is not None and 'm' in extra_expert_args: + m = extra_expert_args.get('m') + + if extra_expert_args is not None: + chunked_extra_expert_args = extra_expert_args + else: + chunked_extra_expert_args = {} + + for chunk_idx in range(num_chunks): + c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( + slice_input_tensors(chunk_idx)) + + c_expert_tokens_meta = None + if expert_tokens_meta is not None: + c_expert_tokens_meta = slice_expert_tokens_metadata( + expert_tokens_meta, c_topk_ids, local_num_experts, + expert_map) + + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M) + + if m is not None: + chunked_extra_expert_args['m'] = e - s + self._do_fused_experts( + fused_out=slice_output_tensor(chunk_idx), + a1=a1, + a1q=c_a1q, + w1=w1, + w2=w2, + topk_weights=c_topk_weights, + topk_ids=c_topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=c_a1q_scale, + a2_scale=c_a2_scale, + expert_tokens_meta=c_expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=chunked_extra_expert_args) + + return fused_out + def forward( self, hidden_states: torch.Tensor, @@ -414,6 +684,9 @@ def forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + extra_expert_args: Optional[dict] = None, + extra_prepare_args: Optional[dict] = None, + extra_finalize_args: Optional[dict] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -446,6 +719,12 @@ def forward( - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. + - extra_expert_args (Optional[dict]): Extra keyword arguments to pass to + fused_experts.apply. + - extra_prepare_args (Optional[dict]): Extra keyword arguments to pass + to prepare. + - extra_finalize_args (Optional[dict]): Extra keyword arguments to pass + to finalize. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -458,7 +737,7 @@ def forward( if global_num_experts == -1: global_num_experts = local_num_experts - (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, + (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, _expert_topk_weights) = self.prepare_finalize.prepare( a1, a1_scale, @@ -469,6 +748,7 @@ def forward( expert_map, apply_router_weight_on_input, self.fused_experts.quant_config, + extra_prepare_args, ) # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. @@ -487,112 +767,31 @@ def forward( # and can never run into the tensor.numel() == 0 case. fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) else: - _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) - - if self.fused_experts.enable_chunking(): - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - num_chunks = cdiv(M, CHUNK_SIZE) - else: - CHUNK_SIZE = M - num_chunks = 1 - - if num_chunks == 1: - (workspace13_shape, workspace2_shape, fused_out_shape, - workspace_dtype) = self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, - local_num_experts) - else: - # Use the full M to get the final output shape. - _, _, fused_out_shape, _ = ( - self.fused_experts.workspace_shapes( - a1, a1q, M, N, K, top_k, global_num_experts, - local_num_experts)) - # Use the CHUNK_SIZE to get the workspace shapes. - workspace13_shape, workspace2_shape, _, workspace_dtype = ( - self.fused_experts.workspace_shapes( - a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts, - local_num_experts)) - - # We can reuse the memory between cache1 and cache3 because by the - # time we need cache3, we're done with cache1. - workspace13 = torch.empty(prod(workspace13_shape), - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.empty(prod(workspace2_shape), - device=a1.device, - dtype=workspace_dtype) - - if num_chunks == 1: - fused_out = _resize_cache(workspace13, fused_out_shape) - - self.fused_experts.apply( - fused_out, - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) - else: - # The leading output dimension may not be equal to M, so - # we compute output indices separately. - M_out = fused_out_shape[0] - assert M_out >= M - factor = M_out // M - assert factor > 0 - OUT_CHUNK_SIZE = CHUNK_SIZE * factor - - fused_out = torch.empty(fused_out_shape, - device=a1q.device, - dtype=workspace_dtype) - - assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, ( - f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}") - - for chunk in range(num_chunks): - begin_chunk_idx = chunk * CHUNK_SIZE - end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M) - begin_out_idx = chunk * OUT_CHUNK_SIZE - end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out) - curr_a1q = a1q[begin_chunk_idx:end_chunk_idx] - curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx, - end_chunk_idx) - curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx, - end_chunk_idx) - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - - self.fused_experts.apply( - fused_out[begin_out_idx:end_out_idx], - curr_a1q, - w1, - w2, - curr_topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=curr_a1q_scale, - a2_scale=curr_a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) - - self.prepare_finalize.finalize(output, fused_out, topk_weights, - topk_ids, apply_router_weight_on_input) + fused_out = self._maybe_chunk_fused_experts( + a1=a1, + a1q=a1q, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + expert_tokens_meta=expert_tokens_meta, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args) + + self.prepare_finalize.finalize( + output, fused_out, topk_weights, topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + extra_finalize_args) return output diff --git a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py index 3aae183dfa20..2c9ad509fa98 100644 --- a/vllm/model_executor/layers/fused_moe/moe_align_block_size.py +++ b/vllm/model_executor/layers/fused_moe/moe_align_block_size.py @@ -111,6 +111,8 @@ def moe_align_block_size_triton( dtype=torch.int32, device=topk_ids.device) tokens_per_thread = cdiv(numel, num_experts) + sorted_token_ids.fill_(numel) + expert_ids.zero_() moe_align_block_size_stage1[grid]( topk_ids, @@ -205,11 +207,8 @@ def moe_align_block_size( sorted_ids = torch.empty((max_num_tokens_padded, ), dtype=torch.int32, device=topk_ids.device) - sorted_ids.fill_(topk_ids.numel()) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) - # Expert ids must be zeroed out to prevent index out of bounds error while - # mapping global expert ids to local expert ids in expert parallelism. - expert_ids = torch.zeros((max_num_m_blocks, ), + expert_ids = torch.empty((max_num_m_blocks, ), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 20ee0d9f780a..d9059f50b445 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -76,43 +76,43 @@ def _moe_unpermute_and_reduce( def moe_permute( hidden_states: torch.Tensor, - topk_weights: torch.Tensor, + a1q_scale: Optional[torch.Tensor], topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - topk: int, n_expert: int, - n_local_expert: int, + n_local_expert: int = -1, expert_map: Optional[torch.Tensor] = None, align_block_size: Optional[int] = None, fill_invalid_expert: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor]: """ This function expands and permutes activation to gather uncontinuous tokens for each expert. Parameters: - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - topk_weights (torch.Tensor): topk expert route weight for each token. + - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - topk_ids (torch.Tensor): topk expert route id for each token. - - token_expert_indices (torch.Tensor): indice for expanded hidden. - - topk (int): The number of top-k experts to select. - n_expert (int): The number of expert. - n_local_expert (int): The number of expert in current EP rank. - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - from the global expert space to the local expert space of the expert + from the global expert space to the local expert space of the expert parallel shard. - align_block_size (Optional[int]): align group gemm block size for deepgemm - fill_invalid_expert(int): fill expert id in m_indices for invalid expert to workaround DeepGemm unsupported -1 in m_indices Returns: - permuted_hidden_states (torch.Tensor): permuted activation. + - a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states - expert_first_token_offset (torch.Tensor): offset of the first token of each expert for standard grouped gemm. if enable 'align_block_size' expert_first_token_offset will align up to 'align_block_size'. - - src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute. + - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute. + - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden. - m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records the group which the j-th row of the LHS belong to.` """ n_token, n_hidden = hidden_states.size() + topk = topk_ids.size(1) assert (n_hidden * hidden_states.element_size() ) % 16 == 0, "permue kernel need hidden dim align to 16B" permuted_row_size = n_token * topk @@ -120,12 +120,19 @@ def moe_permute( permuted_row_size = (permuted_row_size + n_expert * (align_block_size - 1) + align_block_size - 1) // align_block_size * align_block_size - + if n_local_expert == -1: + n_local_expert = n_expert permuted_hidden_states = torch.empty( (permuted_row_size, n_hidden), dtype=hidden_states.dtype, device=hidden_states.device, ) + token_expert_indices = torch.arange(0, + n_token * topk, + dtype=torch.int32, + device=hidden_states.device).reshape( + (n_token, topk)) + m_indices = torch.full((permuted_row_size, ), fill_invalid_expert, dtype=torch.int32, @@ -133,57 +140,54 @@ def moe_permute( expert_first_token_offset = torch.empty(n_local_expert + 1, dtype=torch.int64, device=hidden_states.device) - src_row_id2dst_row_id_map = torch.empty((n_token, topk), - dtype=torch.int32, - device=hidden_states.device) - torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids, - token_expert_indices, expert_map, n_expert, - n_local_expert, topk, align_block_size, - permuted_hidden_states, - expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) - return (permuted_hidden_states, expert_first_token_offset, - src_row_id2dst_row_id_map, m_indices) + permuted_idx = torch.full((permuted_row_size, ), + n_token * topk, + dtype=torch.int32, + device=hidden_states.device) + inv_permuted_idx = torch.empty((n_token, topk), + dtype=torch.int32, + device=hidden_states.device) + topk_ids = topk_ids.to(torch.int32) + torch.ops._moe_C.moe_permute(hidden_states, topk_ids, token_expert_indices, + expert_map, n_expert, n_local_expert, topk, + align_block_size, permuted_hidden_states, + expert_first_token_offset, inv_permuted_idx, + permuted_idx, m_indices) + if a1q_scale is not None: + a1q_scale = a1q_scale[permuted_idx.clamp(max=n_token * topk - 1) // + topk] + return (permuted_hidden_states, a1q_scale, expert_first_token_offset, + inv_permuted_idx.flatten(), m_indices) def moe_unpermute( + out: torch.Tensor, permuted_hidden_states: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - src_row_id2dst_row_id_map: torch.Tensor, - expert_first_token_offset: torch.Tensor, - topk: int, - n_expert: int, - n_local_expert: int, -) -> torch.Tensor: + inv_permuted_idx: torch.Tensor, + expert_first_token_offset: Optional[torch.Tensor] = None, +) -> None: """ This function expands and permutes activation to gathering uncontinuous tokens for each expert. Parameters: + - out (torch.Tensor): output tensor - permuted_hidden_states (torch.Tensor): permuted activation. - topk_weights (torch.Tensor): topk expert route weight for each token. - - topk_ids (torch.Tensor): topk expert route id for each token. - - expert_first_token_offset (torch.Tensor): offset of the first token - of each expert for grouped gemm. - - topk (int): The number of top-k experts to select. - - n_expert (int): The number of expert. - - n_local_expert (int): The number of expert in current EP rank. + - inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute. + - expert_first_token_offset (Optional[torch.Tensor]): offset of the first + token of each expert for grouped gemm. Returns: - hidden_states (torch.Tensor): The reduced and unpermuted activation tensor. """ - n_token, n_hidden = topk_weights.size(0), permuted_hidden_states.size(-1) + topk = topk_weights.size(1) + n_hidden = permuted_hidden_states.size(-1) assert (n_hidden * permuted_hidden_states.element_size() ) % 16 == 0, "unpermue kernel need hidden dim align to 16B" - hidden_states = torch.empty((n_token, n_hidden), - dtype=permuted_hidden_states.dtype, - device=permuted_hidden_states.device) - torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights, - topk_ids, src_row_id2dst_row_id_map, - expert_first_token_offset, n_expert, - n_local_expert, topk, hidden_states) - return hidden_states + inv_permuted_idx, expert_first_token_offset, + topk, out) def moe_permute_unpermute_supported(): diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 112305a4f2d0..46931f2dd7c7 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -1,16 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import pplx_kernels as pplx import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( _validate_scale_shape, moe_kernel_quantize_input) from vllm.utils import cdiv, round_up +logger = init_logger(__name__) + def pplx_hidden_dim_scale_bytes( max_num_tokens: int, @@ -78,29 +83,34 @@ def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens def topk_indices_dtype(self) -> Optional[torch.dtype]: - return torch.uint32 + return torch.int32 def num_dispatchers(self) -> int: return self.num_dispatchers_ def prepare( - self, - a1: torch.Tensor, - a1_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: Optional[torch.Tensor], - apply_router_weight_on_input: bool, + self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor, + topk_ids: torch.Tensor, num_experts: int, + expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + extra_prepare_args: Optional[dict[str, Any]] + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K assert topk_ids.size(0) == num_tokens - # assert expert_map is None, "NYI" + # expert_map should be None because with expert map, -1 id is used for + # non-local token; this causes error when casting ids to the + # topk_indices_dtype() int32 + # + if expert_map is not None: + logger.warning_once( + "The PPLX backend does not support expert mapping. " + "The provided `expert_map` will be ignored.") + expert_map = None #noqa: F841 # Is this always going to be a1.device? device = a1.device @@ -190,7 +200,7 @@ def prepare( out_expert_x_scale=expert_x_scale, dp_x=a1q, dp_x_scale=a1q_scale, - indices=topk_ids, + indices=topk_ids.view(dtype=torch.uint32), bound_m=bound_m, ) @@ -198,16 +208,20 @@ def prepare( expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] assert expert_x_scale.ndim == 3 - return expert_x, expert_x_scale, expert_num_tokens, None, None + expert_tokens_meta = mk.ExpertTokensMetadata( + expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) + + return expert_x, expert_x_scale, expert_tokens_meta, None, None + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + assert isinstance( + weight_and_reduce_impl, TopKWeightAndReduceDelegate + ), ("Weight application and reduction happens in the combine kernel.") - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: # This argument is optional # There's not much point setting this unless it is != topk_ids.size(0) bound_m: Optional[torch.Tensor] = None @@ -227,7 +241,7 @@ def finalize( topk_weights = torch.ones_like(topk_weights) self.a2a.combine(out_tokens=output, - indices=topk_ids, + indices=topk_ids.view(dtype=torch.uint32), weights=topk_weights, expert_y=fused_expert_output, bound_m=bound_m) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index e1114efe5a3f..696c7cdba9a7 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -1,13 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( - _moe_unpermute_and_reduce) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, TopKWeightAndReduceDelegate) from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -38,8 +38,10 @@ def prepare( expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, quant_config: FusedMoEQuantConfig, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], - Optional[torch.Tensor], Optional[torch.Tensor]]: + extra_prepare_args: Optional[dict[str, Any]], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], + Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor], + Optional[torch.Tensor]]: if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -48,19 +50,32 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1" a1.mul_(topk_weights.to(a1.dtype)) + if (extra_prepare_args is not None + and extra_prepare_args.get("skip_quant", True)): + # Skip quantization if explicitly requested + return a1, None, None, None, None + a1q, a1q_scale = moe_kernel_quantize_input( a1, a1_scale, quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape) return a1q, a1q_scale, None, None, None - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - ) -> None: - _moe_unpermute_and_reduce(output, fused_expert_output, None, - topk_weights, apply_router_weight_on_input) + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + extra_finalize_args: Optional[dict[str, Any]]) -> None: + if (extra_finalize_args is not None + and extra_finalize_args.get("skip_weight_reduce", True)): + assert output.shape == fused_expert_output.shape + output.copy_(fused_expert_output) + else: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py new file mode 100644 index 000000000000..fb398eec119f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +import vllm.model_executor.layers.fused_moe.modular_kernel as mk + + +class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): + """ + Useful in the case when some FusedMoEPermuteExpertsUnpermute + implementation does not perform weight application and reduction + but cannot address the needs of all the compatible PrepareAndFinalize + implementations. + For example, BatchedTritonExperts is compatible with both + PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize + does the weight-application + reduction as part of the pplx combine kernel. + But the BatchedPrepareAndFinalize needs an implementation. To facilitate + this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate + so the PrepareAndFinalize implementations could choose how to + weight + reduce. + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceDelegate) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + raise RuntimeError("The caller is expected to choose an appropriate " + "TopKWeightAndReduce implementation.") + + +class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): + """ + The fused_experts outputs have already been weight applied and reduced. + This implementation is a no-op. + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceNoOP) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + # Weight application and reduction operations are already done. + if output is None: + return fused_expert_output + + # MoEPrepareAndFinalizeNoEP needs the output to be in the `output` + # tensor. + assert output.size() == fused_expert_output.size(), ( + "output shape is expected to match the fused_expert_output shape. " + f"But got output={output.size()}, " + f"used_expert_output={fused_expert_output.size()}") + output.copy_(fused_expert_output, non_blocking=True) + return output + + +class TopKWeightAndReduceContiguous(mk.TopKWeightAndReduce): + """ + TopKWeightAndReduce implementation for a fused_experts output + of shape (m, topk, K) + """ + + def __eq__(self, other): + return isinstance(other, TopKWeightAndReduceContiguous) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + + m, num_topk = topk_ids.size() + k = fused_expert_output.size(-1) + if fused_expert_output.ndim == 2: + fused_expert_output = fused_expert_output.view(m, num_topk, k) + + assert fused_expert_output.size() == (m, num_topk, k), ( + f"Expected fused_expert_output size {(m, num_topk, k)}. But got " + f"{fused_expert_output.size()}") + + if not apply_router_weight_on_input: + fused_expert_output.mul_(topk_weights.view(m, -1, 1)) + + if output is None: + output = torch.empty((m, k), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype) + assert output.size() == (m, k), ( + f"Expected output size {(m, k)}. But got {output.size()}") + + ops.moe_sum(fused_expert_output, output) + return output + + +class TopKWeightAndReduceNaiveBatched(mk.TopKWeightAndReduce): + """ + TopKWeightAndReduce implementation for a fused_experts output + of shape (num_experts, batch_size, K) + """ + + def __init__(self, rank: int): + self.rank = rank + + def __eq__(self, other): + return (isinstance(other, TopKWeightAndReduceNaiveBatched) + and (other.rank == self.rank)) + + def apply(self, output: Optional[torch.Tensor], + fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> torch.Tensor: + assert fused_expert_output.ndim == 3 + num_tokens = topk_ids.size(0) + num_local_experts = fused_expert_output.size(0) + K = fused_expert_output.size(-1) + + if output is None: + output = torch.zeros((num_tokens, K), + device=fused_expert_output.device, + dtype=fused_expert_output.dtype) + else: + output.fill_(0) + + assert output.size() == (num_tokens, K), ( + f"Expected output size {(num_tokens, K)}, but got {output.size()}") + + first_expert = num_local_experts * self.rank + last_expert = first_expert + num_local_experts + + for expert_id in range(first_expert, last_expert): + matching_tokens = topk_ids == expert_id + topks = torch.any(matching_tokens, dim=1).flatten() + rows = torch.count_nonzero(topks) + rhs = fused_expert_output[expert_id - first_expert, :rows, :] + if not apply_router_weight_on_input: + rhs.mul_(topk_weights[matching_tokens].view(rhs.size(0), 1)) + output[topks] = output[topks] + rhs + + return output diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index e660376ebe6b..1b31368c79cd 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -1,14 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import Any, Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape) + DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape, + deep_gemm_block_shape) from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): @@ -19,6 +21,7 @@ def __init__( use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + use_mxfp4_w4a4: bool = False, per_act_token_quant: bool = False, block_shape: Optional[list[int]] = None, allow_deep_gemm: bool = False, @@ -29,6 +32,7 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, use_int4_w4a16=use_int4_w4a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, )) @@ -37,11 +41,14 @@ def __init__( use_int8_w8a8=use_int8_w8a8, use_int4_w4a16=use_int4_w4a16, use_int8_w8a16=use_int8_w8a16, + use_mxfp4_w4a4=use_mxfp4_w4a4, per_act_token_quant=per_act_token_quant, block_shape=block_shape, ) - self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant - and use_fp8_w8a8) + + self.allow_deep_gemm = (allow_deep_gemm and use_fp8_w8a8 and + self.block_shape == deep_gemm_block_shape()) + self.deep_gemm_expert = DeepGemmExperts( ) if self.allow_deep_gemm else None @@ -66,6 +73,25 @@ def supports_expert_map(self) -> bool: return ((dge is None or dge.supports_expert_map()) and (te is None or te.supports_expert_map())) + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + dge = self.deep_gemm_expert + te = self.triton_expert + dge_war = dge.finalize_weight_and_reduce_impl() if dge else None + te_war = te.finalize_weight_and_reduce_impl() if te else None + is_dge_war = dge_war is not None + is_te_war = te_war is not None + + if is_dge_war and is_te_war: + assert dge_war == te_war, ( + "Both implementations should agree on WeightAndReduce impls. " + f"Got dge_war: {dge_war}, and te_war: {te_war}") + + if dge_war is not None: + return dge_war + + assert te_war is not None + return te_war + def workspace_shapes( self, a: torch.Tensor, @@ -76,41 +102,38 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + if self.allow_deep_gemm and (_valid_deep_gemm_shape(M, N, K) + or is_blackwell_deep_gemm_used()): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, aq, M, N, K, topk, global_num_experts, local_num_experts) + a, aq, M, N, K, topk, global_num_experts, local_num_experts, + expert_tokens_meta) else: return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, global_num_experts, - local_num_experts) + local_num_experts, + expert_tokens_meta) - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ): + def apply(self, output: torch.Tensor, hidden_states: torch.Tensor, + w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, activation: str, global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: bool, + extra_expert_args: Optional[dict[str, Any]]): use_deep_gemm = (self.allow_deep_gemm - and _valid_deep_gemm(hidden_states, w1, w2)) + and (_valid_deep_gemm(hidden_states, w1, w2) + or is_blackwell_deep_gemm_used())) experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert assert experts is not None @@ -120,6 +143,7 @@ def apply( hidden_states, w1, w2, + topk_weights, topk_ids, activation, global_num_experts, @@ -132,5 +156,7 @@ def apply( a2_scale, workspace13, workspace2, - expert_num_tokens, + expert_tokens_meta, + apply_router_weight_on_input, + extra_expert_args, ) diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index a90cce719b48..966471b5c59b 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from math import prod -from typing import Optional +from typing import Any, Optional, Union import torch @@ -10,7 +10,83 @@ per_token_group_quant_fp8) from vllm.model_executor.layers.quantization.utils.int8_utils import ( per_token_group_quant_int8, per_token_quant_int8) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + quant_dequant_mxfp4) +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils import cdiv +from vllm.utils.flashinfer import fp4_quantize + + +@triton.jit +def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts, + topk_numel, expert_map, + HAS_EXPERT_MAP: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + + curr_expert = tl.program_id(0) + + offsets = tl.arange(0, BLOCK_SIZE) + topk_ids_ptrs = topk_ids_ptr + offsets + + acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32) + for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)): + mask = offsets < (topk_numel - x * BLOCK_SIZE) + expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1) + if HAS_EXPERT_MAP: + expert_map_ptrs = expert_map + expert_ids + expert_map_mask = expert_ids >= 0 + expert_ids = tl.load(expert_map_ptrs, + mask=expert_map_mask, + other=-1) + + has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0) + acc = acc + has_curr_expert + topk_ids_ptrs += BLOCK_SIZE + + if curr_expert < num_experts: + tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc)) + + +def count_expert_num_tokens( + topk_ids: torch.Tensor, num_local_experts: int, + expert_map: Optional[torch.Tensor]) -> torch.Tensor: + """ + Count the number to tokens assigned to each expert. + + Parameters: + - topk_ids (torch.Tensor): Tensor mapping each token to its + list of experts. + - num_local_experts (int): Number of experts in this rank. + - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices + from the global expert space to the local expert space of the expert + parallel shard. + + Returns: + A tensor of size num_local_experts, where tensor[i] holds the number + of tokens assigned to the ith expert. + """ + assert topk_ids.dtype.is_signed, ( + "The kernel uses -1 to represent invalid topk_ids") + expert_num_tokens = torch.empty((num_local_experts), + device=topk_ids.device, + dtype=torch.int32) + + grid = num_local_experts + BLOCK_SIZE = min(topk_ids.numel(), 1024) + BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE) + + _count_expert_num_tokens[(grid, )]( + topk_ids, + expert_num_tokens, + num_local_experts, + topk_ids.numel(), + expert_map, + HAS_EXPERT_MAP=expert_map is not None, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return expert_num_tokens def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: @@ -23,6 +99,16 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: return x.flatten()[:prod(v)].view(*v) +def _fp4_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + is_sf_swizzled_layout: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_sf_swizzled_layout) + + def _fp8_quantize( A: torch.Tensor, A_scale: Optional[torch.Tensor], @@ -34,6 +120,8 @@ def _fp8_quantize( is provided, the output will be blocked. """ if block_shape is None: + # TODO(luka): use QuantFP8 custom op + # https://github.com/vllm-project/vllm/issues/20711 A, A_scale = ops.scaled_fp8_quant( A, A_scale, use_per_token_if_dynamic=per_act_token) else: @@ -74,17 +162,39 @@ def _int8_quantize( return A, A_scale +def _mxfp4_quantize( + A: torch.Tensor, + A_scale: Optional[torch.Tensor], + per_act_token_quant: bool, + block_shape: Optional[list[int]] = None, +) -> tuple[torch.Tensor, None]: + assert block_shape is None + if not current_platform.supports_mx(): + A = quant_dequant_mxfp4(A) + else: + raise NotImplementedError() + + return A, None + + def moe_kernel_quantize_input( A: torch.Tensor, A_scale: Optional[torch.Tensor], - quant_dtype: Optional[torch.dtype], + quant_dtype: Union[None, torch.dtype, str], per_act_token_quant: bool, block_shape: Optional[list[int]] = None, + is_fp4_scale_swizzled: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if quant_dtype == torch.float8_e4m3fn: return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) + elif quant_dtype == torch.uint8: # nvfp4 + return _fp4_quantize(A, + A_scale, + is_sf_swizzled_layout=is_fp4_scale_swizzled) + elif quant_dtype == "mxfp4": + return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape) else: return A, A_scale @@ -142,3 +252,17 @@ def _validate_scale_shape( assert block_shape is not None expected = (a.shape[0], cdiv(a.shape[1], block_shape[1])) assert a_scale.shape == expected, f"{a_scale.shape} == {expected}" + + +def extract_required_args( + extra_args: Optional[dict[str, Any]], + required_keys: list[str], +) -> tuple[Any, ...]: + if extra_args is None: + raise ValueError("`extra_args` must be provided.") + + missing_keys = [k for k in required_keys if k not in extra_args] + if missing_keys: + raise ValueError(f"Missing keys in `extra_args`: {missing_keys}") + + return tuple(extra_args[k] for k in required_keys) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e8d1fd635505..a5fc1db2dc10 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -170,26 +170,6 @@ def forward_cuda( else: return norm_func(x, self.weight.data, self.variance_epsilon) - def forward_hpu( - self, - x: torch.Tensor, - residual: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - from vllm_hpu_extension.kernels import rms_norm - HPUFusedRMSNorm = rms_norm() - if HPUFusedRMSNorm is None: - return self.forward_native(x, residual) - if residual is not None: - orig_shape = x.shape - residual += x.view(residual.shape) - # Note: HPUFusedRMSNorm requires 3D tensors as inputs - x = HPUFusedRMSNorm.apply(residual, self.weight, - self.variance_epsilon) - return x.view(orig_shape), residual - - x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon) - return x - def forward_xpu( self, x: torch.Tensor, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a05ae0edbd77..bb81a663d454 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -259,6 +259,8 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + self.quant_config = quant_config + self.prefix = prefix if quant_config is None: self.quant_method: Optional[ QuantizeMethodBase] = UnquantizedLinearMethod() @@ -300,6 +302,12 @@ def __init__( *, return_bias: bool = True, ): + # If MergedReplicatedLinear, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = self.output_sizes + else: + self.output_partition_sizes = [output_size] + super().__init__(input_size, output_size, skip_bias_add, @@ -311,7 +319,8 @@ def __init__( # All the linear layer supports quant method. assert self.quant_method is not None self.quant_method.create_weights(self, - self.input_size, [self.output_size], + self.input_size, + self.output_partition_sizes, self.input_size, self.output_size, self.params_dtype, @@ -367,6 +376,73 @@ def extra_repr(self) -> str: return s +class MergedReplicatedLinear(ReplicatedLinear): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): + self.output_sizes = output_sizes + super().__init__(input_size, + sum(output_sizes), + bias, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + return_bias=return_bias) + + def weight_loader(self, + param: Union[Parameter, BasevLLMParameter], + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + assert loaded_shard_id is not None + assert loaded_shard_id < len(self.output_sizes) + + if isinstance(param, BlockQuantScaleParameter): + from vllm.model_executor.layers.quantization.fp8 import ( + Fp8LinearMethod, Fp8MoEMethod) + assert self.quant_method is not None + assert isinstance(self.quant_method, + (Fp8LinearMethod, Fp8MoEMethod)) + weight_block_size = self.quant_method.quant_config.weight_block_size + assert weight_block_size is not None + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // + block_n) + shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) // + block_n) + elif isinstance(param, PerTensorScaleParameter): + shard_offset = loaded_shard_id + shard_size = 1 + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + + param[shard_offset:shard_offset + shard_size] = loaded_weight + + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -452,8 +528,10 @@ def __init__( else: self.register_parameter("bias", None) + self.tp_rank = get_tensor_model_parallel_rank() + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -472,15 +550,15 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if is_gguf_weight and isinstance(param, UninitializedParameter): final_shape = list(loaded_weight.shape) if output_dim is not None: - tp_size = get_tensor_model_parallel_world_size() - assert final_shape[output_dim] % tp_size == 0 - final_shape[output_dim] = final_shape[output_dim] // tp_size + assert final_shape[output_dim] % self.tp_size == 0 + final_shape[output_dim] = (final_shape[output_dim] // + self.tp_size) param.materialize(final_shape, dtype=loaded_weight.dtype) param_data = param.data if output_dim is not None and not is_sharded_weight: shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -565,8 +643,11 @@ def __init__( return_bias: bool = True, ): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() - assert all(output_size % tp_size == 0 for output_size in output_sizes) + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert all(output_size % self.tp_size == 0 + for output_size in output_sizes) super().__init__(input_size=input_size, output_size=sum(output_sizes), bias=bias, @@ -598,12 +679,10 @@ def weight_loader(self, return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, @@ -669,11 +748,10 @@ def weight_loader(self, return assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = (sum(self.output_sizes[:loaded_shard_id]) // + self.tp_size) + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -701,7 +779,7 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size if not is_sharded_weight: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -991,12 +1069,9 @@ def weight_loader(self, return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size if loaded_shard_id is not None: loaded_weight = loaded_weight.narrow(output_dim, start_idx, @@ -1071,7 +1146,6 @@ def weight_loader(self, self.weight_loader(param, loaded_weight_shard, shard_id) return - tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. @@ -1123,9 +1197,9 @@ def weight_loader(self, param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = tp_rank + shard_id = self.tp_rank else: - shard_id = tp_rank // self.num_kv_head_replicas + shard_id = self.tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size if not is_sharded_weight: @@ -1245,8 +1319,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) is_sharded_weight = getattr(param, "is_sharded_weight", False) @@ -1264,13 +1336,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = weight_shape[input_dim] // tp_size + weight_shape[input_dim] = (weight_shape[input_dim] // + self.tp_size) param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data if input_dim is not None and not is_sharded_weight: shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 3d01253447c0..e93be9bfb165 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -59,11 +59,12 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: Optional[SamplingMetadata] = None, embedding_bias: Optional[torch.Tensor] = None, + prune_hidden_states: bool = True, ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: - if sampling_metadata is not None: + if sampling_metadata is not None and prune_hidden_states: hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) diff --git a/vllm/model_executor/layers/mamba/abstract.py b/vllm/model_executor/layers/mamba/abstract.py new file mode 100644 index 000000000000..4c4997b4894a --- /dev/null +++ b/vllm/model_executor/layers/mamba/abstract.py @@ -0,0 +1,29 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from collections.abc import Iterable + +import torch + + +class MambaBase(ABC): + """ + Base class for Mamba-like layers which support the v1 engine. + Inherit from this class if you implement a custom layer. + """ + + # Contains the KV cache (mamba state) for the layer + # in the shape specified by `self.get_state_shape`. + # The outer list is for v0 PP virtual engine. Though this code path + # only runs for v1, we have to do this to unify with the interface + # of Attention + v0 PP. + kv_cache: list[Iterable[torch.Tensor]] + + @abstractmethod + def get_state_shape(self) -> Iterable[tuple[int, ...]]: + """ + Defines the shape of the state. + For mamba layers this is usually a (conv_state, ssm_state) tuple. + In this case, returns (conv_state_shape, ssm_state_shape). + """ + pass diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 88053faf9e52..0a836fd17533 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from dataclasses import dataclass +from typing import Optional, Union +import numpy as np import torch from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.placeholder_attn import ( PlaceholderAttentionMetadata) +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.platforms import current_platform +from vllm.v1.attention.backends.mamba_attn import ( + Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets) @dataclass @@ -21,6 +25,29 @@ class Mamba2Metadata: seq_idx: torch.Tensor chunk_indices: torch.Tensor chunk_offsets: torch.Tensor + """ + With continuous batching layout of `x` in vLLM, to enable a Triton program + to handle a request in parallel, two supporting tensors are used + (batch_ptr, token_chunk_offset_ptr) + BLOCK_M = the # tokens to be handled by a Triton program + (can be customized for different hardware) + + nums_dict: + tracks the data associated with a given value of BLOCK_M + BLOCK_M = #tokens handled by a Triton program + cu_seqlen: total tokens per batch + (used as flag to update other data at each new input) + batch_ptr: tracks batch-id handled by the Triton program + token_chunk_offset_ptr: tracks token group_idx handled by the Triton program + (Triton implementation of causal_conv1d handles parallelism in 3-axes + - feature-axis + - batch-axis + - sequence-axis) + """ + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: @@ -38,45 +65,10 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: f"Unsupported platform for Mamba2: {current_platform.device_type}") -def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, - chunk_size: int, - total_seqlens: int): - - cu_seqlens = query_start_loc[1:] # remove prepended 0 - - # outputs will have length expansion of chunks that do not divide - # chunk_size - N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size - > 0).sum() - chunk_indices = torch.arange(N, - dtype=torch.int, - device=query_start_loc.device) - chunk_offsets = torch.zeros((N, ), - dtype=torch.int, - device=query_start_loc.device) - - p = 0 # num of insertions - for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): - - # if does not divide chunk_size, then there is one chunk insertion - p += (s % chunk_size > 0) - - # get the dimensions - # - the + 1 for _e is to shift the boundary by one chunk - # - this shifting is not needed if chunk_size divides e - _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size - > 0) - - # adjust inidces and offsets - chunk_indices[_s:_e] -= p - chunk_offsets[_s] = s % chunk_size - - return chunk_indices, chunk_offsets - - def prepare_mamba2_metadata( chunk_size: int, attn_metadata: AttentionMetadata, + mamba2_metadata=None, ) -> Mamba2Metadata: # compute number of prefill and decode requests @@ -96,12 +88,12 @@ def prepare_mamba2_metadata( attn_metadata_instances = get_platform_metadata_classes() if (isinstance(attn_metadata, attn_metadata_instances) and attn_metadata.context_lens_tensor is not None): - has_initial_states = \ - attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] - # precompute flag to avoid device syncs in mamba2 layer forwards + # precompute flag to avoid device syncs later in mamba2 layer + # forwards # prep is only needed for mamba2 ssd prefill processing - prep_initial_states = torch.any(has_initial_states).item() - + has_initial_states = attn_metadata.context_lens_tensor > 0 + prep_initial_states = torch.any( + has_initial_states[:num_prefills]).item() query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] seq_idx = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, device=query_start_loc.device), @@ -117,9 +109,78 @@ def prepare_mamba2_metadata( _query_start_loc_to_chunk_indices_offsets( query_start_loc, chunk_size, num_prefill_tokens) + if mamba2_metadata is not None: + mamba2_metadata.has_initial_states = has_initial_states + mamba2_metadata.prep_initial_states = prep_initial_states + mamba2_metadata.chunk_size = chunk_size + mamba2_metadata.seq_idx = seq_idx + mamba2_metadata.chunk_indices = chunk_indices + mamba2_metadata.chunk_offsets = chunk_offsets + # We use 1 reset flag: + # * mamba2_metadata.cu_seqlen is None + # update config specific to (each input) + # (become available at first layer, e.g. conv_weights) + mamba2_metadata.cu_seqlen = None # suppose to be updated at each input + + return mamba2_metadata return Mamba2Metadata(has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, chunk_size=chunk_size, seq_idx=seq_idx, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets) + + +def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, + mamba2_metadata: Union[Mamba2Metadata, + Mamba2AttentionMetadata]): + """ + this is triggered upon handling a new input at the first layer + """ + dim, cu_seqlen = x.shape + mamba2_metadata.cu_seqlen = cu_seqlen + seqlens = np.diff(query_start_loc.to('cpu')) + nums_dict = {} # type: ignore + for BLOCK_M in [8]: # cover all BLOCK_M values + nums = -(-seqlens // BLOCK_M) + nums_dict[BLOCK_M] = {} + nums_dict[BLOCK_M]['nums'] = nums + nums_dict[BLOCK_M]['tot'] = nums.sum().item() + mlist = torch.from_numpy(np.repeat(np.arange(len(nums)), nums)) + nums_dict[BLOCK_M]['mlist'] = mlist + mlist_len = len(nums_dict[BLOCK_M]['mlist']) + nums_dict[BLOCK_M]['mlist_len'] = mlist_len + MAX_NUM_PROGRAMS = max(1024, mlist_len) * 2 + offsetlist = [] # type: ignore + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) + offsetlist = torch.tensor(offsetlist, dtype=torch.int32) + nums_dict[BLOCK_M]['offsetlist'] = offsetlist + + if mamba2_metadata.batch_ptr is None: + # Update default value after class definition + #mamba2_metadata.MAX_NUM_PROGRAMS *= 2 + mamba2_metadata.batch_ptr = torch.full((MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + mamba2_metadata.token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device='cuda') + else: + if mamba2_metadata.batch_ptr.nelement() < MAX_NUM_PROGRAMS: + mamba2_metadata.batch_ptr.resize_(MAX_NUM_PROGRAMS).fill_( + PAD_SLOT_ID) + mamba2_metadata.token_chunk_offset_ptr.resize_( # type: ignore + MAX_NUM_PROGRAMS).fill_(PAD_SLOT_ID) + + mamba2_metadata.batch_ptr[0:mlist_len].copy_(mlist) + mamba2_metadata.token_chunk_offset_ptr[ # type: ignore + 0:mlist_len].copy_(offsetlist) + nums_dict[BLOCK_M]['batch_ptr'] = mamba2_metadata.batch_ptr + nums_dict[BLOCK_M]['token_chunk_offset_ptr'] = ( + mamba2_metadata.token_chunk_offset_ptr) # type: ignore + mamba2_metadata.nums_dict = nums_dict + return mamba2_metadata diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 118bd8d55c1d..796c8d937572 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -159,7 +159,7 @@ def forward_cuda(self, hidden_states: torch.Tensor, hidden_states = causal_conv1d_fn( hidden_states, conv_weights, - self.conv1d.bias, + bias=self.conv1d.bias, activation=self.activation, conv_states=mamba_cache_params.conv_state, has_initial_state=attn_metadata.context_lens_tensor > 0, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9dcbcb2e6f2b..e32b2be4d40e 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -13,11 +13,15 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.forward_context import get_forward_context +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.mamba.mamba2_metadata import Mamba2Metadata +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, + update_metadata) +from vllm.model_executor.layers.mamba.mamba_utils import ( + extra_groups_for_head_shards, get_mamba_state_shape) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -29,6 +33,8 @@ LoaderFunction, composed_weight_loader, sharded_weight_loader) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata # Added by the IBM Team, 2024 @@ -144,26 +150,14 @@ def forward_cuda( return out -def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" - - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 - - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups - - def mamba_v2_sharded_weight_loader( shard_spec: list[tuple[int, int, float]], tp_size: int, tp_rank: int, ) -> LoaderFunction: - """Create a weight loader for mamba v2. This ensures that the projections - are correctly sharded so that they can be split into x, B, C. It also - ensures that all the groups corresponding to a head shard is placed + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures that all the groups corresponding to a head shard is placed together with it. """ @@ -218,7 +212,7 @@ def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer @CustomOp.register("mamba_mixer2") -class MambaMixer2(CustomOp): +class MambaMixer2(MambaBase, CustomOp): """ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. A, D are input independent @@ -230,22 +224,21 @@ class MambaMixer2(CustomOp): """ def __init__( - self, - hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, - use_conv_bias: bool, - use_bias: bool, - n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, - rms_norm_eps: float = 1e-5, - activation: str = "silu", - use_rms_norm: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - chunk_size: int = -1, # the chunk size used by v1 + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -427,23 +420,42 @@ def __init__( # of Attention + v0 PP. # The inner tuple is (conv_state, ssm_state) self.kv_cache = [(torch.tensor([]), torch.tensor([]))] - assert chunk_size != -1, "chunk_size must be set for v1" - # NOTE: chunk_size may be -1 for models without v1 support - self.chunk_size = chunk_size self.prefix = prefix def forward_native( self, hidden_states: torch.Tensor, - conv_state: torch.Tensor, - ssm_state: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, ): pass + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mamba_cache_params: MambaCacheParams, + mamba2_metadata: Mamba2Metadata, + mup_vector: Optional[torch.Tensor] = None, + ): + if not envs.VLLM_USE_V1: + CustomOp.forward(self, hidden_states, output, mamba_cache_params, + mamba2_metadata, mup_vector) + else: + torch.ops.vllm.mamba_mixer2( + hidden_states, + output, + self.prefix, + mup_vector, + ) + def forward_cuda( self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: MambaCacheParams, mamba2_metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, @@ -458,9 +470,11 @@ def forward_cuda( if attn_metadata is not None: assert isinstance(attn_metadata, dict) attn_metadata = attn_metadata[self.prefix] + mamba2_metadata = attn_metadata assert isinstance(attn_metadata, Mamba2AttentionMetadata) self_kv_cache = self.kv_cache[forward_context.virtual_engine] - conv_state = self_kv_cache[0] + # conv_state = (..., dim, width-1) yet contiguous along 'dim' + conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor has_initial_states_p = attn_metadata.has_initial_states @@ -527,24 +541,26 @@ def forward_cuda( num_prefill_tokens = attn_metadata.num_prefill_tokens # token count has_prefill = num_prefills > 0 has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes # NOTE: V0 put prefill before decode, v1 puts decode before prefill # Separate prefill and decode by splitting varlen input # Split along token dimension + # NOTE: V0 put prefill before decode, v1 puts decode before prefill if envs.VLLM_USE_V1: hidden_states_B_C_d, hidden_states_B_C_p = torch.split( - hidden_states_B_C, + hidden_states_B_C[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0, ) dt_d, dt_p = torch.split( - dt, + dt[:num_actual_tokens], [num_decodes, num_prefill_tokens], dim=0, ) # Split along batch dimension state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, + state_indices_tensor[:num_actual_tokens], [num_decodes, num_prefills], dim=0, ) @@ -579,19 +595,23 @@ def forward_cuda( # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" + x = hidden_states_B_C_p.transpose( + 0, 1) # this is the form that causal-conv see + if mamba2_metadata.cu_seqlen is None: + mamba2_metadata = update_metadata(x, query_start_loc_p, + mamba2_metadata) hidden_states_B_C_p = causal_conv1d_fn( - hidden_states_B_C_p.transpose(0, 1), + x, conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, has_initial_state=has_initial_states_p, cache_indices=state_indices_tensor_p, + metadata=mamba2_metadata, query_start_loc=query_start_loc_p).transpose( 0, 1)[:num_prefill_tokens] - # TODO: Why is this needed? - hidden_states_B_C_p = hidden_states_B_C_p.contiguous() hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn( hidden_states_B_C_p) @@ -599,9 +619,14 @@ def forward_cuda( initial_states = None if (has_initial_states_p is not None and prep_initial_states): # making a copy of the states - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) + if envs.VLLM_USE_V1: + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + else: + initial_states = torch.where( + has_initial_states_p[:num_prefills, None, None, None], + ssm_state[state_indices_tensor_p], 0) scan_output, varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, @@ -696,36 +721,51 @@ def forward_cuda( # GatedRMSNorm internally applying SiLU to the gate # SiLU is applied internally before normalization, unlike standard # norm usage - hidden_states = self.norm(hidden_states, gate) + hidden_states = self.norm(hidden_states, gate[:num_actual_tokens]) # 5. Final linear projection - out, _ = self.out_proj(hidden_states) - return out + output[:num_actual_tokens], _ = self.out_proj(hidden_states) def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - world_size = get_tensor_model_parallel_world_size() - - conv_state_shape, temporal_state_shape = None, None - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.n_groups + - extra_groups_for_head_shards(self.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (self.intermediate_size + - 2 * n_groups * self.ssm_state_size) - conv_state_shape = ( - divide(conv_dim, world_size), - self.conv_kernel_size - 1, + return get_mamba_state_shape( + intermediate_size=self.intermediate_size, + tp_world_size=get_tensor_model_parallel_world_size(), + n_groups=self.n_groups, + num_heads=self.num_heads, + head_dim=self.head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, ) - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.num_heads, world_size), - self.head_dim, - self.ssm_state_size, - ) - return conv_state_shape, temporal_state_shape + +def mamba_mixer2( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, + mup_vector: Optional[torch.Tensor] = None, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None, + mamba2_metadata=None, + mup_vector=mup_vector) + + +def mamba_mixer2_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, + mup_vector: Optional[torch.Tensor] = None, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer2", + op_func=mamba_mixer2, + mutates_args=["output"], + fake_impl=mamba_mixer2_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py new file mode 100644 index 000000000000..99a582066c0d --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.distributed import divide + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + +def get_mamba_state_shape( + intermediate_size: int, + tp_world_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, +) -> tuple[tuple[int, int], tuple[int, int, int]]: + """ Get the shape of mamba state.""" + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (n_groups + + extra_groups_for_head_shards(n_groups, tp_world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + 2 * n_groups * state_size) + # contiguous along 'dim' axis + conv_state_shape = ( + conv_kernel - 1, + divide(conv_dim, tp_world_size), + ) + + if not use_v1: + conv_state_shape = (conv_state_shape[1], conv_state_shape[0]) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = ( + divide(num_heads, tp_world_size), + head_dim, + state_size, + ) + + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index a10c5ab69787..b8d4bbc37105 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -4,102 +4,942 @@ # Copyright (c) 2024, Tri Dao. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -from typing import Optional +from typing import Optional, Union +import numpy as np import torch -from vllm import _custom_ops as ops from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.triton_utils import tl, triton -def causal_conv1d_fn(x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID): - """ - x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + batch_ptr, + token_chunk_offset_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + batch: tl.int32, # actually padded_batch + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl. + constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.load(batch_ptr + tl.program_id(0)) + chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + # base of the sequence + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( + tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = (conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to( + tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - + 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 5: # STRATEGY1 + col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange( + 0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = x_ptr + ( + (sequence_start_index + idx_tokens_last) * + stride_x_token)[:, None] + ( + idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,] + mask_x = ((idx_tokens_last >= 0)[:, None] & + (idx_tokens_last < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = conv_states_base[None, :] + ( + idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats + < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, + None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ( + (idx_tokens_conv - VAL) * + stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & + (idx_tokens_conv - VAL < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + + tl.debug_barrier( + ) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = conv_states_base + ( + idx_tokens_conv * + stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv + < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ( + (idx_tokens_conv - VAL) * + stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] & + (idx_tokens_conv - VAL < seqlen)[:, None] & + (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = conv_states_base + ( + idx_tokens_conv * + stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv + < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca') + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & ( + idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token + ) * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch sequences are concatenated from left to right for varlen weight: (dim, width) - bias: (dim,) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 - indicates the corresponding state index, + indicates the corresponding state index, like so: conv_state = conv_states[cache_indices[batch_id]] has_initial_state: (batch) bool - indicates whether should the kernel take the current state as initial + indicates whether should the kernel take the current state as initial state for the calculations - conv_states: (...,dim,width - 1) itype - updated inplace if provided - activation: either None or "silu" or "swish" + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 - out: (batch, dim, seqlen) + out: same shape as `x` """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(-1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - - ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, - cache_indices, has_initial_state, activation - in ["silu", "swish"], pad_slot_id) - return x - - -def causal_conv1d_update(x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID): + if isinstance(activation, bool) and activation: + activation = "silu" + + args = None + out = torch.empty_like(x) + if metadata is not None: + cu_seqlen = metadata.cu_seqlen + nums_dict = metadata.nums_dict + #x = metadata.x + args = nums_dict + batch_ptr = metadata.batch_ptr + token_chunk_offset_ptr = metadata.token_chunk_offset_ptr + else: + seqlens = np.diff(query_start_loc.to('cpu')) + args = seqlens + MAX_NUM_PROGRAMS = 1024 + + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=x.device + ) # tracking which seq-idx the Triton program is handling + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS, ), + PAD_SLOT_ID, + dtype=torch.int32, + device=x.device + ) # tracking BLOCK_M-based index in the sequence the Triton program is handling + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + padded_batch = query_start_loc.size(0) - 1 + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert (num_cache_lines, dim, width - 1) == conv_states.shape + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch, ) + assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + if metadata is None: + + def num_program(META, seqlens): + tot = 0 + + mlist = [] + offsetlist = [] # type: ignore + + nums = -(-seqlens // META["BLOCK_M"]) + + tot = nums.sum().item() + mlist = np.repeat(np.arange(len(nums)), nums) + for idx, num in enumerate(nums): + offsetlist.extend( + range(num) + ) # chunk-idx if a sequence is split into multiple chunks + + if META["batch_ptr"].nelement() < len(mlist): + newlen = len(mlist) + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_( + PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= len(mlist): + META["batch_ptr"][0:len(mlist)].copy_( + torch.from_numpy(np.array(mlist))) + META["token_chunk_offset_ptr"][0:len(mlist)].copy_( + torch.from_numpy(np.array(offsetlist))) + + META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) + META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( + META["x_ptr"].device) + return tot + else: + + def num_program(META, nums_dict): + tot = nums_dict[META["BLOCK_M"]]['tot'] + + mlist = nums_dict[META["BLOCK_M"]]['mlist'] + mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len'] + + offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist'] + + if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: + META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[ + META["BLOCK_M"]]["token_chunk_offset_ptr"] + else: + if META["batch_ptr"].nelement() < mlist_len: + newlen = mlist_len + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_( + PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= mlist_len: + META["batch_ptr"][0:mlist_len].copy_(mlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_( + offsetlist) + return tot + + def grid(META): + return ( + num_program(META, args), + triton.cdiv(dim, META["BLOCK_N"]), + ) + + if batch_ptr.device != x.device: + batch_ptr = batch_ptr.to(x.device) + token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + batch_ptr, + token_chunk_offset_ptr, + out, + # Matrix dimensions + padded_batch, + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + #launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to( + tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + # STEP 1: READ init_state data + conv_states_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_state_ptrs_source = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim + ) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ( + (idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens - VAL >= 0)[:, None] & + (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = conv_state_base + ( + idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim + ) # token-index # feature-index + o_ptrs = o_ptr + ( + idx_seq) * stride_o_seq + idx_token * stride_o_token + ( + idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): """ x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) cache_seqlens: (batch,), dtype int32. If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state + The conv_state will be updated by copying x to the conv_state starting at the index @cache_seqlens % state_len. conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, + If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 out: (batch, dim) or (batch, dim, seqlen) """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - activation_val = activation in ["silu", "swish"] + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] unsqueeze = x.dim() == 2 if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) - ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, - cache_seqlens, conv_state_indices, pad_slot_id) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert conv_state.stride( + -2 + ) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch, ) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride( + ) # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + ) + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) if unsqueeze: - x = x.squeeze(-1) - return x + out = out.squeeze(-1) + return out diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index d864a915a073..c06cca080227 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -1,25 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from abc import ABC, abstractmethod +from collections.abc import Mapping, Set +from dataclasses import dataclass from enum import IntEnum -from typing import Optional, Union +from itertools import groupby +from typing import Callable, Optional, TypeVar, Union import torch import torch.nn as nn import torch.nn.functional as F -from typing_extensions import assert_never +from transformers import PretrainedConfig from vllm.config import ModelConfig, PoolerConfig from vllm.model_executor.pooling_metadata import ( # noqa: E501 PoolingMetadata as V0PoolingMetadata) from vllm.model_executor.pooling_metadata import PoolingTensors +from vllm.pooling_params import PoolingParams, PoolingTask from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput -from vllm.transformers_utils.config import ( - get_classification_activation_function, - get_cross_encoder_activation_function) +from vllm.utils import resolve_obj_by_qualname from vllm.v1.pool.metadata import PoolingMetadata as V1PoolingMetadata PoolingMetadata = Union[V0PoolingMetadata, V1PoolingMetadata] +PoolingFn = Callable[ + [Union[torch.Tensor, list[torch.Tensor]], PoolingMetadata], + Union[torch.Tensor, list[torch.Tensor]]] +ClassifierFn = Callable[[torch.Tensor], torch.Tensor] class PoolingType(IntEnum): @@ -31,165 +37,357 @@ class PoolingType(IntEnum): MEAN = 4 -class SimplePooler(nn.Module): - """A layer that pools specific information from hidden states. +@dataclass(frozen=True) +class ResolvedPoolingConfig: + pooling_type: PoolingType - This layer does the following: - 1. Extracts specific tokens or aggregates data based on pooling method. - 2. Normalizes output if specified. - 3. Returns structured results as `PoolerOutput`. + normalize: bool + softmax: bool + step_tag_id: Optional[int] + returned_token_ids: Optional[list[int]] - Attributes: - pooling_type: The type of pooling to use. - normalize: Whether to normalize the pooled data. - """ - - @staticmethod - def from_pooling_type( + @classmethod + def from_config_with_defaults( + cls, + pooler_config: PoolerConfig, pooling_type: PoolingType, - *, normalize: bool, softmax: bool, step_tag_id: Optional[int] = None, returned_token_ids: Optional[list[int]] = None, - ) -> "SimplePooler": + ) -> "ResolvedPoolingConfig": + return cls( + pooling_type=PoolingType[pooler_config.pooling_type] + if pooler_config.pooling_type is not None else pooling_type, + normalize=pooler_config.normalize + if pooler_config.normalize is not None else normalize, + softmax=pooler_config.softmax + if pooler_config.softmax is not None else softmax, + step_tag_id=pooler_config.step_tag_id + if pooler_config.step_tag_id is not None else step_tag_id, + returned_token_ids=pooler_config.returned_token_ids + if pooler_config.returned_token_ids is not None else + returned_token_ids, + ) + + +@dataclass(frozen=True) +class PoolingParamsUpdate: + requires_token_ids: bool = False + """Set this flag to enable `get_prompt_token_ids` for your pooler.""" + + def apply(self, params: PoolingParams) -> None: + params.requires_token_ids = self.requires_token_ids + + +class Pooler(nn.Module, ABC): + """The interface required for all poolers used in pooling models in vLLM.""" + + @staticmethod + def for_encode( + pooler_config: PoolerConfig, + *, + default_pooling_type: PoolingType = PoolingType.ALL, + default_normalize: bool = False, + default_softmax: bool = False, + default_step_tag_id: Optional[int] = None, + default_returned_token_ids: Optional[list[int]] = None, + ): + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + step_tag_id=default_step_tag_id, + returned_token_ids=default_returned_token_ids, + ) + + if resolved_config.pooling_type == PoolingType.STEP: + return StepPooler.from_config(resolved_config) + + return SimplePooler.from_config(resolved_config) + + @staticmethod + def for_embed( + pooler_config: PoolerConfig, + *, + default_pooling_type: PoolingType = PoolingType.LAST, + default_normalize: bool = True, + default_softmax: bool = False, + ): + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) + + return SimplePooler.from_config(resolved_config) + + @staticmethod + def for_classify( + pooler_config: PoolerConfig, + classifier: Optional[ClassifierFn], + *, + default_pooling_type: PoolingType = PoolingType.LAST, + default_normalize: bool = False, + default_softmax: bool = True, + ): + resolved_config = ResolvedPoolingConfig.from_config_with_defaults( + pooler_config=pooler_config, + pooling_type=default_pooling_type, + normalize=default_normalize, + softmax=default_softmax, + ) + base_pooler = SimplePooler.from_config(resolved_config) + if classifier is None: + return base_pooler + + return ClassifierPooler( + pooling=base_pooler.pooling, + classifier=classifier, + act_fn=base_pooler.head.activation, + ) + + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + """Determine which pooling tasks are supported.""" + raise NotImplementedError + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + """ + Construct the updated pooling parameters to use for a supported task. + """ + return PoolingParamsUpdate() + + @abstractmethod + def forward( + self, + hidden_states: Union[list[torch.Tensor], torch.Tensor], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + raise NotImplementedError + + +def get_prompt_lens( + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, +) -> torch.Tensor: + if isinstance(pooling_metadata, V1PoolingMetadata): + return pooling_metadata.prompt_lens + + return PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states[0].device).prompt_lens + + +def get_prompt_token_ids( + pooling_metadata: PoolingMetadata) -> list[torch.Tensor]: + if isinstance(pooling_metadata, V1PoolingMetadata): + assert pooling_metadata.prompt_token_ids is not None, ( + "Please set `requires_token_ids=True` in `get_pooling_updates`") + + return [ + pooling_metadata.prompt_token_ids[i, :num] + for i, num in enumerate(pooling_metadata.prompt_lens) + ] + + return [ + torch.tensor(seq_data_i.prompt_token_ids) + for seq_data_i in pooling_metadata.seq_data.values() + ] + + +def get_tasks(pooling_metadata: PoolingMetadata) -> list[PoolingTask]: + if isinstance(pooling_metadata, V0PoolingMetadata): + pooling_params = [p for _, p in pooling_metadata.seq_groups] + else: + pooling_params = pooling_metadata.pooling_params + + tasks: list[PoolingTask] = [ + task for pooling_param in pooling_params + if (task := pooling_param.task) is not None + ] + assert len(pooling_params) == len(tasks) + + return tasks + + +def get_classification_activation_function(config: PretrainedConfig): + return PoolerClassify() + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + function_name: Optional[str] = None + if (hasattr(config, "sentence_transformers") + and "activation_fn" in config.sentence_transformers): + function_name = config.sentence_transformers["activation_fn"] + elif (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + function_name = config.sbert_ce_default_activation_function + + if function_name is not None: + assert function_name.startswith("torch.nn.modules."), ( + "Loading of activation functions is restricted to " + "torch.nn.modules for security reasons") + fn = resolve_obj_by_qualname(function_name)() + return PoolerActivation.wraps(fn) + + return PoolerScore() + + +def build_output( + all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput: + all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data] + return PoolerOutput(outputs=all_outputs) + + +class PoolingMethod(nn.Module, ABC): + + @staticmethod + def from_pooling_type(pooling_type: PoolingType) -> "PoolingMethod": if pooling_type == PoolingType.LAST: - assert step_tag_id is None and returned_token_ids is None - return LastPool(normalize=normalize, softmax=softmax) + return LastPool() if pooling_type == PoolingType.ALL: - assert step_tag_id is None and returned_token_ids is None - return AllPool(normalize=normalize, softmax=softmax) + return AllPool() if pooling_type == PoolingType.CLS: - assert step_tag_id is None and returned_token_ids is None - return CLSPool(normalize=normalize, softmax=softmax) + return CLSPool() if pooling_type == PoolingType.MEAN: - assert step_tag_id is None and returned_token_ids is None - return MeanPool(normalize=normalize, softmax=softmax) - if pooling_type == PoolingType.STEP: - return StepPool(normalize=normalize, - softmax=softmax, - step_tag_id=step_tag_id, - returned_token_ids=returned_token_ids) + return MeanPool() - assert_never(pooling_type) + raise NotImplementedError(f"Unsupported method: {pooling_type}") - def __init__(self, *, normalize: bool, softmax: bool) -> None: - super().__init__() + @abstractmethod + def get_supported_tasks(self) -> Set[PoolingTask]: + raise NotImplementedError - self.head = PoolerHead(normalize=normalize, softmax=softmax) + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate() - def get_prompt_lens( + @abstractmethod + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - assert isinstance(hidden_states, torch.Tensor) - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens + """ + Note: + `prompt_len=None` means `prompt_len=len(hidden_states)`. + """ + raise NotImplementedError - def extract_states( + @abstractmethod + def forward_all( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, ) -> Union[list[torch.Tensor], torch.Tensor]: raise NotImplementedError - def build_output(self, data: torch.Tensor) -> PoolingSequenceGroupOutput: - return PoolingSequenceGroupOutput(data) - def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - pooled_data = self.extract_states(hidden_states, pooling_metadata) - pooled_data = self.head(pooled_data, pooling_metadata) - pooled_outputs = [self.build_output(data) for data in pooled_data] - return PoolerOutput(outputs=pooled_outputs) + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + if isinstance(hidden_states, list): + return [ + self.forward_one(h, prompt_len) + for h, prompt_len in zip(hidden_states, prompt_lens) + ] -class CLSPool(SimplePooler): + return self.forward_all(hidden_states, prompt_lens) - def extract_states( + +class CLSPool(PoolingMethod): + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} + + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with CLS pooling" - if isinstance(hidden_states, list): - result = [] - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with CLS pooling" - result.append(req_state[0]) - return result + return hidden_states[0] + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: first_token_flat_indices = torch.zeros_like(prompt_lens) first_token_flat_indices[1:] += torch.cumsum(prompt_lens, dim=0)[:-1] return hidden_states[first_token_flat_indices] -class LastPool(SimplePooler): +class LastPool(PoolingMethod): - def extract_states( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - if isinstance(hidden_states, list): - return [h[-1] for h in hidden_states] + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + def forward_one( + self, + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return hidden_states[-1] + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 return hidden_states[last_token_flat_indices] -class AllPool(SimplePooler): +class AllPool(PoolingMethod): - def extract_states( + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode"} + + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with ALL pooling" - if isinstance(hidden_states, list): - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with ALL pooling" - return hidden_states + return hidden_states - offset = 0 - pooled_data = list[torch.Tensor]() - for prompt_len in prompt_lens: - pooled_data.append(hidden_states[offset:offset + prompt_len]) - offset += prompt_len + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: + return list(hidden_states.split_with_sizes(prompt_lens.tolist())) - return pooled_data +class MeanPool(PoolingMethod): -class MeanPool(SimplePooler): + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed", "classify", "score"} - def extract_states( + def forward_one( self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + hidden_states: torch.Tensor, + prompt_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with MEAN pooling" - if isinstance(hidden_states, list): - result = [] - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with mean pooling" - result.append(torch.mean(req_state, dim=0, - dtype=torch.float32)) - return result + return hidden_states.mean(dim=0, dtype=torch.float32) + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: # Use float32 for torch.cumsum in MeanPool, # otherwise precision will be lost significantly. cumsum = torch.cumsum(hidden_states, dim=0, dtype=torch.float32) @@ -203,78 +401,108 @@ def extract_states( hidden_states[start_indices]) / prompt_lens.unsqueeze(1) -class StepPool(SimplePooler): +_T = TypeVar("_T", torch.Tensor, list[torch.Tensor]) - def __init__( - self, - *, - normalize: bool, - softmax: bool, - step_tag_id: Optional[int] = None, - returned_token_ids: Optional[list[int]] = None, - ): - super().__init__(normalize=normalize, softmax=softmax) - self.step_tag_id = step_tag_id - self.returned_token_ids = returned_token_ids +class BasePoolerActivation(nn.Module, ABC): - def get_prompt_token_ids( - self, - pooling_metadata: PoolingMetadata, - ) -> list[torch.Tensor]: - if isinstance(pooling_metadata, V1PoolingMetadata): - return [ - pooling_metadata.prompt_token_ids[i, :num] - for i, num in enumerate(pooling_metadata.prompt_lens) - ] - return [ - torch.tensor(seq_data_i.prompt_token_ids) - for seq_data_i in pooling_metadata.seq_data.values() - ] + @abstractmethod + def forward(self, pooled_data: _T) -> _T: + # shape: + # classify (& score) -> (batch_size, num_classes) + # embed -> (batch_size, embedding_dim) or list(embedding_dim) + # (batch_size, dimensions) or list(dimensions) if using MRL + raise NotImplementedError - def extract_states( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> Union[list[torch.Tensor], torch.Tensor]: - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) - prompt_token_ids = self.get_prompt_token_ids(pooling_metadata) - pooled_data_lst = list[torch.Tensor]() - if isinstance(hidden_states, list): - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with step pooling" - pooled_data_lst = hidden_states - else: - offset = 0 - for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - offset += prompt_len - pooled_data_lst.append(pooled_data_i) +class PoolerActivation(BasePoolerActivation): - pooled_data = list[torch.Tensor]() - returned_token_ids = self.returned_token_ids - step_tag_id = self.step_tag_id + @staticmethod + def wraps(module: nn.Module): + if isinstance(module, nn.Identity): + return PoolerIdentity() + if isinstance(module, (nn.Sigmoid, nn.Softmax)): + return PoolerClassify() - for data, token_id in zip(pooled_data_lst, prompt_token_ids): - if returned_token_ids is not None and len(returned_token_ids) > 0: - data = data[:, returned_token_ids] + return LambdaPoolerActivation(module) - if step_tag_id is not None: - data = data[token_id == step_tag_id] - pooled_data.append(data) + @abstractmethod + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def forward(self, pooled_data: _T) -> _T: + if isinstance(pooled_data, list): + return [self.forward_chunk(data) for data in pooled_data] + return self.forward_chunk(pooled_data) + + +class PoolerIdentity(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: return pooled_data +class PoolerNormalize(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + x = F.normalize(pooled_data.float(), p=2, dim=-1) + return x.to(pooled_data.dtype) + + +class PoolerClassify(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + num_labels = pooled_data.shape[-1] + if num_labels < 2: + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + + return F.softmax(pooled_data.float(), dim=-1).to(pooled_data.dtype) + + +class PoolerScore(PoolerActivation): + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + num_labels = pooled_data.shape[-1] + if num_labels < 2: + return F.sigmoid(pooled_data.float()).to(pooled_data.dtype) + + return pooled_data + + +class LambdaPoolerActivation(PoolerActivation): + + def __init__(self, fn: Callable[[torch.Tensor], torch.Tensor]): + super().__init__() + + self.fn = fn + + def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor: + return self.fn(pooled_data) + + class PoolerHead(nn.Module): - def __init__(self, *, normalize: bool, softmax: bool) -> None: + @classmethod + def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "PoolerHead": + if pooler_config.normalize and pooler_config.softmax: + raise ValueError("`normalize=True` and `softmax=True` should not " + "be set together") + + activation: PoolerActivation + if pooler_config.normalize: + activation = PoolerNormalize() + elif pooler_config.softmax: + activation = PoolerClassify() + else: + activation = PoolerIdentity() + + return cls(activation) + + def __init__(self, activation: PoolerActivation) -> None: super().__init__() - self.normalize = normalize - self.softmax = softmax + self.activation = activation def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], pooling_metadata: PoolingMetadata): @@ -312,162 +540,214 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor], for vecs, d in zip(pooled_data, dimensions_list) ] - if self.normalize: - if isinstance(pooled_data, list): - pooled_data = [ - F.normalize(data, p=2, dim=-1) for data in pooled_data - ] - else: - pooled_data = F.normalize(pooled_data, p=2, dim=-1) - - if self.softmax: - if isinstance(pooled_data, list): - pooled_data = [ - F.softmax(data, dim=-1) - if data.shape[-1] >= 2 else F.sigmoid(data) - for data in pooled_data - ] - else: - if pooled_data.shape[-1] >= 2: - pooled_data = F.softmax(pooled_data, dim=-1) - else: - pooled_data = F.sigmoid(pooled_data) + return self.activation(pooled_data) - # shape: - # classify (& score) -> (batch_size, num_classes) - # embed -> (batch_size, embedding_dim) or list(embedding_dim) - # (batch_size, dimensions) or list(dimensions) if using MRL - return pooled_data +class SimplePooler(Pooler): + """A layer that pools specific information from hidden states. -class Pooler(nn.Module): + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + """ @classmethod - def from_config_with_defaults( + def from_config( cls, - pooler_config: PoolerConfig, - pooling_type: PoolingType, - normalize: bool, - softmax: bool, + pooler_config: ResolvedPoolingConfig, + ) -> "SimplePooler": + pooling = PoolingMethod.from_pooling_type(pooler_config.pooling_type) + head = PoolerHead.from_config(pooler_config) + + return cls(pooling, head) + + def __init__(self, pooling: PoolingMethod, head: PoolerHead) -> None: + super().__init__() + + self.pooling = pooling + self.head = head + + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.pooling.get_pooling_updates(task) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) + + +class StepPooler(Pooler): + + @classmethod + def from_config(cls, pooler_config: ResolvedPoolingConfig) -> "StepPooler": + assert pooler_config.pooling_type == PoolingType.STEP + + return cls( + PoolerHead.from_config(pooler_config), + step_tag_id=pooler_config.step_tag_id, + returned_token_ids=pooler_config.returned_token_ids, + ) + + def __init__( + self, + head: PoolerHead, + *, step_tag_id: Optional[int] = None, returned_token_ids: Optional[list[int]] = None, - ) -> SimplePooler: - return SimplePooler.from_pooling_type( - pooling_type=PoolingType[pooler_config.pooling_type] - if pooler_config.pooling_type is not None else pooling_type, - normalize=pooler_config.normalize - if pooler_config.normalize is not None else normalize, - softmax=pooler_config.softmax - if pooler_config.softmax is not None else softmax, - step_tag_id=pooler_config.step_tag_id - if pooler_config.step_tag_id is not None else step_tag_id, - returned_token_ids=pooler_config.returned_token_ids - if pooler_config.returned_token_ids is not None else - returned_token_ids, - ) + ) -> None: + super().__init__() + + self.pooling = AllPool() + self.head = head + self.step_tag_id = step_tag_id + self.returned_token_ids = returned_token_ids + + def extract_states( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + pooled_data_lst = self.pooling(hidden_states, pooling_metadata) + prompt_token_ids = get_prompt_token_ids(pooling_metadata) + + pooled_data = list[torch.Tensor]() + returned_token_ids = self.returned_token_ids + step_tag_id = self.step_tag_id + + for data, token_id in zip(pooled_data_lst, prompt_token_ids): + if returned_token_ids is not None and len(returned_token_ids) > 0: + data = data[:, returned_token_ids] + + if step_tag_id is not None: + data = data[token_id == step_tag_id] + pooled_data.append(data) + + return pooled_data + + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.extract_states(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) -class ClassifierPooler(nn.Module): +class ClassifierPooler(Pooler): """A pooling layer for classification tasks. This layer does the following: 1. Applies a classification layer to the hidden states. 2. Optionally applies a pooler layer. - 3. Applies an activation function to the output. In the case of - classification models it is either sigmoid or softmax. In the - case of scoring models, the same behavior is configuration - dependent, as in the sentence-transformers library. + 3. Applies an activation function to the output. """ + @staticmethod + def act_fn_for_seq_cls(config: ModelConfig): + return get_classification_activation_function(config.hf_config) + + @staticmethod + def act_fn_for_cross_encoder(config: ModelConfig): + return get_cross_encoder_activation_function(config.hf_config) + def __init__( self, - config: ModelConfig, - classifier: nn.Module, - pooler: Optional[nn.Module] = None, - ): + pooling: PoolingFn, + classifier: ClassifierFn, + act_fn: PoolerActivation, + ) -> None: super().__init__() - self.classifier = classifier - self.pooler = pooler - self.classification_act_fn = get_classification_activation_function( - config.hf_config) - self.cross_encoder_act_fn = get_cross_encoder_activation_function( - config.hf_config) - - def _get_act_fn(self, use_cross_encoder: bool): - return (self.cross_encoder_act_fn - if use_cross_encoder else self.classification_act_fn) + self.pooling = pooling + self.classifier = classifier + self.act_fn = act_fn - def get_prompt_lens( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> torch.Tensor: - if isinstance(pooling_metadata, V1PoolingMetadata): - return pooling_metadata.prompt_lens - assert isinstance(hidden_states, torch.Tensor) - return PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"classify", "score"} def forward( self, hidden_states: Union[torch.Tensor, list[torch.Tensor]], pooling_metadata: PoolingMetadata, ) -> PoolerOutput: - """Pools sentence pair scores from the hidden_states.""" - prompt_lens = self.get_prompt_lens(hidden_states, pooling_metadata) + pooled_data = self.pooling(hidden_states, pooling_metadata) - pooled_data = list[torch.Tensor]() - if isinstance(hidden_states, list): - for req_state, prompt_len in zip(hidden_states, prompt_lens): - assert prompt_len == req_state.shape[0], \ - "partial prefill not supported with classifier" - pooled_data = hidden_states + # apply classifier once on the full batch if possible + if isinstance(pooled_data, torch.Tensor): + pooled_output = self.classifier(pooled_data) + elif len({data.shape for data in pooled_data}) <= 1: + pooled_output = self.classifier(torch.stack(pooled_data)) else: - offset = 0 - for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - offset += prompt_len - pooled_data.append(pooled_data_i) + pooled_output = [self.classifier(data) for data in pooled_data] - pooled_data_lst = [] - for pooled_data_i in pooled_data: + scores = self.act_fn(pooled_output) - if self.pooler is not None: - final_shape_tensor = self.pooler(pooled_data_i) - else: - final_shape_tensor = self.classifier(pooled_data_i) + return build_output(scores) - pooled_data_lst.append(final_shape_tensor) - pooled_output = torch.stack(pooled_data_lst) +class DispatchPooler(Pooler): + """Dispatches calls to a sub-pooler based on the pooling task.""" - if self.pooler is not None: - # apply classifier once on the full batch if possible - pooled_output = self.classifier(pooled_output) + def __init__(self, poolers_by_task: Mapping[PoolingTask, Pooler]) -> None: + super().__init__() - if isinstance(pooling_metadata, V0PoolingMetadata): - use_cross_encoder_list = [ - pooling_param.use_cross_encoder - for _, pooling_param in pooling_metadata.seq_groups - ] - else: - use_cross_encoder_list = [ - pooling_param.use_cross_encoder - for pooling_param in pooling_metadata.pooling_params - ] + for task, pooler in poolers_by_task.items(): + if task not in pooler.get_supported_tasks(): + raise ValueError( + f"{pooler=} does not support {task=}. " + f"Supported tasks: {pooler.get_supported_tasks()}") + + self.poolers_by_task = poolers_by_task + + def get_supported_tasks(self) -> Set[PoolingTask]: + return set(self.poolers_by_task) - # shape of scores: (batch_size, num_labels) - if all(use_cross_encoder == use_cross_encoder_list[0] - for use_cross_encoder in use_cross_encoder_list): - act_fn = self._get_act_fn(use_cross_encoder_list[0]) - scores = act_fn(pooled_output) + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.poolers_by_task[task].get_pooling_updates(task) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + poolers_by_task = self.poolers_by_task + + if isinstance(hidden_states, list): + hidden_states_lst = hidden_states else: - scores = torch.stack([ - self._get_act_fn(use_cross_encoder)(vecs) - for use_cross_encoder, vecs in zip(use_cross_encoder_list, - pooled_output) - ]) - - pooled_outputs = [PoolingSequenceGroupOutput(data) for data in scores] - return PoolerOutput(outputs=pooled_outputs) + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + hidden_states_lst = list(hidden_states.split(prompt_lens.tolist())) + + outputs = list[PoolingSequenceGroupOutput]() + offset = 0 + for task, group in groupby(get_tasks(pooling_metadata)): + if not (pooler := poolers_by_task.get(task)): + raise ValueError( + f"Unsupported task: {task} " + f"Supported tasks: {self.get_supported_tasks()}") + + num_items = len(list(group)) + group_output: PoolerOutput = pooler( + hidden_states_lst[offset:offset + num_items], + pooling_metadata[offset:offset + num_items], + ) + + outputs.extend(group_output.outputs) + offset += num_items + + return PoolerOutput(outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 60217ee86ad1..95aea912a150 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -36,6 +36,7 @@ "torchao", "auto-round", "rtn", + "inc", ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -104,6 +105,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .gptq_marlin import GPTQMarlinConfig from .gptq_marlin_24 import GPTQMarlin24Config from .hqq_marlin import HQQMarlinConfig + from .inc import INCConfig from .ipex_quant import IPEXConfig from .marlin import MarlinConfig from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config @@ -144,7 +146,8 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, "auto-round": AutoRoundConfig, - "rtn": RTNConfig + "rtn": RTNConfig, + "inc": INCConfig, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) @@ -157,4 +160,4 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "QuantizationMethods", "get_quantization_config", "QUANTIZATION_METHODS", -] \ No newline at end of file +] diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 1ed3ef8d2173..a96f3ee5c301 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,16 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any, Callable, Optional, Union import torch +from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, + FusedMoEMethodBase) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod, set_weight_attrs) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op @@ -120,12 +123,15 @@ def get_safe_value(config, keys, default_value=None): llm_int8_skip_modules=llm_int8_skip_modules, llm_int8_threshold=llm_int8_threshold) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["LinearMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["LinearMethodBase", "BitsAndBytesMoEMethod"]]: if isinstance(layer, LinearBase): if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules): return UnquantizedLinearMethod() return BitsAndBytesLinearMethod(self) + elif isinstance(layer, FusedMoE): + return BitsAndBytesMoEMethod(self) return None @@ -146,6 +152,13 @@ def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: list[str]): return substr_check or prefix_check +def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + class BitsAndBytesLinearMethod(LinearMethodBase): """Linear method for BitsAndBytes. @@ -173,12 +186,6 @@ def create_weights(self, layer: torch.nn.Module, **extra_weight_attrs): from bitsandbytes.nn import Int8Params - def calculate_quant_ratio(dtype): - if dtype.is_floating_point: - return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits - else: - return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits - def create_qweight_for_8bit(): qweight = Int8Params( data=torch.empty(sum(output_partition_sizes), @@ -384,13 +391,220 @@ def _apply_bnb_4bit_fake( try: - direct_register_custom_op( - op_name="apply_bnb_4bit", - op_func=_apply_bnb_4bit, - mutates_args=["out"], - fake_impl=_apply_bnb_4bit_fake, - ) + direct_register_custom_op(op_name="apply_bnb_4bit", + op_func=_apply_bnb_4bit, + mutates_args=["out"], + fake_impl=_apply_bnb_4bit_fake, + dispatch_key=current_platform.dispatch_key) apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit except AttributeError as error: raise error + + +class BitsAndBytesMoEMethod(FusedMoEMethodBase): + """MoE method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.45.3": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.45.3.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.45.3 via " + "`pip install bitsandbytes>=0.45.3` to use " + "bitsandbytes quantizer.") from err + self.topk_indices_dtype = None + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if self.quant_config.load_in_8bit: + call_fun = self._create_weights_8bit + else: + call_fun = self._create_weights_4bit + call_fun( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `BitsAndBytesMoEMethod` yet.") + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype) + if self.quant_config.load_in_8bit: + w13, w2 = self._apply_8bit_dequant(layer) + else: + w13, w2 = self._apply_4bit_dequnt(layer) + return fused_experts( + hidden_states=x, + w1=w13, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + ) + + def _create_weights_4bit( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + quant_ratio = calculate_quant_ratio(params_dtype) + # Fused gate_up_proj (column parallel) + w13_total_size = (hidden_size * 2 * + intermediate_size_per_partition) // quant_ratio + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + w13_total_size, + 1, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + set_weight_attrs( + w13_qweight, + { + "num_experts": + num_experts, + "input_dim": + hidden_size, + "output_dim": + 2 * intermediate_size_per_partition, + "experts_shape": ( + num_experts, + intermediate_size_per_partition * 2, + hidden_size, + ), + "pack_factor": + quant_ratio, + "use_bitsandbytes_4bit": + True, + }, + ) + # down_proj (row parallel) + w2_total_size = (hidden_size * + intermediate_size_per_partition) // quant_ratio + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + w2_total_size, + 1, + dtype=torch.uint8, + ), + requires_grad=False, + ) + set_weight_attrs( + w2_qweight, + { + "num_experts": + num_experts, + "input_dim": + intermediate_size_per_partition, + "output_dim": + hidden_size, + "experts_shape": ( + num_experts, + hidden_size, + intermediate_size_per_partition, + ), + "pack_factor": + quant_ratio, + "use_bitsandbytes_4bit": + True, + }, + ) + layer.register_parameter("w2_weight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + def _create_weights_8bit( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def _apply_4bit_dequnt( + self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + from bitsandbytes.functional import dequantize_4bit + w13 = dequantize_4bit( + layer.w13_weight.reshape(-1, 1), + layer.w13_weight.bnb_quant_state, + ) + w2 = dequantize_4bit( + layer.w2_weight.reshape(-1, 1), + layer.w2_weight.bnb_quant_state, + ) + w13 = w13.reshape(layer.w13_weight.experts_shape) + w2 = w2.reshape(layer.w2_weight.experts_shape) + return w13, w2 + + def _apply_8bit_dequant( + self, layer: torch.nn.Module) -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index e7f65d13181d..90b45e32a688 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -332,6 +332,12 @@ def _is_fp8_w8a8_sm90(self, weight_quant: BaseModel, return (self._check_scheme_supported(90, error=False, match_exact=True) and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a8_sm100(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + return (self._check_scheme_supported( + 100, error=False, match_exact=True) + and self._is_fp8_w8a8(weight_quant, input_quant)) + def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool: # Confirm weights quantized. diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ef67cc0eda46..7da52ce6ff8c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -83,7 +83,8 @@ def get_moe_method( return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): return CompressedTensorsW4A4MoeMethod() - elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): + elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) + or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -295,6 +296,7 @@ def apply( if enable_eplb: raise NotImplementedError("EPLB not supported for " "`CompressedTensorsW4A4MoeMethod` yet.") + assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -326,10 +328,6 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) - assert activation == "silu", "Only SiLU activation is supported." - assert not apply_router_weight_on_input, ( - "Router weight on input is not " - "supported for CompressedTensorsW4A4MoeMethod.") assert expert_map is None, ("Expert Parallelism / expert_map " "is currently not supported for " "CompressedTensorsW4A4MoeMethod.") @@ -339,22 +337,25 @@ def apply( # Cutlass moe takes in activations in BF16/Half precision # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4(a=x, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device).to(x.dtype) + return cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w2_blockscale=layer.w2_blockscale_swizzled, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + device=x.device, + apply_router_weight_on_input=apply_router_weight_on_input).to( + x.dtype) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): @@ -737,11 +738,11 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) self.topk_indices_dtype = None - self.fused_experts = cutlass_moe_fp8 # type: ignore + self.fused_experts = None # type: ignore self.disable_expert_map = False + self.is_fp8_w8a8_sm100 = self.quant_config._is_fp8_w8a8_sm100( + self.weight_quant, self.input_quant) def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -929,24 +930,67 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - ) + e_score_correction_bias=e_score_correction_bias) - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) + per_act_token = ( + self.input_quant.strategy == QuantizationStrategy.TOKEN) + per_channel_quant = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + # Triton fused_experts is faster in small batch sizes on SM100. + # Fall back to fused_experts in small batch sizes. + if self.is_fp8_w8a8_sm100 and topk_ids.shape[0] <= 8: + from vllm.model_executor.layers.fused_moe import fused_experts + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=per_channel_quant, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + if self.fused_experts is None: + # If no modular kernel is provided, use cutlass_moe_fp8 + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + return cutlass_moe_fp8( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + per_act_token=per_act_token, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + else: + return self.fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py index 30ed55aee04f..168b221a9cfe 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py @@ -15,6 +15,9 @@ QKVParallelLinear) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, sparse_cutlass_supported) from vllm.model_executor.parameter import (BasevLLMParameter, @@ -24,6 +27,8 @@ __all__ = ["CompressedTensors24"] +from vllm.platforms import current_platform + class CompressedTensors24(CompressedTensorsScheme): @@ -45,6 +50,12 @@ def __init__( and self.model_compressor.sparsity_config.format == CompressionFormat.sparse_24_bitmask.value) + if quantized and input_quant is not None and \ + self._get_quant_dtype() == current_platform.fp8_dtype(): + static = not input_quant.dynamic + g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN + self.quant_fp8 = QuantFP8(static, g_shape) + @classmethod def get_min_capability(cls) -> int: # Only cutlass 3.x kernels are implemented so far @@ -232,9 +243,7 @@ def apply_weights( :return: The output tensor of the layer """ if self.quantized: - scale = None - if hasattr(layer, "input_scale"): - scale = layer.input_scale + scale = getattr(layer, 'input_scale', None) if self.weights_dtype == torch.int8: ops_output = ops.scaled_int8_quant(x, scale=scale) @@ -242,11 +251,7 @@ def apply_weights( input_scale = ops_output[1] else: assert self.weights_dtype == torch.float8_e4m3fn - if scale is not None: - q_input, input_scale = ops.scaled_fp8_quant(x, scale=scale) - else: - q_input, input_scale = ops.scaled_fp8_quant( - x, use_per_token_if_dynamic=True) + q_input, input_scale = self.quant_fp8(x, scale=scale) else: # Not quantized, nothing to do with the input_scales, use as is @@ -269,7 +274,10 @@ def apply_weights( def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype: if not self.quantized: return params_dtype + return self._get_quant_dtype() + def _get_quant_dtype(self) -> torch.dtype: + assert self.quantized assert self.weight_quant is not None assert self.input_quant is not None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 1e61e058cb84..d984e89d9e02 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -9,6 +9,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) @@ -26,7 +28,11 @@ def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.act_q_group_shape = GroupShape.PER_TENSOR \ + if is_static_input_scheme else GroupShape.PER_TOKEN + self.fp8_linear = Fp8LinearOp( + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_q_group_shape) @classmethod def get_min_capability(cls) -> int: diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 5903976eaf6b..d26a932eddb2 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -6,10 +6,8 @@ from vllm.platforms import current_platform from vllm.triton_utils import triton -from vllm.utils import direct_register_custom_op, has_deep_gemm - -if has_deep_gemm(): - import deep_gemm +from vllm.utils import direct_register_custom_op +from vllm.utils.deep_gemm import fp8_gemm_nt logger = logging.getLogger(__name__) @@ -57,7 +55,7 @@ def w8a8_block_fp8_matmul_deepgemm( output_dtype) # Deepgemm only supports output tensor type as bfloat16 assert C.dtype == torch.bfloat16 - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + fp8_gemm_nt((A, As), (B, Bs), C) return C diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 3e465ee2cdd2..b2cab7d4614a 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, @@ -37,7 +37,6 @@ def __init__(self, ignore_list: list[str], input_scale_ub: float): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = not current_platform.has_device_capability(89) - self.fp8_linear = Fp8LinearOp() @classmethod def get_name(cls) -> QuantizationMethods: @@ -76,7 +75,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) self.out_dtype = torch.get_default_dtype() def create_weights( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5a1a427d7d72..75f8adf34f7d 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,11 +23,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + get_col_major_tma_aligned_tensor, requant_weight_ue8m0_inplace) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, all_close_1d, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, @@ -40,6 +42,8 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.utils import has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used +from vllm.utils.flashinfer import has_flashinfer_moe if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -49,6 +53,11 @@ logger = init_logger(__name__) +def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, 2, x.shape[-2] // 2, + x.shape[-1]).flip(dims=[1]).reshape(x.shape) + + def _is_col_major(x: torch.Tensor) -> bool: assert x.dim() == 3 b, m, n = x.shape @@ -199,9 +208,17 @@ def __init__(self, quant_config: Fp8Config): and current_platform.is_fp8_fnuz()) self.block_quant = self.quant_config.weight_block_size is not None + self.act_q_static = self.quant_config.activation_scheme == "static" + # Use per-token quantization for better perf if dynamic and cutlass + if not self.act_q_static and cutlass_fp8_supported(): + self.act_q_group_shape = GroupShape.PER_TOKEN + else: + self.act_q_group_shape = GroupShape.PER_TENSOR + self.fp8_linear = Fp8LinearOp( - # Default to using per_token quantization if cutlass is supported - use_per_token_if_dynamic=cutlass_fp8_supported()) + act_quant_static=self.act_q_static, + act_quant_group_shape=self.act_q_group_shape, + cutlass_fp8_supported=cutlass_fp8_supported()) def create_weights( self, @@ -240,9 +257,16 @@ def create_weights( f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}.") # Required by column parallel or enabling merged weights - if (tp_size > 1 and output_size // output_size_per_partition - == tp_size) or len(output_partition_sizes) > 1: - for output_partition_size in output_partition_sizes: + is_tp_split = (tp_size > 1 and + output_size // output_size_per_partition == tp_size) + is_merged_gemm = len(output_partition_sizes) > 1 + if is_tp_split or is_merged_gemm: + sizes_to_check = output_partition_sizes + if not is_tp_split and is_merged_gemm: + # In case of merged matrices, we allow the last + # matrix to not be a multiple of block size + sizes_to_check = output_partition_sizes[:-1] + for output_partition_size in sizes_to_check: if output_partition_size % block_n != 0: raise ValueError( f"Weight output_partition_size = " @@ -393,6 +417,19 @@ def process_weights_after_loading(self, layer: Module) -> None: # Activations not quantized for marlin. del layer.input_scale + # On B200, DeepGemm only support E8M0 scale, which means we need to + # requantize the weight and input to the specific scale + # at the same time. + if is_blackwell_deep_gemm_used(): + assert layer.weight_block_size is not None + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.weight.data, + layer.weight_scale_inv.data if hasattr( + layer, "weight_scale_inv") else layer.weight_scale.data, + block_sz, + ) + def apply(self, layer: torch.nn.Module, x: torch.Tensor, @@ -449,6 +486,11 @@ def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.flashinfer_moe_enabled = False + if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): + logger.info_once( + "Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.") + self.flashinfer_moe_enabled = True # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) @@ -464,11 +506,16 @@ def __init__(self, quant_config: Fp8Config): logger.warning_once("Failed to import DeepGemm kernels.") elif not self.block_quant: logger.warning_once("Model is not block quantized. Not using " - " DeepGemm kernels") + "DeepGemm kernels") elif (current_platform.is_cuda() - and current_platform.has_device_capability(90)): + and current_platform.is_device_capability(90)): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") self.allow_deep_gemm = True + elif (current_platform.is_cuda() + and is_blackwell_deep_gemm_used()): + logger.info_once("Using DeepGemm SM100 kernels for " + "Fp8MoEMethod.") + self.allow_deep_gemm = True else: logger.warning_once( "DeepGemm not supported on the current platform.") @@ -476,10 +523,10 @@ def __init__(self, quant_config: Fp8Config): # Check for CutlassBlockScaledGroupedGemm support. self.allow_cutlass_block_scaled_grouped_gemm = False if not self.block_quant: - logger.warning_once("Model is not block quantized. Not using " - "CutlassBlockScaledGroupedGemm kernels") + logger.debug_once("Model is not block quantized. Not using " + "CutlassBlockScaledGroupedGemm kernels") elif (current_platform.is_cuda() - and current_platform.has_device_capability(100)): + and current_platform.is_device_capability(100)): logger.info_once( "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." ) @@ -645,6 +692,14 @@ def process_weights_after_loading(self, layer: Module) -> None: normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale_inv, layer.w2_input_scale) + elif self.flashinfer_moe_enabled: + # NOTE: weights have to be swapped since the activation is + # applied on different half for flashinfer vs vllm + w13_weight = _swap_w13_to_w31(layer.w13_weight.data) + w13_weight_scale_inv = _swap_w13_to_w31( + layer.w13_weight_scale_inv.data) + w2_weight = layer.w2_weight.data + w2_weight_scale_inv = layer.w2_weight_scale_inv.data else: w13_weight = layer.w13_weight.data w13_weight_scale_inv = layer.w13_weight_scale_inv.data @@ -670,15 +725,14 @@ def process_weights_after_loading(self, layer: Module) -> None: # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. - if self.allow_deep_gemm: + if self.allow_deep_gemm and not is_blackwell_deep_gemm_used(): # Lazy import to avoid CUDA initialization problems. - import deep_gemm as dg if _is_col_major(layer.w13_weight_scale_inv): layer.w13_weight_scale_inv = \ - dg.get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() + get_col_major_tma_aligned_tensor(layer.w13_weight_scale_inv).contiguous() if _is_col_major(layer.w2_weight_scale_inv): layer.w2_weight_scale_inv = \ - dg.get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() + get_col_major_tma_aligned_tensor(layer.w2_weight_scale_inv).contiguous() # If checkpoint is fp16, quantize in place. elif not self.quant_config.is_checkpoint_fp8_serialized: @@ -797,6 +851,29 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale + if is_blackwell_deep_gemm_used(): + assert layer.weight_block_size is not None + # Re-quantise the expert weights so their scales are UE8M0. + block_sz = tuple(layer.weight_block_size) + requant_weight_ue8m0_inplace( + layer.w13_weight.data, + layer.w13_weight_scale_inv.data, + block_sz, + ) + requant_weight_ue8m0_inplace( + layer.w2_weight.data, + layer.w2_weight_scale_inv.data, + block_sz, + ) + + # Ensure column-major TMA alignment expected by DeepGEMM. + if _is_col_major(layer.w13_weight_scale_inv): + layer.w13_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w13_weight_scale_inv).contiguous() + if _is_col_major(layer.w2_weight_scale_inv): + layer.w2_weight_scale_inv = get_col_major_tma_aligned_tensor( + layer.w2_weight_scale_inv).contiguous() + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -864,25 +941,25 @@ def apply( assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) + if not self.flashinfer_moe_enabled: + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 @@ -920,6 +997,31 @@ def apply( apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) + elif self.flashinfer_moe_enabled: + # Currently only work with DS models + assert self.block_quant + assert (renormalize and use_grouped_topk + and scoring_func == 'sigmoid' + and custom_routing_function is None) + assert activation == "silu" + return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( + routing_logits=router_logits.to(torch.float32), + routing_bias=e_score_correction_bias, + x=x, + w13_weight=layer.w13_weight, + w13_weight_scale_inv=layer.w13_weight_scale_inv, + w2_weight=layer.w2_weight, + w2_weight_scale_inv=layer.w2_weight_scale_inv, + global_num_experts=global_num_experts, + top_k=top_k, + num_expert_group=num_expert_group, + topk_group=topk_group, + intermediate_size=layer.intermediate_size_per_partition, + expert_offset=layer.ep_rank * layer.local_num_experts, + local_num_experts=layer.local_num_experts, + block_shape=self.quant_config.weight_block_size, + routed_scaling=1.0, + ) else: return self.fused_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/inc.py b/vllm/model_executor/layers/quantization/inc.py new file mode 100644 index 000000000000..8aa1f1a14bfc --- /dev/null +++ b/vllm/model_executor/layers/quantization/inc.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Intel Gaudi supports quantization of various modules and functions, +# including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. +# During model loading, +# INC will patch layers with quantization/dequantization operators. +# Meanwhile, INC will convert original weight to target datatype +# and loading to target device. +# static scaling should be provided through Quant_CONFIG: +# `QUANT_CONFIG` is an environment variable, +# that points to the measurement or quantization JSON config file. +# The measurement configuration file is used during the calibration procedure, +# to collect measurements for a given model. +# The quantization configuration is used during inference. +# For more information, please refer to: +# https://docs.habana.ai/en/v1.21.1/PyTorch/vLLM_Inference/vLLM_FP8_Inference.html + +from typing import Any, Optional + +import torch + +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization import QuantizationMethods +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) + + +class INCConfig(QuantizationConfig): + """Config class for FP8 using Intel Neural Compressor.""" + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "inc" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "INCConfig": + raise AssertionError + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) + return None + + @classmethod + def get_min_capability(cls) -> int: + raise AssertionError + + @staticmethod + def get_config_filenames() -> list[str]: + return [] diff --git a/vllm/model_executor/layers/quantization/input_quant_fp8.py b/vllm/model_executor/layers/quantization/input_quant_fp8.py new file mode 100644 index 000000000000..e1a9bdde9334 --- /dev/null +++ b/vllm/model_executor/layers/quantization/input_quant_fp8.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) +from vllm.platforms import current_platform + +# Using the default value (240.0) from pytorch will cause accuracy +# issue on dynamic quantization models. Here use 224.0 for fnuz on ROCm. +_FP8_DTYPE = current_platform.fp8_dtype() +_FP8_FINFO = torch.finfo(_FP8_DTYPE) +_FP8_MAX = 224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.max +_FP8_MIN = -224.0 if current_platform.is_fp8_fnuz() else _FP8_FINFO.min +_FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) + + +@CustomOp.register("quant_fp8") +class QuantFP8(CustomOp): + """ + Quantize input tensor to per-tensor or per-token FP8. + This CustomOp supports both static and dynamic quantization. + """ + + def __init__(self, + static: bool, + group_shape: GroupShape, + num_token_padding: Optional[int] = None): + """ + + :param static: static or dynamic quantization + :param group_shape: quantization group shape (PER_TOKEN or PER_TENSOR) + :param num_token_padding: Pad the token dimension of output to this size + """ + super().__init__() + self.num_token_padding = num_token_padding + assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR} + assert not static or group_shape == GroupShape.PER_TENSOR, \ + "Only per-tensor scales supported for static quantization." + self.static = static + self.group_shape = group_shape + self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN + + def forward_cuda( + self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + assert (scale is not None) == self.static + assert scale_ub is None or (not self.static and self.group_shape + == GroupShape.PER_TOKEN + and scale_ub.numel() == 1) + + return ops.scaled_fp8_quant( + x, + scale, + num_token_padding=self.num_token_padding, + scale_ub=scale_ub, + use_per_token_if_dynamic=self.use_per_token_if_dynamic) + + def forward_native( + self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None, + scale_ub: Optional[torch.Tensor] = None, + ): + assert (scale is not None) == self.static + assert scale_ub is None or (not self.static and self.group_shape + == GroupShape.PER_TOKEN + and scale_ub.numel() == 1) + + if scale is None: + if self.group_shape == GroupShape.PER_TOKEN: + x_max, _ = x.abs().max(dim=-1) + x_max = x_max.unsqueeze(-1).to(torch.float32) + if scale_ub is not None: + x_max = x_max.clamp(max=scale_ub) + else: + x_max = x.abs().max().unsqueeze(-1).to(torch.float32) + + scale = x_max / _FP8_MAX + scale = scale.clamp(min=_FP8_MIN_SCALING_FACTOR) + + # Even for dynamic per-token scales, + # reciprocal performs slightly better than division + out = x.to(torch.float32) * scale.reciprocal() + out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) + + # This currently generates an extra Triton kernel in compilation. + # Fortunately, we don't use padding if compiling. + # TODO(luka): benchmark torch._scaled_mm to hopefully remove padding + # in general. + if self.num_token_padding is not None: + padding = max(self.num_token_padding - out.size(0), 0) + out = F.pad(out, (0, 0, 0, padding), "constant", 0.0) + + return out, scale diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py index 0bf0d530d235..21e5ae793c3f 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -8,6 +8,8 @@ AllSparkLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.bitblas import ( # noqa: E501 BitBLASLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.conch import ( # noqa: E501 + ConchLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 ExllamaLinearKernel) from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 @@ -24,6 +26,7 @@ AllSparkLinearKernel, MarlinLinearKernel, BitBLASLinearKernel, + ConchLinearKernel, ExllamaLinearKernel, ] @@ -80,4 +83,4 @@ def choose_mp_linear_kernel( raise ValueError( "Failed to find a kernel that can implement the "\ "WNA16 linear layer. Reasons: \n" - + '\n'.join(failure_reasons)) \ No newline at end of file + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py new file mode 100644 index 000000000000..f80af548f019 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from importlib.util import find_spec +from typing import Final, Optional + +import torch + +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + +_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [ + scalar_types.uint4, scalar_types.uint8, scalar_types.uint4b8, + scalar_types.uint8b128 +] +_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128] + + +class ConchLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> tuple[bool, Optional[str]]: + if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES: + error_msg = f"Weight type ({c.weight_type}) not supported by "\ + "ConchLinearKernel, supported types are: " \ + f"{_CONCH_SUPPORTED_WEIGHT_TYPES}" + return False, error_msg + + if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES: + error_msg = f"Group size ({c.group_size}) not supported by "\ + "ConchLinearKernel, supported group sizes are: " \ + f"{_CONCH_SUPPORTED_GROUP_SIZES}" + return False, error_msg + + if find_spec("conch") is None: + error_msg = "conch-triton-kernels is not installed, please "\ + "install it via `pip install conch-triton-kernels` "\ + "and try again!" + return False, error_msg + + return True, None + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = x.data.contiguous() + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + from conch.ops.quantization.gemm import mixed_precision_gemm + + w_q, w_s, w_zp, _ = self._get_weight_params(layer) + + output = mixed_precision_gemm( + x=x, + w_q_packed=w_q.data, + w_s=w_s.data, + w_zp=w_zp.data if w_zp is not None else None, + weight_size_bits=self.config.weight_type.size_bits, + weight_bias=self.config.weight_type.bias, + group_size=self.config.group_size, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py index ed81b02bc4a1..da951ddab2e4 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -126,6 +126,11 @@ def apply_weights(self, if c.has_g_idx: x_2d = self.act_perm(x_2d) + if c.zero_points: + assert w_zp is not None + else: + w_zp = None + output = ops.machete_mm(a=x_2d, b_q=w_q, b_type=c.weight_type, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py index 165548a06012..7f808fa92a9a 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -8,11 +8,55 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from .cutlass import CutlassScaledMMLinearKernel from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig +def rocm_aiter_gemm_w8a8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) + + +def rocm_aiter_gemm_w8a8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: Optional[torch.Tensor] = None, + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + + m = A.shape[0] + n = B.shape[0] + Y = torch.empty(m, n, dtype=output_dtype, device=A.device) + return Y + + +if current_platform.is_rocm(): + direct_register_custom_op( + op_name="rocm_aiter_gemm_w8a8", + op_func=rocm_aiter_gemm_w8a8_impl, + mutates_args=[], + fake_impl=rocm_aiter_gemm_w8a8_fake, + dispatch_key=current_platform.dispatch_key, + ) + + class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): @classmethod @@ -111,10 +155,9 @@ def apply_weights(self, " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + "does not support AITER block scaled GEMM.") - from aiter import gemm_a8w8_CK - # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects # a to be [M, K] # b to be [N, K] # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format - return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) + return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s, + bias, out_dtype) diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py index 3de28af40aaa..0b931b2d8b81 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -90,16 +90,15 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: w_q, w_s, _, _, _ = self._get_weight_params(layer) - import torch_xla.experimental.xla_quantized_matmul # noqa: F401 - out = torch.ops.xla.quantized_matmul(x, - w_q, - w_s, - zero_point=None, - block_size=-1, - int4_weight=False, - quantize_activation=True) - # `quantized_matmul` output is fp32, cast it down to bf16 for perf - out = out.to(x.dtype) + # Required to register custom ops. + import torch_xla.experimental.custom_kernel # noqa: F401 + out = torch.ops.xla.quantized_matmul_int8( + x, + w_q, + w_s, + quantize_activation=True, + ) + # Explicitly capture control flow to make dynamo happy. # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 9db875330230..460334d77f0a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -7,9 +7,15 @@ from torch.nn import Module from torch.nn.parameter import Parameter +import vllm.envs as envs +import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) +from vllm.distributed import get_ep_group from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig +from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 + FlashInferCutlassMoEPrepareAndFinalize) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, @@ -22,7 +28,7 @@ apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, @@ -42,9 +48,13 @@ class ModelOptFp8Config(QuantizationConfig): def __init__( self, is_checkpoint_fp8_serialized: bool = False, + kv_cache_quant_method: Optional[str] = None, + exclude_modules: Optional[list[str]] = None, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + self.kv_cache_quant_method = kv_cache_quant_method + self.exclude_modules = exclude_modules if is_checkpoint_fp8_serialized: logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change.") @@ -65,44 +75,118 @@ def get_min_capability(cls) -> int: def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + """Detect if this ModelOpt config should be used based on + quantization config.""" + + if hf_quant_cfg is None: + return None + + # Use the community standard 'quant_method' + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + # Only proceed if the method is explicitly "modelopt" + if quant_method != "modelopt": + return None + + # Look for ModelOpt-specific config structure + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + quant_algo = quant_config.get("quant_algo", "") + if "FP8" in quant_algo: + return "modelopt" + else: + # Check for compressed-tensors style config with specific quant_algo + quant_algo = hf_quant_cfg.get("quant_algo", "") + if isinstance(quant_algo, str) and "FP8" in quant_algo: + return "modelopt" + + return None + @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptFp8Config": - quant_config = cls.get_from_keys(config, ["quantization"]) - quant_method = quant_config["quant_algo"] + # Handle both ModelOpt format and compressed-tensors style format + if "quantization" in config: + # ModelOpt format: {"quantization": {"quant_algo": "..."}} + quant_config = cls.get_from_keys(config, ["quantization"]) + if not isinstance(quant_config, dict): + raise ValueError( + "Expected 'quantization' to be a dictionary in config") + quant_method = quant_config.get("quant_algo", "") + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") + exclude_modules = quant_config.get("exclude_modules") + else: + # Compressed-tensors style format: + # {"quant_algo": "...", "quant_method": "modelopt"} + quant_method = config.get("quant_algo", "") + kv_cache_quant_method = config.get("kv_cache_quant_algo") + exclude_modules = config.get("exclude_modules") + if quant_method not in QUANT_ALGOS: - raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" - " quantizations in vLLM. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration.") + raise ValueError( + f"ModelOpt currently only supports: {QUANT_ALGOS} " + "quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") is_checkpoint_fp8_serialized = ("FP8" in quant_method) - return cls(is_checkpoint_fp8_serialized) + return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, + exclude_modules) + + def is_layer_excluded(self, prefix: str) -> bool: + """ + Check if a layer should be excluded from quantization. + + This method handles both regular models and multimodal models that use + the language_model prefix. For multimodal models, it checks if the + module name (without the language_model prefix) is in the exclude list. + """ + if self.exclude_modules is None: + return False + + # Check if any excluded module matches the prefix + for module in self.exclude_modules: + if (module in prefix + or (prefix.startswith("language_model.") + and module in prefix.removeprefix("language_model."))): + return True + return False def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): + if self.is_layer_excluded(prefix): + return UnquantizedLinearMethod() return ModelOptFp8LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return ModelOptFp8MoEMethod(self) return None class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and - activation scale. Future support might be added for dynamic + activation scale. Future support might be added for dynamic scales. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. - 2. Only support float8_e4m3fn datatype + 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptFp8Config): self.quant_config = quant_config - self.fp8_linear = Fp8LinearOp() + self.fp8_linear = Fp8LinearOp( + act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR) def create_weights( self, @@ -171,13 +255,230 @@ def apply( bias=bias) +class ModelOptFp8MoEMethod(FusedMoEMethodBase): + """MoE method for ModelOpt FP8. + Supports loading FP8 checkpoints with static weight scale and + activation scale. + Args: + quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + cutlass_fp8_supported) + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + # Use FP8 dtype if checkpoint is serialized + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight_loader = extra_weight_attrs.get("weight_loader") + + w13_weight = ModelWeightParameter( + data=torch.empty(num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=weight_dtype), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight", w13_weight) + + w2_weight = ModelWeightParameter( + data=torch.empty(num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=weight_dtype), + input_dim=2, + output_dim=1, + weight_loader=weight_loader, + ) + layer.register_parameter("w2_weight", w2_weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALES - Per-tensor scaling for ModelOpts + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = PerTensorScaleParameter( + data=torch.full( + (num_experts, 2), + 1.0, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ) + w2_weight_scale = PerTensorScaleParameter( + data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + # Set weight loader attributes for scales + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + + # INPUT SCALES - Per-tensor scaling for ModelOpt + w13_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + w2_input_scale = PerTensorScaleParameter( + data=torch.full((num_experts, ), 1.0, dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Process FP8 MoE weights after loading from serialized checkpoint. + Only supports pre-quantized checkpoints with FP8 weights and scales. + """ + + layer.w13_weight = Parameter(layer.w13_weight.data, + requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) + + from vllm._custom_ops import scaled_fp8_quant + from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + per_tensor_dequantize) + + # Handle scale parameters + if hasattr(layer, + "w13_weight_scale") and layer.w13_weight_scale is not None: + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max of the w1 and w3 scales + # then dequant and requant each expert. + if layer.w13_weight_scale.dim() == 2: + + # Get the maximum scale across w1 and w3 for each expert + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + + # Requantize each expert's weights using the combined scale + # w13_weight (num_experts, 2 * intermediate_size, hidden_size) + # where the first intermediate_size rows are w1, the next are w3 + intermediate_size = layer.w13_weight.shape[1] // 2 + for expert_id in range(layer.w13_weight.shape[0]): + start = 0 + for shard_id in range(2): # w1 and w3 + # Dequantize using the original scale for this shard + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + intermediate_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + # Requantize using the combined max scale + + ( + layer.w13_weight[expert_id][start:start + + intermediate_size, :], + _, + ) = scaled_fp8_quant(dq_weight, + max_w13_scales[expert_id]) + + start += intermediate_size + + # Update the scale parameter to be per-expert + layer.w13_weight_scale = Parameter(max_w13_scales, + requires_grad=False) + else: + layer.w13_weight_scale = Parameter(layer.w13_weight_scale.data, + requires_grad=False) + + if hasattr(layer, + "w2_weight_scale") and layer.w2_weight_scale is not None: + layer.w2_weight_scale = Parameter(layer.w2_weight_scale.data, + requires_grad=False) + # Input scales must be equal for each expert in fp8 MoE layers. + if hasattr(layer, + "w13_input_scale") and layer.w13_input_scale is not None: + layer.w13_input_scale = Parameter(layer.w13_input_scale.max(), + requires_grad=False) + if hasattr(layer, + "w2_input_scale") and layer.w2_input_scale is not None: + layer.w2_input_scale = Parameter(layer.w2_input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `ModelOptFp8MoEMethod` yet.") + + # Expert selection + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + per_channel_quant=False, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + class ModelOptNvFp4Config(QuantizationConfig): """Config class for ModelOpt FP4.""" def __init__( self, is_checkpoint_nvfp4_serialized: bool, - kv_cache_quant_algo: str, + kv_cache_quant_algo: Optional[str], exclude_modules: list[str], group_size: int = 16, ) -> None: @@ -208,24 +509,138 @@ def get_min_capability(cls) -> int: def get_config_filenames(cls) -> list[str]: return ["hf_quant_config.json"] + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: + """Detect if this ModelOpt FP4 config should be used based on + quantization config.""" + if hf_quant_cfg is None: + return None + + # Use the community standard 'quant_method' + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + # Only proceed if the method is explicitly "modelopt" + if quant_method != "modelopt": + return None + + # Look for ModelOpt-specific config structure + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + quant_algo = quant_config.get("quant_algo", "") + if "NVFP4" in quant_algo: + return "modelopt_fp4" + else: + # Check for compressed-tensors style config with specific + # quant_algo field + quant_algo = hf_quant_cfg.get("quant_algo", "") + if isinstance(quant_algo, str) and "FP4" in quant_algo.upper(): + return "modelopt_fp4" + + return None + @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptNvFp4Config": - quant_config = cls.get_from_keys(config, ["quantization"]) - quant_method = quant_config["quant_algo"] + # Handle both traditional ModelOpt format and compressed-tensors + # style format + if "quantization" in config: + # Traditional ModelOpt format: + # {"quantization": {"quant_algo": "..."}} + quant_config = cls.get_from_keys(config, ["quantization"]) + if not isinstance(quant_config, dict): + raise ValueError( + "Expected 'quantization' to be a dictionary in config") + + quant_method = quant_config.get("quant_algo", "") + if not quant_method: + raise ValueError("Missing 'quant_algo' in quantization config") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo") + if kv_cache_quant_algo_raw is None: + # No KV cache quantization by default + kv_cache_quant_algo = None + elif isinstance(kv_cache_quant_algo_raw, str): + kv_cache_quant_algo = kv_cache_quant_algo_raw + else: + raise ValueError(f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}") + + # Handle group_size with proper type validation + group_size_raw = quant_config.get("group_size") + if group_size_raw is None: + group_size = 16 # Default value + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError(f"group_size must be an integer, got " + f"{type(group_size_raw)}") from None + + exclude_modules = quant_config.get("exclude_modules", []) + if not isinstance(exclude_modules, list): + raise ValueError(f"exclude_modules must be a list, got " + f"{type(exclude_modules)}") + else: + # Compressed-tensors style format: + # {"quant_algo": "...", "quant_method": "modelopt"} + quant_method = config.get("quant_algo", "") + + # Handle kv_cache_quant_algo with proper type validation + kv_cache_quant_algo_raw = config.get("kv_cache_quant_algo") + if kv_cache_quant_algo_raw is None: + # No KV cache quantization by default + kv_cache_quant_algo = None + elif isinstance(kv_cache_quant_algo_raw, str): + kv_cache_quant_algo = kv_cache_quant_algo_raw + else: + raise ValueError(f"kv_cache_quant_algo must be a string, got " + f"{type(kv_cache_quant_algo_raw)}") + + # Handle group_size with proper type validation + group_size_raw = config.get("group_size") + if group_size_raw is None: + group_size = 16 # Default value + elif isinstance(group_size_raw, int): + group_size = group_size_raw + else: + try: + group_size = int(group_size_raw) + except (ValueError, TypeError): + raise ValueError(f"group_size must be an integer, got " + f"{type(group_size_raw)}") from None + + exclude_modules = config.get("exclude_modules", []) + if not isinstance(exclude_modules, list): + raise ValueError(f"exclude_modules must be a list, got " + f"{type(exclude_modules)}") + if quant_method not in QUANT_ALGOS: - raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" - " quantizations in vLLM. Please check the " - "`hf_quant_config.json` file for your model's " - "quant configuration.") + raise ValueError( + f"ModelOpt currently only supports: {QUANT_ALGOS} " + "quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) - if ("group_size" and "kv_cache_quant_algo" - and "exclude_modules") not in quant_config: - raise ValueError("NVFP4 quantization requires group size and " - "kv_cache_quant_algo specified in " - "hf_quant_config.json") - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] - group_size = quant_config["group_size"] - exclude_modules = quant_config["exclude_modules"] + + # For FP4, these fields are required + if is_checkpoint_nvfp4_serialized and "quantization" in config: + # Check if required fields are present in the quantization config + quant_config = config["quantization"] + required_fields = [ + "group_size", "kv_cache_quant_algo", "exclude_modules" + ] + missing_fields = [ + field for field in required_fields if field not in quant_config + ] + if missing_fields: + raise ValueError( + f"NVFP4 quantization requires the following fields in " + f"hf_quant_config.json: {missing_fields}") + return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) @@ -273,7 +688,7 @@ def __init__(self, quant_config: Union[ModelOptFp8Config, class ModelOptNvFp4LinearMethod(LinearMethodBase): """Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure: - + input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, @@ -454,7 +869,7 @@ def apply( class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. - Args: + Args: quant_config: NVFP4 Quant Config """ @@ -462,6 +877,18 @@ def __init__(self, quant_config: ModelOptNvFp4Config): self.quant_config = quant_config self.cutlass_nvfp4_supported = cutlass_fp4_supported() self.use_marlin = False + self.allow_flashinfer_cutlass = False + + if envs.VLLM_USE_FLASHINFER_MOE_FP4: + if self.cutlass_nvfp4_supported and current_platform.is_cuda() \ + and current_platform.is_device_capability(100): + logger.info_once( + "Using FlashInfer kernels for ModelOptNvFp4FusedMoE.") + self.allow_flashinfer_cutlass = True + else: + logger.warning_once( + "Flashinfer CUTLASS Fused MoE not supported " + "or found on the current platform.") if not self.cutlass_nvfp4_supported: if is_fp4_marlin_supported(): @@ -471,6 +898,78 @@ def __init__(self, quant_config: ModelOptNvFp4Config): " quantization. Please use Blackwell and" " above.") + self.fused_experts = None # type: ignore + + def maybe_swap_experts_impl( + self, + moe_parallel_config: FusedMoEParallelConfig, + ): + if not self.allow_flashinfer_cutlass: + return + + logger.debug_once("FlashInferExperts") + # default to TP/EP case only + + experts_kwargs: dict[str, Any] = { + "use_nvfp4_w4a4": True, + "use_dp": moe_parallel_config.dp_size > 1, + "ep_rank": moe_parallel_config.ep_rank, + "ep_size": moe_parallel_config.ep_size, + "tp_rank": moe_parallel_config.tp_rank, + "tp_size": moe_parallel_config.tp_size, + } + + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + experts = FlashInferExperts(**experts_kwargs) + self.fused_experts = mk.FusedMoEModularKernel( + FlashInferCutlassMoEPrepareAndFinalize( + quant_dtype=torch.uint8, + #meaning 2x e2m1 packed in one, kernel requirement + ), + experts, + ) + + # This method update self.fused_experts + # only prepare_finalize is not None call select_gemm_impl + # so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert + # when it's not called(TP case), we still have 2 kernels to use. + def select_gemm_impl(self, prepare_finalize, + moe) -> mk.FusedMoEPermuteExpertsUnpermute: + + assert moe is not None + assert prepare_finalize is not None + experts = None + all2all_manager = get_ep_group().device_communicator.all2all_manager + assert all2all_manager is not None + if self.allow_flashinfer_cutlass: + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + FlashInferExperts) + logger.debug_once("Using FlashInferExperts") + experts = FlashInferExperts( + use_nvfp4_w4a4=True, + use_dp=moe.moe_parallel_config.dp_size > 1, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) + else: + assert moe.dp_size > 1 + logger.debug_once("Using CutlassExpertsFp4") + # Currently CutlassExpertsFp4 doesn't support DP + raise ValueError("CutlassExpertsFp4 doesn't support DP. " + "Use flashinfer CUTLASS FusedMoE backend instead " + "(set VLLM_USE_FLASHINFER_MOE_FP4=1)") + + return experts + + def uses_weight_scale_2_pattern(self) -> bool: + """ + FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. + """ + return True + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -585,8 +1084,30 @@ def swizzle_blockscale(self, scale: torch.tensor): if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # GEMM 1 + # The FlashInfer Cutlass fused MoE kernel expects the combined weights + # to be ordered as [w3, w1], unlike the standard [w1, w3] layout. + gemm1_weight = layer.w13_weight.data + gemm1_weight_scale = layer.w13_weight_scale.data + + if self.allow_flashinfer_cutlass: + dim = -2 + size = gemm1_weight.size(dim) + assert size % 2 == 0, f"Expected even size in dim {dim}, got {size}" + half = size // 2 + + # Reorder weight + w1, w3 = gemm1_weight.split(half, dim=dim) + gemm1_weight = torch.cat([w3, w1], dim=dim).contiguous() + + # Reorder scale + s1, s3 = gemm1_weight_scale.split(half, dim=dim) + gemm1_weight_scale = torch.cat([s3, s1], dim=dim).contiguous() + + layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) + layer.w13_weight_scale = Parameter(gemm1_weight_scale, + requires_grad=False) + if not torch.allclose(layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]): logger.warning_once( @@ -617,9 +1138,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False) - layer.w13_weight = Parameter(layer.w13_weight.data, - requires_grad=False) - # GEMM 2 layer.g2_alphas = Parameter( (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), @@ -673,21 +1191,21 @@ def apply( if enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptNvFp4FusedMoE` yet.") + assert activation == "silu", "Only SiLU activation is supported." - if self.use_marlin: - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + if self.use_marlin: return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, @@ -704,44 +1222,74 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map) - assert activation == "silu", "Only SiLU activation is supported." - assert not apply_router_weight_on_input, ( - "Router weight on input is not " - "supported for ModelOptNvFp4FusedMoE.") - assert expert_map is None, ("Expert Parallelism / expert_map " - "is currently not supported for " - "ModelOptNvFp4FusedMoE.") - - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp4) - - # Cutlass moe takes in activations in BF16/Half precision - # and fp4 quantized weights loaded from the checkpoint - return cutlass_moe_fp4(a=x, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_blockscale_swizzled, - w1_alphas=layer.g1_alphas, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_blockscale_swizzled, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - m=x.shape[0], - n=layer.w2_weight.shape[2] * 2, - k=x.shape[1], - e=layer.w13_weight.shape[0], - a1_gscale=layer.w13_input_scale_quant, - a2_gscale=layer.w2_input_scale_quant, - device=x.device).to(x.dtype) + if self.fused_experts is None: + # If no modular kernel is provided, use cutlass_moe_fp4 for TP case + # only (no EP). + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + out = cutlass_moe_fp4( + a=x, + w1_fp4=layer.w13_weight, + w2_fp4=layer.w2_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w2_blockscale=layer.w2_blockscale_swizzled, + g1_alphas=layer.g1_alphas, + g2_alphas=layer.g2_alphas, + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + device=x.device, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) + else: + # TP or DP case + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + is_valid_flashinfer_cutlass_fused_moe) + assert is_valid_flashinfer_cutlass_fused_moe( + x, layer.w13_weight, layer.w2_weight), ( + "Flashinfer CUTLASS Fused MoE not applicable!") + + a1_gscale = torch.min(layer.w13_input_scale_quant) + a2_gscale = torch.min(layer.w2_input_scale_quant) + extra_expert_args = { + 'g1_alphas': layer.g1_alphas, + 'g2_alphas': layer.g2_alphas, + 'out_dtype': x.dtype, + # Avoid confusion with a1_scale and a2_scale + # where are batch size related. + 'a1_gscale': a1_gscale, + 'a2_gscale': a2_gscale, + } + extra_prepare_args = { + 'use_dp': layer.dp_size > 1, + 'local_tokens': x.shape[0], + 'a1_gscale': a1_gscale, + } + extra_finalize_args = { + 'use_dp': layer.dp_size > 1, + 'local_tokens': x.shape[0], + } + + out = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, # TODO(shuw): fix later, now output is high prec + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=layer.w13_blockscale_swizzled, + w2_scale=layer.w2_blockscale_swizzled, + apply_router_weight_on_input=apply_router_weight_on_input, + extra_expert_args=extra_expert_args, + extra_prepare_args=extra_prepare_args, + extra_finalize_args=extra_finalize_args, + ) + return out diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 32ba1055f9c8..d11cba2caba8 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -17,7 +17,7 @@ Fp8KVCacheMethod, Fp8LinearMethod) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped) + GroupShape, is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform @@ -95,8 +95,10 @@ def __init__(self, quant_config: PTPCFp8Config): super().__init__(quant_config=quant_config) # Force weight quantization self.quant_config.is_checkpoint_fp8_serialized = False - self.fp8_linear = Fp8LinearOp(cutlass_fp8_supported=False, - use_per_token_if_dynamic=True) + self.fp8_linear = Fp8LinearOp( + act_quant_static=False, + cutlass_fp8_supported=False, + act_quant_group_shape=GroupShape.PER_TOKEN) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = torch.nn.Parameter(layer.weight.data, diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 05dff4bae395..b67ee5cf453d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -237,12 +237,6 @@ def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]], "Quark model is not in MX-FP4 format: not group_size=32") return False - # Weights need to use static quantization. - if weight_quant.get("is_dynamic") is True: - logger.debug( - "Quark model is not in MX-FP4 format: not weight static") - return False - # Activations need to use dynamic quantization. if input_quant.get("is_dynamic") is False: logger.debug( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a040c430cbca..6f69210d0861 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,11 +5,12 @@ import torch -import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + OCP_MX_BLOCK_SIZE) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -17,7 +18,9 @@ logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] +__all__ = [ + "QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod" +] class QuarkMoEMethod(FusedMoEMethodBase): @@ -40,6 +43,8 @@ def get_moe_method( if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + elif quant_config._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFp4MoEMethod(weight_config, input_config) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -242,4 +247,163 @@ def apply( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + a2_scale=layer.w2_input_scale, + activation=activation) + + +class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: dict[str, Any], input_config: dict[str, + Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_group" + and input_qscheme == "per_group"): + raise ValueError( + "For MX(FP4) Fused MoE layers, only per-group scales " + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + if self.static_input_scales: + raise NotImplementedError( + "QuarkW4A4MXFp4MoEMethod with static input scales is currently " + "not implemented. Please open an issue.") + + if not current_platform.supports_mx(): + self.emulate = True + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + else: + self.emulate = True + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + params_dtype = torch.uint8 + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + out = fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_mxfp4_w4a4=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + block_shape=None, + activation=activation, + ) + return out diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index 3c56251b7a00..880438a22a69 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -6,14 +6,16 @@ import torch import torch.nn.functional as F -import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( - OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) + OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) from vllm.platforms import current_platform +logger = init_logger(__name__) + __all__ = ["QuarkW4A4MXFP4"] @@ -25,7 +27,29 @@ def __init__(self, weight_quant_spec: dict[str, Any], self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec - self.emulate = not current_platform.supports_mx() + + self.static_input_scales = not input_quant_spec.get("is_dynamic") + + if self.static_input_scales: + raise NotImplementedError( + "QuarkW4A4MXFP4 with static input scales is currently not " + "implemented. Please open an issue.") + + if not current_platform.supports_mx(): + self.emulate = True + logger.warning_once( + "The current platform does not support native MXFP4 " + "computation. Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") + else: + self.emulate = True + logger.warning_once( + "The current platform supports native MXFP4 " + "computation, but kernels are not yet integrated in vLLM. " + "Simulated weight dequantization and activation " + "QDQ (quantize and dequantize) will be used, with the linear " + "layers computed in high precision.") @classmethod def get_min_capability(cls) -> int: @@ -37,43 +61,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) - if self.emulate: - try: - from quark.torch.export.nn.modules import realquantizer - from quark.torch.quantization.config.config import ( - QuantizationSpec) - except ImportError as err: - raise ImportError( - "The package `amd-quark` is required to use AMD Quark " - "MX-FP4 models. Please install it with `pip install " - "amd-quark`.") from err - - weight_quant_spec = QuantizationSpec.from_dict( - self.weight_quant_spec) - - weight_quantizer = realquantizer.get_real_quantizer( - qspec=weight_quant_spec, - quantizer=None, - real_quantized=True, - reorder=False, - float_dtype=self.out_dtype, - scale_shape=layer.weight_scale.shape, - zero_point_shape=None, - ) - weight_quantizer.scale.data = layer.weight_scale.data - - if not envs.VLLM_QUARK_EMU_MEM_OPT: - layer.weight = torch.nn.Parameter( - weight_quantizer(layer.weight.data).to(self.out_dtype), - requires_grad=False, - ) - else: - self.weight_quantizer = weight_quantizer - layer.weight_scale = None - - # This call is necessary to release the scales memory. - torch.cuda.empty_cache() - def create_weights(self, layer: torch.nn.Module, output_partition_sizes: list[int], input_size_per_partition: int, @@ -116,11 +103,10 @@ def apply_weights(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: if self.emulate: - if envs.VLLM_QUARK_EMU_MEM_OPT: - dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) - else: - dq_w = layer.weight - qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) - return F.linear(qdq_x, dq_w, bias) + dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype) + + x = quant_dequant_mxfp4(x) + + return F.linear(x, dq_w, bias) else: raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index c7bc98184d0e..2cb35249f49e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -7,6 +7,8 @@ from torch.nn import Parameter from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, @@ -28,10 +30,14 @@ def __init__(self, weight_config: dict[str, Any], self.is_static_input_scheme = not cast( bool, input_config.get("is_dynamic")) self.input_qscheme = cast(str, input_config.get("qscheme")) - self.use_per_token_if_dynamic = (not self.is_static_input_scheme \ - and self.input_qscheme == "per_channel") + + per_token = (not self.is_static_input_scheme + and self.input_qscheme == "per_channel") + self.act_quant_group_shape = GroupShape.PER_TOKEN \ + if per_token else GroupShape.PER_TENSOR self.fp8_linear = Fp8LinearOp( - use_per_token_if_dynamic=self.use_per_token_if_dynamic) + act_quant_static=self.is_static_input_scheme, + act_quant_group_shape=self.act_quant_group_shape) self.out_dtype = torch.get_default_dtype() @classmethod @@ -44,7 +50,7 @@ def process_weights_after_loading(self, layer) -> None: # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor if self.weight_qscheme == "per_tensor": - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -82,7 +88,7 @@ def process_weights_after_loading(self, layer) -> None: requires_grad=False) else: weight_scale = layer.weight_scale.data - if self.use_per_token_if_dynamic: + if self.act_quant_group_shape == GroupShape.PER_TOKEN: weight_scale = weight_scale.view(-1, 1) layer.weight = Parameter(weight.t(), requires_grad=False) # required by torch.compile to be torch.nn.Parameter diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index cbf8231defc6..ee5f2b51564d 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -5,6 +5,7 @@ import functools import json import os +from collections.abc import Sequence from typing import Any, Callable, Optional, Union import torch @@ -13,12 +14,13 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( - scaled_dequantize) + group_broadcast) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm +from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used logger = init_logger(__name__) @@ -54,7 +56,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( ) -> torch.Tensor: import aiter as rocm_aiter - return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) + return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) def rocm_aiter_gemm_w8a8_blockscale_fake( @@ -235,7 +237,7 @@ def block_quant_to_tensor_quant( The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. Note only float8 is supported for now. """ - x_dq_block = scaled_dequantize(x_q_block, x_s) + x_dq_block = group_broadcast(x_q_block, x_s) x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) return x_q_tensor, scale @@ -255,6 +257,7 @@ def _per_token_group_quant_fp8( # Information for float8 fp8_min, fp8_max, + use_ue8m0: tl.constexpr, # Meta-parameters BLOCK: tl.constexpr, ): @@ -284,7 +287,8 @@ def _per_token_group_quant_fp8( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) @@ -308,6 +312,7 @@ def _per_token_group_quant_fp8_colmajor( # Information for float8 fp8_min, fp8_max, + use_ue8m0: tl.constexpr, # Meta-parameters BLOCK: tl.constexpr, ): @@ -346,7 +351,8 @@ def _per_token_group_quant_fp8_colmajor( y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) - y_s = _absmax / fp8_max + scale_raw = _absmax / fp8_max + y_s = tl.math.exp2(tl.ceil(tl.log2(scale_raw))) if use_ue8m0 else scale_raw y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) @@ -360,6 +366,7 @@ def per_token_group_quant_fp8( dtype: Optional[torch.dtype] = None, column_major_scales: bool = False, out_q: Optional[torch.Tensor] = None, + use_ue8m0: bool = is_blackwell_deep_gemm_used(), ) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -374,7 +381,7 @@ def per_token_group_quant_fp8( out_q: Optional output tensor. If not provided, function will create. Returns: tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the - scaling factor for quantization. + scaling factor. """ dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( @@ -391,8 +398,7 @@ def per_token_group_quant_fp8( if x_q is None: x_q = torch.empty_like(x, device=x.device, dtype=dtype) - M = x.numel() // group_size - N = group_size + # Allocate the scale tensor in either row- or column-major format. if column_major_scales: shape = (x.shape[-1] // group_size, ) + x.shape[:-1] x_s = torch.empty(shape, device=x.device, @@ -401,6 +407,15 @@ def per_token_group_quant_fp8( shape = x.shape[:-1] + (x.shape[-1] // group_size, ) x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + # prefer CUDA kernel if available + if current_platform.is_cuda() and x.is_contiguous(): + torch.ops._C.per_token_group_fp8_quant(x, x_q, x_s, group_size, eps, + fp8_min, fp8_max, use_ue8m0) + return x_q, x_s + + # TRITON FALLBACK + M = x.numel() // group_size + N = group_size BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) @@ -417,6 +432,7 @@ def per_token_group_quant_fp8( eps, fp8_min=fp8_min, fp8_max=fp8_max, + use_ue8m0=use_ue8m0, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -432,6 +448,7 @@ def per_token_group_quant_fp8( eps, fp8_min=fp8_min, fp8_max=fp8_max, + use_ue8m0=use_ue8m0, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, @@ -651,3 +668,124 @@ def grid(META): ) return C + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 +# TODO(wentao): remove this function when DeepGEMM exposes this function +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of + 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return cdiv(x, alignment) * alignment + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 +# TODO(wentao): remove this function when DeepGEMM exposes this function +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` + will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along + the M axis (thus meets the requirement of LHS scaling tensor in + DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in + # CUDA + assert x.dim() in (2, 3) + remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride( + 2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose( + torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def requant_weight_ue8m0_inplace( + weight: torch.Tensor, + weight_scale: torch.Tensor, + block_size: Sequence[int] = (128, 128), +) -> None: + """Re-quantise *weight* so that its per-block scaling factors are in the + UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace. + + Args: + weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``. + Expected shape ``(..., M, K)``. + weight_scale: Corresponding per-block scale tensor (``torch.float32``) + with shape ``(..., M // block_size[0], K // block_size[1])``. + block_size: 2-element iterable ``[block_m, block_k]`` describing the + block quantisation granularity. + """ + if weight.numel() == 0: + return + + if weight.dtype != torch.float8_e4m3fn: + raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got " + f"{weight.dtype} instead.") + + from vllm.utils.deep_gemm import per_block_cast_to_fp8 + + block_m, block_k = int(block_size[0]), int(block_size[1]) + + # Flatten leading dimensions so we can iterate over the last two dims. + leading_shape = weight.shape[:-2] + if len(leading_shape) == 0: + w_view = weight.unsqueeze(0) + s_view = weight_scale.unsqueeze(0) + else: + w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1]) + s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:]) + + num_mats = w_view.size(0) + for idx in range(num_mats): + w_q = w_view[idx] + s_old = s_view[idx] + + # De-quantise with the *old* scaling factors (float32). + m_cur, k_cur = w_q.shape + s_float = s_old.to(torch.float32) + # Expand scales along rows and cols by block size, then crop. + s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0) + s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1) + s_exp = s_exp[:m_cur, :k_cur] + w_dq = w_q.to(torch.float32) * s_exp + # Re-quantise using power-of-two scaling (UE8M0). + w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k]) + + # Write back the results in-place. + w_q.copy_(w_requant) + s_old.copy_(s_requant) diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 9d4a188f52df..1119045db072 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -1,45 +1,67 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - import torch +from vllm.utils import direct_register_custom_op + OCP_MX_BLOCK_SIZE = 32 -def per_token_group_quant_mxfp4(x: torch.Tensor, - block_k: int, - scale_calculation_mode: str = "even" - ) -> tuple[torch.Tensor, torch.Tensor]: +def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: try: - from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( - fake_quantize_fp4_fp6_per_group_with_scale) - from quark.torch.quantization.utils import (even_round, - reshape_to_blocks) + from quark.torch.kernel import mx except ImportError as err: raise ImportError("The package `amd-quark` is required to use " "MX-FP4 models. Please install it with `pip install " "amd-quark`.") from err - axis = -1 - block_x = reshape_to_blocks(x, block_k, axis) - amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) - amax = amax.squeeze(-1) - - # TODO: there are other rounding strategies supported in quark and in the - # config.json that we do not check for here! - if scale_calculation_mode != "even": - raise NotImplementedError( - f"Scale calculation mode {scale_calculation_mode} is not yet " - "supported in MX-FP4 quantization") - scale = even_round(amax, "fp4") - - # Apply dequantize(quantize(x)). - x = fake_quantize_fp4_fp6_per_group_with_scale( - x, - scale.to(x.device), - axis=axis, - group_size=block_k, - quant_dtype="fp4", + return mx.dq_mxfp4(x, scale, float_dtype) + + +def _dequant_mxfp4_fake(x: torch.Tensor, scale: torch.Tensor, + float_dtype: torch.dtype) -> torch.Tensor: + return torch.empty((*x.shape[:-1], x.shape[-1] * 2), + dtype=float_dtype, + device=x.device) + + +def _quant_dequant_mxfp4(x: torch.Tensor, + scale_calculation_mode: str = "even") -> torch.Tensor: + try: + from quark.torch.kernel import mx + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + return mx.qdq_mxfp4(x, scale_calculation_mode) + + +def _quant_dequant_mxfp4_fake(x: torch.Tensor, + scale_calculation_mode: str = "even" + ) -> torch.Tensor: + return torch.empty_like(x) + + +try: + direct_register_custom_op( + op_name="dequant_mxfp4", + op_func=_dequant_mxfp4, + mutates_args=[], + fake_impl=_dequant_mxfp4_fake, ) + dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 +except AttributeError as error: + raise error - return x, scale +try: + direct_register_custom_op( + op_name="quant_dequant_mxfp4", + op_func=_quant_dequant_mxfp4, + mutates_args=[], + fake_impl=_quant_dequant_mxfp4_fake, + ) + quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 +except AttributeError as error: + raise error diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index d6b96774b4e8..54361a2323c2 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -3,7 +3,7 @@ """This file is used for /tests and /benchmarks""" from collections.abc import Mapping from types import MappingProxyType -from typing import Optional +from typing import ClassVar, NamedTuple, Optional import numpy import torch @@ -12,13 +12,30 @@ MARLIN_QQQ_SUPPORTED_NUM_BITS) from vllm.scalar_type import ScalarType, scalar_types -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# Use proxy as NamedTuple direct subclasses cannot have static members +class _GroupShape(NamedTuple): + row: int + col: int + + +class GroupShape(_GroupShape): + """ + This class describes the quantization group shape. + It includes static members for common shapes (per-tensor, per-token). + """ + + # Aliases for common quantization group shapes + PER_TENSOR: ClassVar['GroupShape'] + PER_TOKEN: ClassVar['GroupShape'] + + +GroupShape.PER_TENSOR = GroupShape(-1, -1) +GroupShape.PER_TOKEN = GroupShape(1, -1) # Normalize the group_shape to the full extent for any dims that are -1 -def _normalize_quant_group_shape(x: torch.Tensor, group_shape: tuple[int, - int]): +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): # -1 means full extent return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], group_shape[1] if group_shape[1] > 0 else x.shape[-1]) @@ -58,7 +75,7 @@ def group_broadcast(t, shape): # (i.e. per-token-per-group) def scaled_quantize( x: torch.Tensor, - group_shape: tuple[int, int], + group_shape: GroupShape, quant_dtype: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: group_shape = _normalize_quant_group_shape(x, group_shape) @@ -99,7 +116,7 @@ def scaled_quantize( def scaled_dequantize( x_q: torch.Tensor, x_s: torch.Tensor, - group_shape: Optional[tuple[int, int]] = None, + group_shape: Optional[GroupShape] = None, out_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: if group_shape is not None: @@ -332,6 +349,10 @@ def reshape_w(w): ) +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, group_size: int, diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index adc67aa64952..47bb45793281 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -8,6 +8,9 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.config import CompilationLevel, get_current_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + GroupShape) from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting @@ -271,20 +274,21 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, def dispatch_w8a8_scaled_mm( cutlass_fp8_supported: bool, per_tensor_weights: bool, - per_tensor_activations: bool, use_per_token_if_dynamic: Optional[bool] -) -> Callable[..., torch.Tensor]: + per_tensor_activations: bool) -> Callable[..., torch.Tensor]: + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: return cutlass_w8a8_scaled_mm if per_tensor_weights and per_tensor_activations: if current_platform.is_rocm(): return rocm_per_tensor_w8a8_scaled_mm return torch_per_tensor_w8a8_scaled_mm - # torch.scaled_mm supports per tensor weights + activations only - # so fallback to naive if per channel or per token - if (use_per_token_if_dynamic and not per_tensor_weights - and not per_tensor_activations and USE_ROWWISE_TORCH_SCALED_MM): + # If torch.scaled_mm supports per-channel (weights) per-token (inputs) + if not per_tensor_weights and not per_tensor_activations \ + and USE_ROWWISE_TORCH_SCALED_MM: return torch_per_token_w8a8_scaled_mm + # Normally, torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token return torch_channelwise_w8a8_scaled_mm @@ -299,11 +303,11 @@ class Fp8LinearOp: """ def __init__(self, + act_quant_static: bool, cutlass_fp8_supported: bool = cutlass_fp8_supported(), - use_per_token_if_dynamic: bool = False, + act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR, pad_output: Optional[bool] = None): self.cutlass_fp8_supported = cutlass_fp8_supported - self.use_per_token_if_dynamic = use_per_token_if_dynamic # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. @@ -312,9 +316,16 @@ def __init__(self, # as it breaks with dynamic shapes. if pad_output is None: config = get_current_vllm_config().compilation_config - pad_output = config.level < CompilationLevel.PIECEWISE - self.output_padding = 17 if ( - pad_output and not current_platform.is_rocm()) else None + pad_output = config.level < CompilationLevel.PIECEWISE and \ + not cutlass_fp8_supported and \ + not current_platform.is_rocm() + + self.output_padding = 17 if pad_output else None + self.act_quant_static = act_quant_static + self.act_quant_group_shape = act_quant_group_shape + self.quant_fp8 = QuantFP8(static=act_quant_static, + group_shape=act_quant_group_shape, + num_token_padding=self.output_padding) def apply( self, @@ -325,8 +336,6 @@ def apply( input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - # TODO(luka) remove this parameter in favor of __init__ - use_per_token_if_dynamic: Optional[bool] = None ) -> torch.Tensor: # ops.scaled_fp8_quant supports both dynamic and static quant. # If dynamic, layer.input_scale is None and x_scale computed from x. @@ -336,40 +345,27 @@ def apply( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[1]] - # TODO(luka) this is here because currently MLA only decides this - # during the forward method instead of in __init__. - if use_per_token_if_dynamic is None: - use_per_token_if_dynamic = self.use_per_token_if_dynamic - if out_dtype is None: out_dtype = input.dtype - # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A - if self.cutlass_fp8_supported: - assert input.dtype != current_platform.fp8_dtype( - ), "FP8 input to cutlass is not currently implemented" - qinput, x_scale = ops.scaled_fp8_quant( + # If input not quantized + # TODO(luka) remove this path if not used anymore + if input.dtype != current_platform.fp8_dtype(): + qinput, x_scale = self.quant_fp8( input_2d, input_scale, - scale_ub=input_scale_ub, - use_per_token_if_dynamic=use_per_token_if_dynamic) + input_scale_ub, + ) else: - if input.dtype != current_platform.fp8_dtype(): - # Maybe apply padding to output, see comment in __init__ - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=self.output_padding, - use_per_token_if_dynamic=use_per_token_if_dynamic) - else: - qinput, x_scale = input_2d, input_scale + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) + # TODO(luka) do this dispatch during init (after ScaledMM refactor) w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm( self.cutlass_fp8_supported, per_tensor_weights, - per_tensor_activations, use_per_token_if_dynamic) + per_tensor_activations) return w8a8_scaled_mm_func(qinput=qinput, weight=weight, diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py deleted file mode 100644 index db68f18726d3..000000000000 --- a/vllm/model_executor/layers/rejection_sampler.py +++ /dev/null @@ -1,406 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from functools import cached_property -from importlib.util import find_spec -from typing import Optional - -import torch -import torch.jit - -import vllm.envs as envs -from vllm.logger import init_logger -from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeStochasticBaseSampler) -from vllm.platforms import current_platform - -logger = init_logger(__name__) - -if find_spec("flashinfer"): - """ - Consider utilizing the FlashInfer rejection sampling kernel initially, - as it employs a dedicated kernel rather than relying on - Torch tensor operations. This design choice helps to fuse operations, - reduce memory I/O, and consequently enhances performance. - """ - from flashinfer.sampling import chain_speculative_sampling -else: - chain_speculative_sampling = None - - -class RejectionSampler(SpecDecodeStochasticBaseSampler): - """Apply modified rejection sampling as described in "Accelerating Large - Language Model Decoding with Speculative Sampling" - https://arxiv.org/pdf/2302.01318.pdf. - """ - - def __init__(self, - strict_mode: bool = False, - use_flashinfer: Optional[bool] = None): - """Create a rejection sampler. - - Args: - strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. - use_flashinfer: We will use this parameter to determine whether - to use the FlashInfer rejection sampling kernel or not. If it's - None, we will use the default value from the environment variable. - This parameter is only used for testing purposes. - """ - super().__init__(strict_mode=strict_mode) - if use_flashinfer is None: - self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( - chain_speculative_sampling is not None) - else: - self.use_flashinfer = use_flashinfer - - if self.use_flashinfer: - logger.info("Use flashinfer for rejection sampling.") - else: - logger.info("Use pytorch for rejection sampling.") - - def forward( - self, - target_with_bonus_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - seeded_seqs: Optional[dict[int, torch.Generator]] = None, - ) -> torch.Tensor: - """Sample token ids using rejection sampling. This accepts or rejects - tokens proposed by the draft model using the probability of each token - according to the draft and target models. - - In the worst case where all draft tokens are rejected, it is guaranteed - one correct token will be emitted. - - In the case where all draft tokens are accepted, a bonus token will be - accepted as its cheap to have the target model score this speculative - sequence. - - Args: - target_with_bonus_probs: The probability distribution - over token ids given context according to the target model. - shape = [batch_size, num_speculative_tokens + 1, vocab_size] - - bonus_token_ids: The "bonus" token ids that are accepted iff all - speculative tokens in a sequence are accepted. - shape = [batch_size, num_bonus_tokens] - - draft_probs: The probability distribution over token ids given - context according to the draft model. - shape = [batch_size, num_speculative_tokens, vocab_size] - - draft_token_ids: The token ids that were sampled from the draft - probabilities. - shape = [batch_size, num_speculative_tokens] - - seeded_seqs: Dict of batch row index to torch generator, for - sequences using seeded generation. - - Returns: - output_token_ids: The token ids sampled via rejection sampling, - or -1 if unable to sample a token because the previous token - was rejected. - shape = [batch_size, num_speculative_tokens + num_bonus_tokens] - """ - # Only perform shape/dtype/device checking in strict mode, as it adds - # overhead. - if self._strict_mode: - self._raise_if_incorrect_input(target_with_bonus_probs, - draft_token_ids, bonus_token_ids, - draft_probs) - - batch_size, k, _ = draft_probs.shape - - # batch_size = 0 when all requests in the batch are - # non_spec requests. In this case, output_token_ids is - # just an empty tensor. - if batch_size == 0: - return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) - - # If use Flashinfer chain_speculative_sampling kernel - # for rejection sampling - if self.use_flashinfer and chain_speculative_sampling is not None: - batch_size, k, _ = draft_probs.shape - - (output_token_ids, accepted_token_num, - emitted_token_num) = chain_speculative_sampling( - draft_probs, - draft_token_ids, - target_with_bonus_probs, - ) - - # num_emitted_tokens returned by flashinfer - # does not include the bonus token - # Flashinfer stops at the first token that violates - # the condition p >= q and does not include recovery/bonus token. - # Therefore, we need to add batch_size here. - self.num_accepted_tokens += accepted_token_num.sum() - self.num_emitted_tokens += emitted_token_num.sum() + batch_size - self.num_draft_tokens += batch_size * k - else: - accepted, recovered_token_ids = ( - self._batch_modified_rejection_sampling( - target_with_bonus_probs[:, :-1], - draft_probs, - draft_token_ids, - seeded_seqs, - )) - - output_token_ids = self._create_output( - accepted, - recovered_token_ids, - draft_token_ids, - bonus_token_ids, - ) - - return output_token_ids - - def _batch_modified_rejection_sampling( - self, - target_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_token_ids: torch.Tensor, # [batch_size, k] - seeded_seqs: Optional[dict[int, torch.Generator]], - ) -> tuple[torch.Tensor, torch.Tensor]: - """Perform modified rejection sampling on each sequence. - - Returns: - A tuple of two tensors: - 0: A bool tensor of which tokens in each sequence is accepted. - shape = [batch_size, k] - 1: Token ids sampled from a recovered distribution, to be used - when a token is rejected. - shape = [batch_size, k] - """ - - batch_size, k, vocab_size = draft_probs.shape - - # shape [batch_size, k] - accepted = self._get_accepted(target_probs, draft_probs, - draft_token_ids, seeded_seqs) - - recovered_probs = self._get_recovered_probs( - target_probs, draft_probs).reshape(batch_size * k, vocab_size) - - # NOTE: the recovered_probs are overwritten by this method. - recovered_token_ids = _multinomial( - recovered_probs, - num_samples=1, - k=k, - seeded_seqs=seeded_seqs or {}, - ).reshape(batch_size, k) - - return accepted, recovered_token_ids - - def _create_uniform_samples(self, - seeded_seqs: Optional[dict[int, - torch.Generator]], - batch_size: int, k: int, - device: torch.device) -> torch.Tensor: - """ - Generates a batch of uniform random samples, with optional seeding - for specific sequences. - - This method creates a tensor of shape `(batch_size, k + 1)` filled - with uniform random values in the range [0, 1). If `seeded_seqs` - is provided, the sequences corresponding to specific indices - will be generated using the provided `torch.Generator` for - reproducibility. The other sequences will be generated without - a seed. - - Args: - seeded_seqs : Optional[dict[int, torch.Generator]] - A dictionary mapping indices in the batch to - `torch.Generator` objects. If `None`, all samples are - generated without a seed. - batch_size : int - The number of sequences to generate. - k : int - The number of random samples per sequence. - device : torch.device - The device on which to allocate the tensor. - - Returns: - uniform_rand : torch.Tensor - A tensor of shape `(batch_size, k + 1)` containing uniform - random values in the range [0, 1). - """ - if not seeded_seqs: - return torch.rand(batch_size, k + 1, device=device) - - uniform_rand = torch.empty(batch_size, k + 1, device=device) - - non_seeded_indices = [] - for idx in range(batch_size): - generator = seeded_seqs.get(idx) - if generator is None: - non_seeded_indices.append(idx) - else: - uniform_rand[idx, :] = torch.rand(1, - k + 1, - dtype=self.probs_dtype, - device=device, - generator=generator) - if non_seeded_indices: - uniform_rand[non_seeded_indices, :] = torch.rand( - len(non_seeded_indices), - k + 1, - dtype=self.probs_dtype, - device=device) - return uniform_rand - - def _get_accepted( - self, - target_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_probs: torch.Tensor, # [batch_size, k, vocab_size] - draft_token_ids: torch.Tensor, # [batch_size, k] - seeded_seqs: Optional[dict[int, torch.Generator]], - ) -> torch.Tensor: - r"""Create bool matrix over the proposed draft tokens. If - True, then a token can be accepted, else it should be - rejected. - - Given $q(\hat{x}_{n+1}|x_1, \dots, x_n)$, the probability of - $\hat{x}_{n+1}$ given context $x_1, \dots, x_n$ according - to the target model, and $p(\hat{x}_{n+1}|x_1, \dots, x_n)$, the - same conditional probability according to the draft model, the token - is accepted with probability: - - $$ - \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} - {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) - $$ - - This implementation does not apply causality. When using the output, - if a token is rejected, subsequent tokens should not be used. - - Returns a bool tensor of shape [batch_size, k] specifying which tokens - are accepted. - """ - batch_size, k, _ = draft_probs.shape - batch_indices = torch.arange(batch_size, - device=target_probs.device)[:, None] - probs_indices = torch.arange(k, device=target_probs.device) - - # shape [batch_size, k] - selected_draft_probs = draft_probs[batch_indices, probs_indices, - draft_token_ids] - - # shape [batch_size, k] - selected_target_probs = target_probs[batch_indices, probs_indices, - draft_token_ids] - - uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, - k - 1, target_probs.device) - - capped_ratio = torch.minimum( - selected_target_probs / selected_draft_probs, - torch.full((1, ), 1, device=target_probs.device)) - accepted = uniform_rand < capped_ratio - - return accepted - - def _get_recovered_probs( - self, - target_probs: torch.Tensor, # [k, vocab_size] - draft_probs: torch.Tensor, # [k, vocab_size] - ) -> torch.Tensor: - r"""Create a probability distribution for each proposed token which can - be sampled if the proposed token is rejected. - - When this routine is applied sequentially, the true distribution of the - target model is recovered (within hardware numerics). - - The probability distribution used in this rejection case is constructed - as follows. Given $q(x|x_1, \dots, x_n)$, the probability of - $x$ given context $x_1, \dots, x_n$ according to the target - model and $p(x|x_1, \dots, x_n)$, the same conditional probability - according to the draft model: - - $$ - x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ - $$ - - where $(f(x))_+$ is defined as: - - $$ - (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} - $$ - - See https://github.com/vllm-project/vllm/pull/2336 for a visualization - of the draft, target, and recovered probability distributions. - - Returns a tensor of shape [batch_size, k, vocab_size]. - - Note: - This batches operations on GPU and thus constructs the recovered - distribution for all tokens, even if they are accepted. This causes - division-by-zero errors, so we use self._smallest_positive_value to - avoid that. This introduces some drift to the distribution. - """ - _, k, _ = draft_probs.shape - - # shape [batch_size, k, vocab_size] - difference = target_probs - draft_probs - - # TODO(cade): Can we use logprobs instead of probs, and avoid the - # division-by-zero errors without introducing distribution drift? - - # shape [batch_size, k, vocab_size] - f = torch.clamp(difference, min=self._smallest_positive_value) - - # shape [batch_size, k, vocab_size] - recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) - - return recovered_probs - - @cached_property - def _smallest_positive_value(self) -> float: - """Return the smallest positive value representable by the probs dtype. - This value is used when constructing a distribution from which to sample - recovered tokens in the first rejection case. - - See _get_recovered_probs for more details - - Note that this isn't actually the smallest positive value representable - by float32, but the smallest positive normal value. - See https://en.wikipedia.org/wiki/Subnormal_number for more information. - """ - return torch.finfo(self.probs_dtype).tiny - - -# torch.multinomial forces a GPU<->CPU sync. -# Therefore, we use an optimized implementation instead that skips the sync. -# Note that we always sample with replacement. -# probs will be modified in place, but this is fine, as we pass -# in a copy already. -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def _multinomial( - probs: torch.Tensor, - num_samples: int, - k: int, - seeded_seqs: dict[int, torch.Generator], -) -> torch.Tensor: - - if num_samples > 1: - # This is equivalent to torch.repeat_interleaved (which also - # forces a GPU<->CPU sync). - probs = probs[:, None, :].expand(probs.shape[0], num_samples, - probs.shape[1]).contiguous().view( - -1, probs.shape[1]) - q = torch.empty_like(probs) - if not seeded_seqs: - q.exponential_(1.0) - else: - start = 0 - for idx in range(len(q) // k): - end = start + k - generator = seeded_seqs.get(idx) - # Note: generator might be None for non seeded - q[start:end].exponential_(1.0, generator=generator) - start = end - - return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index a4615132a518..dddd4d6a7117 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -229,64 +229,6 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key - def forward_hpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: Optional[torch.Tensor] = None, - offsets: Optional[torch.Tensor] = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - from habana_frameworks.torch.hpex.kernels import ( - RotaryPosEmbeddingMode, apply_rotary_pos_emb) - if offsets is not None: - offsets = offsets.view(positions.shape[0], -1) - positions = positions + offsets - positions = positions.flatten() - num_tokens = positions.shape[0] - cos_sin = self.cos_sin_cache.index_select(0, positions).view( - num_tokens, 1, -1) - cos, sin = cos_sin.chunk(2, dim=-1) - # HPU RoPE kernel requires hidden dimension for cos and sin to be equal - # to query hidden dimension, so the original tensors need to be - # expanded - # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE - # and expansion of cos/sin tensors via concatenation - # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE - # and expansion of cos/sin tensors via repeat_interleave - rope_mode: RotaryPosEmbeddingMode - if self.is_neox_style: - rope_mode = RotaryPosEmbeddingMode.BLOCKWISE - cos = torch.cat((cos, cos), dim=-1) - sin = torch.cat((sin, sin), dim=-1) - else: - rope_mode = RotaryPosEmbeddingMode.PAIRWISE - sin = torch.repeat_interleave(sin, - 2, - dim=-1, - output_size=cos_sin.shape[-1]) - cos = torch.repeat_interleave(cos, - 2, - dim=-1, - output_size=cos_sin.shape[-1]) - - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., :self.rotary_dim] - query_pass = query[..., self.rotary_dim:] - query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, - rope_mode) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - if key is not None: - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., :self.rotary_dim] - key_pass = key[..., self.rotary_dim:] - key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, - rope_mode) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) - return query, key - def forward_neuron( self, positions: torch.Tensor, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 08840fc40cf6..e77eb637c894 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -21,7 +21,6 @@ from vllm.sequence import (VLLM_INVALID_TOKEN_ID, CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) -from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): # yapf: disable @@ -119,9 +118,6 @@ class SamplerOutput( # specified in lieu of prompt token ids or text. sampled_token_embeds: Optional[torch.Tensor] = None - # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - # Optional last hidden states from the model. hidden_states: Optional[torch.Tensor] = None @@ -159,11 +155,9 @@ def __repr__(self) -> str: else self.sampled_token_probs.shape) sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else self.sampled_token_ids.shape) - return ( - f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr}, " - f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + return (f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr})") class Sampler(nn.Module): diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py deleted file mode 100644 index 0a36fe9be45b..000000000000 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ /dev/null @@ -1,259 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import abstractmethod -from typing import Optional, Union - -import torch -import torch.jit -import torch.nn as nn - -from vllm.platforms import current_platform - - -class SpecDecodeBaseSampler(nn.Module): - """Base class for samplers used for Speculative Decoding verification - step. - """ - - def __init__(self, strict_mode: bool = False): - """Base class constructor. - Args: - strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. - """ - super().__init__() - self._strict_mode = strict_mode - - # NOTE: A "bonus token" is accepted iff all proposal tokens are - # accepted. There is always only one possible bonus token. We store this - # value in a variable for readability. - self._num_bonus_tokens = 1 - - self.num_accepted_tokens: Optional[torch.Tensor] = None - self.num_emitted_tokens: Optional[torch.Tensor] = None - self.num_draft_tokens: int = 0 - - def init_gpu_tensors(self, device: Union[int, str]) -> None: - assert self.num_accepted_tokens is None - if isinstance(device, int): - device = f"{current_platform.device_type}:{device}" - elif not isinstance(device, str): - raise ValueError(f"Device must be int or str, get {type(device)}") - self.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - self.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - - def init_tensors(self, - device: Union[int, str], - device_type: Union[torch.device, str] = 'cuda') -> None: - assert self.num_accepted_tokens is None - if isinstance(device_type, torch.device): - device_type = device_type.type - if isinstance(device, int): - device = f"{device_type}:{device}" - self.num_accepted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - self.num_emitted_tokens = torch.tensor(0, - dtype=torch.long, - device=device) - - @property - def probs_dtype(self): - return torch.float32 - - @property - def token_id_dtype(self): - return torch.int64 - - def _create_output( - self, - accepted: torch.Tensor, # [batch_size, k] - substitute_token_ids: torch.Tensor, # [batch_size, k] - draft_token_ids: torch.Tensor, # [batch_size, k] - bonus_token_ids: torch.Tensor, # [batch_size] - ) -> torch.Tensor: - """Format output. Returns a matrix of token ids. When - a token is rejected via sampling, all subsequent token ids are - set to -1 for the sequence. - - Args: - accepted: A boolean tensor indicating if the corresponding - draft token in draft_token_ids should be accepted or not. - substitute_token_ids: A tensor of token_ids that can be used - as substitutes for the draft token ids if the proposed token - is rejected. - draft_token_ids: A tensor of token ids speculated by the - draft model. - bonus_token_ids: Token ids to use as the bonus token if - all the draft tokens are accepted. - Returns: - A tensor containing the accepted token ids. The shape of the - tensor is [batch_size, k + num_bonus_tokens] - """ - batch_size, k = substitute_token_ids.shape - bonus_token_ids = bonus_token_ids.squeeze(-1) - # Determine the index of the first False value for each row. - limits = (accepted == 0).max(1).indices - limits[~(accepted == 0).any(1)] = k - - # Create masks using the indices. - indices = torch.arange(k, device=accepted.device).unsqueeze(0) - accepted_mask = indices < limits.unsqueeze(1) - after_false_mask = indices == limits.unsqueeze(1) - - # Create an extended output tensor - output_with_bonus_tokens = -torch.ones( - (batch_size, k + self._num_bonus_tokens), - dtype=self.token_id_dtype, - device=accepted.device) - output = output_with_bonus_tokens[:, :k] - - # Fill in the first k columns of the output tensor using masks and data - # tensors. - output[:, :k] = torch.where(accepted_mask, draft_token_ids, - -torch.ones_like(draft_token_ids)) - - # Fill the last column. - # We check output directly as accepted may have True values inconsistent - # with causal acceptance. - output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, - bonus_token_ids, -1) - - # Fill the recovered token ids. - output.mul_(~after_false_mask).add_( - substitute_token_ids.mul(after_false_mask)) - - self.num_accepted_tokens += accepted.sum() - self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() - self.num_draft_tokens += batch_size * k - - return output_with_bonus_tokens - - def _raise_if_incorrect_input( - self, - target_with_bonus_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: Optional[torch.Tensor] = None, - ) -> None: - self._raise_if_incorrect_shape(target_with_bonus_probs, - draft_token_ids, bonus_token_ids, - draft_probs) - self._raise_if_incorrect_dtype(target_with_bonus_probs, - draft_token_ids, bonus_token_ids, - draft_probs) - self._raise_if_inconsistent_device(target_with_bonus_probs, - draft_token_ids, bonus_token_ids, - draft_probs) - self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], - draft_token_ids, bonus_token_ids) - - def _raise_if_incorrect_shape( - self, - target_with_bonus_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: Optional[torch.Tensor] = None, - ) -> None: - (target_batch_size, num_target_probs, - target_vocab_size) = target_with_bonus_probs.shape - - # Does not count the extra token - num_target_probs -= 1 - - # validate the shape of draft token ids. - draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape - assert draft_token_ids_batch_size == target_batch_size - assert num_draft_token_ids == num_target_probs - - # validate the shape of bonus token ids - bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape - assert bonus_batch_size == target_batch_size - assert num_bonus_tokens == self._num_bonus_tokens - - # validate the shape of draft probs if it is set - if draft_probs is not None: - (draft_batch_size, num_draft_probs, - draft_vocab_size) = draft_probs.shape - assert draft_batch_size == target_batch_size - assert num_draft_probs == num_target_probs - assert (draft_vocab_size == target_vocab_size - ), f"{draft_vocab_size=} {target_vocab_size=}" - - def _raise_if_incorrect_dtype( - self, - target_with_bonus_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: Optional[torch.Tensor] = None, - ) -> None: - assert target_with_bonus_probs.dtype == self.probs_dtype - assert draft_token_ids.dtype == self.token_id_dtype - assert bonus_token_ids.dtype == self.token_id_dtype - if draft_probs is not None: - assert draft_probs.dtype == self.probs_dtype - - def _raise_if_inconsistent_device( - self, - target_with_bonus_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: Optional[torch.Tensor] = None, - ) -> None: - devices = [ - t.device for t in [ - target_with_bonus_probs, bonus_token_ids, draft_probs, - draft_token_ids - ] if t is not None - ] - assert all([devices[0] == device for device in devices]) - - def _raise_if_out_of_bounds_vocab( - self, - vocab_size: int, - draft_token_ids: torch.Tensor, - bonus_token_ids: torch.Tensor, - ) -> None: - assert torch.all(bonus_token_ids < vocab_size) - assert torch.all(bonus_token_ids >= 0) - assert torch.all(draft_token_ids < vocab_size) - assert torch.all(draft_token_ids >= 0) - - -class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): - """Base class for samplers used for Speculative Decoding verification - step which are deterministic. - """ - - @abstractmethod - def forward( - self, - target_with_bonus_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> torch.Tensor: - raise NotImplementedError - - -class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): - """Base class for samplers used for Speculative Decoding verification - step which are stochastic - """ - - @abstractmethod - def forward( - self, - target_with_bonus_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - seeded_seqs: Optional[dict[int, torch.Generator]] = None, - ) -> torch.Tensor: - raise NotImplementedError diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py deleted file mode 100644 index 5dabaa5379e7..000000000000 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ /dev/null @@ -1,166 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.jit - -from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeDeterministicBaseSampler) - - -class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): - """Apply typical acceptance sampling as described in section 3.3.1 in - "MEDUSA: Simple LLM Inference Acceleration Framework with - Multiple Decoding Heads" - https://arxiv.org/pdf/2401.10774 - """ - - def __init__( - self, - posterior_threshold: float, - posterior_alpha: float, - strict_mode: bool = False, - ): - """Create a Typical Acceptance Sampler. - - Args: - strict_mode: Whether or not to perform shape/device/dtype checks - during sampling. This catches correctness issues but adds - nontrivial latency. - posterior_threshold : A threshold value that sets a lower bound - on the posterior probability of a token in target model for it - to be accepted. - posterior_alpha : A scaling factor for the entropy-based - threshold in typical acceptance sampling. - """ - self._posterior_threshold = posterior_threshold - self._posterior_alpha = posterior_alpha - super().__init__(strict_mode=strict_mode) - - def forward( - self, - target_with_bonus_probs: torch.Tensor, - bonus_token_ids: torch.Tensor, - draft_probs: torch.Tensor, - draft_token_ids: torch.Tensor, - ) -> torch.Tensor: - """Sample token ids using typical acceptance sampling. This accepts - or rejects tokens proposed by the draft model using the probability - of each token according to the draft and target models. - - In the worst case where all draft tokens are rejected, it is guaranteed - one token will be emitted. - - In the case where all draft tokens are accepted, the bonus token will be - accepted. - - Args: - target_probs: The probability distribution over token ids given - context according to the target model. - shape = [batch_size, num_speculative_tokens, vocab_size] - - bonus_token_ids: The "bonus" token ids that are accepted iff all - speculative tokens in a sequence are accepted. - shape = [batch_size, num_bonus_tokens] - - draft_probs: This parameter is unused by the acceptance sampler. - - draft_token_ids: The token ids that were sampled from the draft - probabilities. - shape = [batch_size, num_speculative_tokens] - - Returns: - output_token_ids: The token ids sampled via rejection sampling, - or -1 if unable to sample a token because the previous token - was rejected. - shape = [batch_size, num_speculative_tokens + num_bonus_tokens] - """ - # Only perform shape/dtype/device checking in strict mode, as it adds - # overhead. - if self._strict_mode: - self._raise_if_incorrect_input(target_with_bonus_probs, - draft_token_ids, bonus_token_ids) - target_probs = target_with_bonus_probs[:, :-1] - accepted = self._evaluate_accepted_tokens(target_probs, - draft_token_ids) - recovered_token_ids = self._get_recovered_token_ids(target_probs) - output_token_ids = self._create_output(accepted, recovered_token_ids, - draft_token_ids, - bonus_token_ids) - return output_token_ids - - def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): - r""" - Evaluates and returns a mask of accepted tokens based on the - posterior probabilities. - - Args: - target_probs (torch.Tensor): A tensor of shape - (batch_size, k, vocab_size) representing the probabilities of - each token in the vocabulary for each position in the proposed - sequence. This is the distribution generated by the target - model. - draft_token_ids (torch.Tensor): A tensor of shape (batch_size, k) - representing the proposed token ids. - - A draft token_id x_{n+k} is accepted if it satisfies the - following condition - - $$ - p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > - \min \left( \epsilon, \delta * \exp \left( - -H(p_{\text{original}}( - \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) - $$ - - where $p_{\text{original}}$ corresponds to target_probs - and $\epsilon$ and $\delta$ correspond to hyperparameters - specified using self._posterior_threshold and self._posterior_alpha - - This method computes the posterior probabilities for the given - draft token ids based on the provided target probabilities. It - calculates the entropy of the posterior distribution and determines - a dynamic threshold for each token position using the provided - posterior_threshold and posterior_alpha values. The method then - returns a boolean mask indicating which tokens can be accepted. - - Returns: - torch.Tensor: A boolean tensor of shape (batch_size, k) where each - element indicates whether the corresponding draft token has - been accepted or rejected. True indicates acceptance and false - indicates rejection. - """ - device = target_probs.device - candidates_prob = torch.gather( - target_probs, dim=-1, - index=draft_token_ids.unsqueeze(-1)).squeeze(-1) - # A small constant added to prevent computing the logarithm of zero, - # which can lead to undefined values. - epsilon = 1e-5 - posterior_entropy = -torch.sum( - target_probs * torch.log(target_probs + epsilon), dim=-1) - threshold = torch.minimum( - torch.ones_like(posterior_entropy, device=device) * - self._posterior_threshold, - torch.exp(-posterior_entropy) * self._posterior_alpha, - ) - accepted_mask = candidates_prob > threshold - return accepted_mask - - def _get_recovered_token_ids(self, target_probs): - """ - The recovered token ids will fill the first unmatched token - by the target token. - - Args: - target_probs (torch.Tensor): A tensor of shape - (batch_size, k, vocab_size) containing the target probability - distribution. - - Returns: - torch.Tensor: A tensor of shape (batch_size, k) with the recovered - token ids which are selected from target probs. - """ - max_indices = torch.argmax(target_probs, dim=-1) - - return max_indices diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f35f969781bd..a5f262c832bf 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -388,20 +388,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # Copy the data. Select chunk corresponding to current shard. loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - - if current_platform.is_hpu(): - # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here, - # so we're using a workaround. Remove this when fixed in - # HPU PT bridge. - padded_weight = torch.cat([ - loaded_weight, - torch.zeros(param.shape[0] - loaded_weight.shape[0], - *loaded_weight.shape[1:]) - ]) - param.data.copy_(padded_weight) - else: - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) - param[loaded_weight.shape[0]:].data.fill_(0) + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) def forward(self, input_): if self.tp_size > 1: diff --git a/vllm/model_executor/model_loader/base_loader.py b/vllm/model_executor/model_loader/base_loader.py index 5018c7d9a360..4cf6c7988960 100644 --- a/vllm/model_executor/model_loader/base_loader.py +++ b/vllm/model_executor/model_loader/base_loader.py @@ -6,9 +6,12 @@ import torch.nn as nn from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import init_logger from vllm.model_executor.model_loader.utils import ( initialize_model, process_weights_after_loading, set_default_torch_dtype) +logger = init_logger(__name__) + class BaseModelLoader(ABC): """Base class for model loaders.""" @@ -32,11 +35,16 @@ def load_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> nn.Module: """Load a model with the given configurations.""" device_config = vllm_config.device_config - target_device = torch.device(device_config.device) + load_config = vllm_config.load_config + load_device = device_config.device if load_config.device is None else \ + load_config.device + target_device = torch.device(load_device) with set_default_torch_dtype(model_config.dtype): with target_device: model = initialize_model(vllm_config=vllm_config, model_config=model_config) + + logger.debug("Loading weights on %s ...", load_device) # Quantization does not happen in `load_weights` but after it self.load_weights(model, model_config) process_weights_after_loading(model, model_config, target_device) diff --git a/vllm/model_executor/model_loader/bitsandbytes_loader.py b/vllm/model_executor/model_loader/bitsandbytes_loader.py index 8e330f7eeaf4..68fcb785691c 100644 --- a/vllm/model_executor/model_loader/bitsandbytes_loader.py +++ b/vllm/model_executor/model_loader/bitsandbytes_loader.py @@ -20,6 +20,7 @@ get_tensor_model_parallel_world_size) # yapf: enable from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, MergedColumnParallelLinear, QKVParallelLinear, @@ -374,7 +375,8 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, if weight_sub_tensor.is_cuda: loaded_weight = weight_sub_tensor else: - loaded_weight = weight_sub_tensor.cuda() + loaded_weight = weight_sub_tensor.to( + device=current_platform.device_type) # remove the following after the issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 @@ -411,9 +413,33 @@ def _get_bnb_target_modules(self, model: nn.Module) -> None: # in case model has a mixture of disk-merged and disk-split # weights with same last name. self.target_modules.append(name) + elif (isinstance(module, FusedMoE) + and hasattr(module.quant_method, "quant_config")): + if not hasattr(model, "get_expert_mapping"): + raise AttributeError( + f"MoE Model {type(model).__name__} does not support " + "BitsAndBytes quantization yet. Ensure this model has " + "'get_expert_mapping' method.") + # TODO: support FusedMoE with prequant and 8bit. + if self.pre_quant: + raise ValueError( + "Prequant BitsAndBytes models with FusedMoE is not " + "supported yet.") + if self.load_8bit: + raise ValueError( + "BitsAndBytes 8bit quantization with FusedMoE is not " + "supported yet.") + # Get the corresponding weight name using module name and + # get_expert_mapping. + expert_mapping = model.get_expert_mapping() + for exp in expert_mapping: + weight_name = exp[1] + rep_name = name.replace("experts", + "") + weight_name.removesuffix(".") + self.target_modules.append(rep_name) assert (self.target_modules - ), "vllm currently does not support BNB quantization for" + ), "vLLM currently does not support BNB quantization for" f" {type(model).__name__}" def _classify_module_sharding(self, model: nn.Module): @@ -437,6 +463,14 @@ def _classify_module_sharding(self, model: nn.Module): # dimension (dim=-1) elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) + elif isinstance(module, FusedMoE): + expert_mapping = model.get_expert_mapping() + for exp in expert_mapping: + if exp[-1] == "w2": + weight_name = exp[1] + rep_name = name.replace( + "experts", "") + weight_name.removesuffix(".") + self.column_sharded_weights_modules.append(rep_name) def _verify_model_compatibility(self, model: nn.Module, model_config: ModelConfig) -> None: @@ -490,34 +524,132 @@ def _initialize_loader_state(self, model: nn.Module, self._get_bnb_target_modules(model) self._classify_module_sharding(model) - def load_weights(self, model: nn.Module, - model_config: ModelConfig) -> None: + def _dequantize_dq(self, quant_states: Any): + """ + When BNB employs Double Quantization, we perform the dequantization of + these constants during weight loading rather than at inference time, + thereby avoiding this computational overhead during inference. This + comes at the cost of increased memory usage. + """ + from bitsandbytes.functional import QuantState, dequantize_blockwise - self._verify_model_compatibility(model, model_config) - self._initialize_loader_state(model, model_config) + def _dequantize_single_state(quant_state): + """Helper function to dequantize a single QuantState object.""" + if not (isinstance(quant_state, QuantState) + and quant_state.nested): + return - logger.info("Loading weights with BitsAndBytes quantization. " - "May take a while ...") - qweight_iterator, quant_state_dict = ( - self._get_quantized_weights_iterator( - model_config.model, - model_config.revision, - )) - weights_to_load = {name for name, _ in model.named_parameters()} - loaded_weights = model.load_weights(qweight_iterator) - # Some models may have weights loading tracker unimplemented. - if loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 + absmax = dequantize_blockwise(quant_state.absmax, + quant_state.state2) + absmax += quant_state.offset - param_dict = dict(model.named_parameters()) + # Ensure float32 dtype + if absmax.dtype != torch.float32: + absmax = absmax.float() + + quant_state.absmax = absmax + quant_state.nested = False + quant_state.offset = None + quant_state.state2 = None + + if isinstance(quant_states, dict): + for quant_state in quant_states.values(): + _dequantize_single_state(quant_state) + else: + _dequantize_single_state(quant_states) + return quant_states + + def _fuse_moe_quant_states(self, model: nn.Module, + quant_states_dict: dict) -> dict: + """ + + This function consolidates individual expert quantization states into + fused representations for w13 and w2. + """ + from bitsandbytes.functional import QuantState + + if not hasattr(model, "get_expert_mapping"): + return dict() + + expert_mapping = model.get_expert_mapping() + expert_qs_dict = {} + for name, module in model.named_modules(): + if not isinstance(module, FusedMoE): + continue + w1_states_lst = [] + w2_states_lst = [] + w3_states_lst = [] + for exp in expert_mapping: + shard_id = exp[-1] + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + layer_prefix = name.split("experts")[0] + weight_qual_name = layer_prefix + exp[1] + "weight" + quant_state = self._dequantize_dq( + quant_states_dict[weight_qual_name]) + if shard_id == "w1": + w1_states_lst.append(quant_state) + elif shard_id == "w2": + w2_states_lst.append(quant_state) + else: + w3_states_lst.append(quant_state) + del quant_states_dict[weight_qual_name] + assert (len(w1_states_lst) == len(w2_states_lst) == + len(w3_states_lst)) + w13_absmax_lst = [] + w2_absmax_lst = [] + w13_total_dim0 = 0 + w2_total_dim0 = 0 + for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, + w3_states_lst): + assert w1_qs.shape == w3_qs.shape + assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize + assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype + # w1 and w3 are interleaved in storage + w13_absmax_lst.append(w1_qs.absmax) + w13_absmax_lst.append(w3_qs.absmax) + w2_absmax_lst.append(w2_qs.absmax) + w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0] + w2_total_dim0 += w2_qs.shape[0] + + w13_absmax = torch.cat(w13_absmax_lst) + w2_absmax = torch.cat(w2_absmax_lst) + # Create fused quantization state for w13. + w13_qs = QuantState( + absmax=w13_absmax, + shape=(w13_total_dim0, w1_states_lst[0].shape[1]), + code=w1_states_lst[0].code, + blocksize=w1_states_lst[0].blocksize, + quant_type="nf4", + dtype=w1_states_lst[0].dtype, + ) + # Create fused quantization state for w2. + w2_qs = QuantState( + absmax=w2_absmax, + shape=(w2_total_dim0, w2_states_lst[0].shape[1]), + code=w2_states_lst[0].code, + blocksize=w2_states_lst[0].blocksize, + quant_type="nf4", + dtype=w2_states_lst[0].dtype, + ) + # The weight suffixes .w13_weight and .w2_weight are consistent + # with the param in BitsAndBytesMoEMethod. + w13_weight_name = name + ".w13_weight" + w2_weight_name = name + ".w2_weight" + expert_qs_dict[w13_weight_name] = w13_qs + expert_qs_dict[w2_weight_name] = w2_qs + return expert_qs_dict + + def _stack_quantization_states( + self, model: nn.Module, + quant_state_dict: dict) -> dict[str, dict[int, Any]]: stacked_quant_state_dict: dict[str, dict[int, Any]] = {} # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter - + param_dict = dict(model.named_parameters()) for quant_param_name in quant_state_dict: if is_pp_missing_parameter(quant_param_name, model): continue @@ -558,14 +690,20 @@ def load_weights(self, model: nn.Module, stacked_quant_state_dict[quant_param_name][shard_index] = ( quant_state_dict[non_stacked_param_name]) + return stacked_quant_state_dict + def _bind_quant_states_to_params(self, model: nn.Module, + stacked_quant_state_dict: dict) -> None: # save quant_states and offsets as the attributes of the parameters + param_dict = dict(model.named_parameters()) for param_name, param in param_dict.items(): if param_name in stacked_quant_state_dict: quant_states = stacked_quant_state_dict[param_name] # Dequantize double quantized values during weight loading. - dequantize_dq(quant_states) + self._dequantize_dq(quant_states) set_weight_attrs(param, {"bnb_quant_state": quant_states}) + if not isinstance(quant_states, dict): + continue pack_ratio = getattr(param, "pack_factor", -1) if pack_ratio == -1: @@ -585,29 +723,40 @@ def load_weights(self, model: nn.Module, if self.load_8bit: set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + + self._verify_model_compatibility(model, model_config) + self._initialize_loader_state(model, model_config) + + logger.info("Loading weights with BitsAndBytes quantization. " + "May take a while ...") + qweight_iterator, quant_state_dict = ( + self._get_quantized_weights_iterator( + model_config.model, + model_config.revision, + )) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") + expert_quant_state_dict = self._fuse_moe_quant_states( + model, quant_state_dict) + + stacked_quant_state_dict = self._stack_quantization_states( + model, quant_state_dict) + + stacked_quant_state_dict = { + **expert_quant_state_dict, + **stacked_quant_state_dict + } + self._bind_quant_states_to_params(model, stacked_quant_state_dict) torch.cuda.empty_cache() def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) - - -def dequantize_dq(quant_states: dict) -> None: - """ - When BNB employs Double Quantization, we perform the dequantization of - these constants during weight loading rather than at inference time, - thereby avoiding this computational overhead during inference. This comes - at the cost of increased memory usage. - """ - from bitsandbytes.functional import QuantState, dequantize_blockwise - for _, quant_state in quant_states.items(): - # Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356 - if isinstance(quant_state, QuantState) and quant_state.nested: - absmax = dequantize_blockwise(quant_state.absmax, - quant_state.state2) - absmax += quant_state.offset - if absmax.dtype != torch.float32: - absmax = absmax.float() - quant_state.absmax = absmax - quant_state.nested = False - quant_state.offset = None - quant_state.state2 = None diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 4624ff01ddc0..2fcae7eb6e6c 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -218,16 +218,6 @@ def _xla_weights_iterator(iterator: Generator): weights_iterator = _xla_weights_iterator(weights_iterator) - elif current_platform.is_hpu(): - import habana_frameworks.torch.core as htcore - - def _hpu_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - htcore.mark_step() - - weights_iterator = _hpu_weights_iterator(weights_iterator) - if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 203c80760145..26af87c1ed67 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -6,6 +6,7 @@ import gguf import torch import torch.nn as nn +from huggingface_hub import hf_hub_download from transformers import AutoModelForCausalLM from vllm.config import LoadConfig, ModelConfig, VllmConfig @@ -32,8 +33,18 @@ def __init__(self, load_config: LoadConfig): def _prepare_weights(self, model_name_or_path: str): if os.path.isfile(model_name_or_path): return model_name_or_path + # for raw HTTPS link + if model_name_or_path.startswith( + ("http://", "https://")) and model_name_or_path.endswith(".gguf"): + return hf_hub_download(url=model_name_or_path) + # repo id/filename.gguf + if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): + repo_id, filename = model_name_or_path.rsplit("/", 1) + return hf_hub_download(repo_id=repo_id, filename=filename) else: - raise ValueError(f"{model_name_or_path} is not a file.") + raise ValueError( + f"Unrecognised GGUF reference: {model_name_or_path} " + "(expected local file, raw URL, or <repo_id>/<filename>.gguf)") def _get_gguf_weights_map(self, model_config: ModelConfig): """ diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 1c14d55fc0fe..3d491be3156b 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -5,18 +5,18 @@ import contextlib import contextvars import dataclasses -import io import json import os +import tempfile import threading import time -from collections.abc import Generator -from dataclasses import dataclass -from functools import partial -from typing import TYPE_CHECKING, Any, BinaryIO, Optional, Union +from collections.abc import Generator, MutableMapping +from dataclasses import asdict, dataclass, field, fields +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union import regex as re import torch +from huggingface_hub import snapshot_download from torch import nn from torch.utils._python_dispatch import TorchDispatchMode from transformers import PretrainedConfig @@ -27,6 +27,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, PlaceholderModule if TYPE_CHECKING: @@ -39,10 +40,6 @@ from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) - _read_stream, _write_stream = (partial( - open_stream, - mode=mode, - ) for mode in ("rb", "wb+")) except ImportError: tensorizer = PlaceholderModule("tensorizer") DecryptionParams = tensorizer.placeholder_attr("DecryptionParams") @@ -54,9 +51,6 @@ get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage") no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor") - _read_stream = tensorizer.placeholder_attr("_read_stream") - _write_stream = tensorizer.placeholder_attr("_write_stream") - __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', @@ -66,6 +60,23 @@ logger = init_logger(__name__) +def is_valid_deserialization_uri(uri: Optional[str]) -> bool: + if uri: + scheme = uri.lower().split("://")[0] + return scheme in {"s3", "http", "https"} or os.path.exists(uri) + return False + + +def tensorizer_kwargs_arg(value): + loaded = json.loads(value) + if not isinstance(loaded, dict): + raise argparse.ArgumentTypeError( + f"Not deserializable to dict: {value}. serialization_kwargs and " + f"deserialization_kwargs must be " + f"deserializable from a JSON string to a dictionary. ") + return loaded + + class MetaTensorMode(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): @@ -137,54 +148,145 @@ def wrapper(*args, **kwargs): @dataclass -class TensorizerConfig: - tensorizer_uri: Union[str, None] = None - vllm_tensorized: Optional[bool] = False - verify_hash: Optional[bool] = False +class TensorizerConfig(MutableMapping): + tensorizer_uri: Optional[str] = None + tensorizer_dir: Optional[str] = None + vllm_tensorized: Optional[bool] = None + verify_hash: Optional[bool] = None num_readers: Optional[int] = None encryption_keyfile: Optional[str] = None s3_access_key_id: Optional[str] = None s3_secret_access_key: Optional[str] = None s3_endpoint: Optional[str] = None - model_class: Optional[type[torch.nn.Module]] = None - hf_config: Optional[PretrainedConfig] = None - dtype: Optional[Union[str, torch.dtype]] = None lora_dir: Optional[str] = None - _is_sharded: bool = False + stream_kwargs: Optional[dict[str, Any]] = None + serialization_kwargs: Optional[dict[str, Any]] = None + deserialization_kwargs: Optional[dict[str, Any]] = None + _extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False, + default=None) + model_class: Optional[type[torch.nn.Module]] = field(init=False, + default=None) + hf_config: Optional[PretrainedConfig] = field(init=False, default=None) + dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None) + _is_sharded: bool = field(init=False, default=False) + _fields: ClassVar[tuple[str, ...]] + _keys: ClassVar[frozenset[str]] + """ + Args for the TensorizerConfig class. These are used to configure the + behavior of model serialization and deserialization using Tensorizer. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. This is a required field unless lora_dir is + provided and the config is meant to be used for the + `tensorize_lora_adapter` function. Unless a `tensorizer_dir` or + `lora_dir` is passed to this object's initializer, this is a required + argument. + tensorizer_dir: Path to a directory containing serialized model tensors, + and all other potential model artifacts to load the model, such as + configs and tokenizer files. Can be passed instead of `tensorizer_uri` + where the `model.tensors` file will be assumed to be in this + directory. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/others/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + lora_dir: Path to a directory containing LoRA adapter artifacts for + serialization or deserialization. When serializing LoRA adapters + this is the only necessary parameter to pass to this object's + initializer. + """ def __post_init__(self): # check if the configuration is for a sharded vLLM model self._is_sharded = isinstance(self.tensorizer_uri, str) \ and re.search(r'%0\dd', self.tensorizer_uri) is not None - if not self.tensorizer_uri and not self.lora_dir: - raise ValueError("tensorizer_uri must be provided.") - if not self.tensorizer_uri and self.lora_dir: - self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" - assert self.tensorizer_uri is not None, ("tensorizer_uri must be " - "provided.") - self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) - self.lora_dir = self.tensorizer_dir - - @classmethod - def as_dict(cls, *args, **kwargs) -> dict[str, Any]: - cfg = TensorizerConfig(*args, **kwargs) - return dataclasses.asdict(cfg) - def to_dict(self) -> dict[str, Any]: - return dataclasses.asdict(self) + if self.tensorizer_dir and self.lora_dir: + raise ValueError( + "Only one of tensorizer_dir or lora_dir may be specified. " + "Use lora_dir exclusively when serializing LoRA adapters, " + "and tensorizer_dir or tensorizer_uri otherwise.") + if self.tensorizer_dir and self.tensorizer_uri: + logger.warning_once( + "Provided both tensorizer_dir and tensorizer_uri. " + "Inferring tensorizer_dir from tensorizer_uri as the " + "latter takes precedence.") + self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) + if not self.tensorizer_uri: + if self.lora_dir: + self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors" + elif self.tensorizer_dir: + self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors" + else: + raise ValueError("Unable to resolve tensorizer_uri. " + "A valid tensorizer_uri or tensorizer_dir " + "must be provided for deserialization, and a " + "valid tensorizer_uri, tensorizer_uri, or " + "lora_dir for serialization.") + else: + self.tensorizer_dir = os.path.dirname(self.tensorizer_uri) + + if not self.serialization_kwargs: + self.serialization_kwargs = {} + if not self.deserialization_kwargs: + self.deserialization_kwargs = {} + + def to_serializable(self) -> dict[str, Any]: + # Due to TensorizerConfig needing to be msgpack-serializable, it needs + # support for morphing back and forth between itself and its dict + # representation + + # TensorizerConfig's representation as a dictionary is meant to be + # linked to TensorizerConfig in such a way that the following is + # technically initializable: + # TensorizerConfig(**my_tensorizer_cfg.to_serializable()) + + # This means the dict must not retain non-initializable parameters + # and post-init attribute states + + # Also don't want to retain private and unset parameters, so only retain + # not None values and public attributes + + raw_tc_dict = asdict(self) + blacklisted = [] + + if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict: + blacklisted.append("tensorizer_dir") + + if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict: + blacklisted.append("tensorizer_dir") + + tc_dict = {} + for k, v in raw_tc_dict.items(): + if (k not in blacklisted and k not in tc_dict + and not k.startswith("_") and v is not None): + tc_dict[k] = v + + return tc_dict def _construct_tensorizer_args(self) -> "TensorizerArgs": - tensorizer_args = { - "tensorizer_uri": self.tensorizer_uri, - "vllm_tensorized": self.vllm_tensorized, - "verify_hash": self.verify_hash, - "num_readers": self.num_readers, - "encryption_keyfile": self.encryption_keyfile, - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, - } - return TensorizerArgs(**tensorizer_args) # type: ignore + return TensorizerArgs(self) # type: ignore def verify_with_parallel_config( self, @@ -209,81 +311,76 @@ def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): tensorizer_args = self._construct_tensorizer_args() return open_stream(self.tensorizer_uri, - **tensorizer_args.stream_params) + **tensorizer_args.stream_kwargs) + + def keys(self): + return self._keys + + def __len__(self): + return len(fields(self)) + + def __iter__(self): + return iter(self._fields) + + def __getitem__(self, item: str) -> Any: + if item not in self.keys(): + raise KeyError(item) + return getattr(self, item) + + def __setitem__(self, key: str, value: Any) -> None: + if key not in self.keys(): + # Disallow modifying invalid keys + raise KeyError(key) + setattr(self, key, value) + + def __delitem__(self, key, /): + if key not in self.keys(): + raise KeyError(key) + delattr(self, key) + + +TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig)) +TensorizerConfig._keys = frozenset(TensorizerConfig._fields) @dataclass class TensorizerArgs: - tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, - bytes, os.PathLike, int] - vllm_tensorized: Optional[bool] = False - verify_hash: Optional[bool] = False - num_readers: Optional[int] = None + tensorizer_uri: Optional[str] = None + tensorizer_dir: Optional[str] = None encryption_keyfile: Optional[str] = None - s3_access_key_id: Optional[str] = None - s3_secret_access_key: Optional[str] = None - s3_endpoint: Optional[str] = None - """ - Args for the TensorizerAgent class. These are used to configure the behavior - of the TensorDeserializer when loading tensors from a serialized model. - - Args: - tensorizer_uri: Path to serialized model tensors. Can be a local file - path or a S3 URI. This is a required field unless lora_dir is - provided and the config is meant to be used for the - `tensorize_lora_adapter` function. - vllm_tensorized: If True, indicates that the serialized model is a - vLLM model. This is used to determine the behavior of the - TensorDeserializer when loading tensors from a serialized model. - It is far faster to deserialize a vLLM model as it utilizes - tensorizer's optimized GPU loading. Note that this is now - deprecated, as serialized vLLM models are now automatically - inferred as vLLM models. - verify_hash: If True, the hashes of each tensor will be verified against - the hashes stored in the metadata. A `HashMismatchError` will be - raised if any of the hashes do not match. - num_readers: Controls how many threads are allowed to read concurrently - from the source file. Default is `None`, which will dynamically set - the number of readers based on the number of available - resources and model size. This greatly increases performance. - encryption_keyfile: File path to a binary file containing a - binary key to use for decryption. `None` (the default) means - no decryption. See the example script in - examples/others/tensorize_vllm_model.py. - s3_access_key_id: The access key for the S3 bucket. Can also be set via - the S3_ACCESS_KEY_ID environment variable. - s3_secret_access_key: The secret access key for the S3 bucket. Can also - be set via the S3_SECRET_ACCESS_KEY environment variable. - s3_endpoint: The endpoint for the S3 bucket. Can also be set via the - S3_ENDPOINT_URL environment variable. - """ - def __post_init__(self): - self.file_obj = self.tensorizer_uri - self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID - self.s3_secret_access_key = (self.s3_secret_access_key + def __init__(self, tensorizer_config: TensorizerConfig): + for k, v in tensorizer_config.items(): + setattr(self, k, v) + self.file_obj = tensorizer_config.tensorizer_uri + self.s3_access_key_id = (tensorizer_config.s3_access_key_id + or envs.S3_ACCESS_KEY_ID) + self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY) - self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL - self.stream_params = { - "s3_access_key_id": self.s3_access_key_id, - "s3_secret_access_key": self.s3_secret_access_key, - "s3_endpoint": self.s3_endpoint, + self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL + + self.stream_kwargs = { + "s3_access_key_id": tensorizer_config.s3_access_key_id, + "s3_secret_access_key": tensorizer_config.s3_secret_access_key, + "s3_endpoint": tensorizer_config.s3_endpoint, + **(tensorizer_config.stream_kwargs or {}) } - self.deserializer_params = { - "verify_hash": self.verify_hash, - "encryption": self.encryption_keyfile, - "num_readers": self.num_readers + self.deserialization_kwargs = { + "verify_hash": tensorizer_config.verify_hash, + "encryption": tensorizer_config.encryption_keyfile, + "num_readers": tensorizer_config.num_readers, + **(tensorizer_config.deserialization_kwargs or {}) } if self.encryption_keyfile: with open_stream( - self.encryption_keyfile, - **self.stream_params, + tensorizer_config.encryption_keyfile, + **self.stream_kwargs, ) as stream: key = stream.read() decryption_params = DecryptionParams.from_key(key) - self.deserializer_params['encryption'] = decryption_params + self.deserialization_kwargs['encryption'] = decryption_params @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -405,15 +502,24 @@ def init_tensorizer_model(tensorizer_config: TensorizerConfig, def deserialize_tensorizer_model(model: nn.Module, tensorizer_config: TensorizerConfig) -> None: tensorizer_args = tensorizer_config._construct_tensorizer_args() + if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri): + raise ValueError( + f"{tensorizer_config.tensorizer_uri} is not a valid " + f"tensorizer URI. Please check that the URI is correct. " + f"It must either point to a local existing file, or have a " + f"S3, HTTP or HTTPS scheme.") before_mem = get_mem_usage() start = time.perf_counter() - with _read_stream( + with open_stream( tensorizer_config.tensorizer_uri, - **tensorizer_args.stream_params) as stream, TensorDeserializer( + mode="rb", + **tensorizer_args.stream_kwargs) as stream, TensorDeserializer( stream, dtype=tensorizer_config.dtype, - device=f'cuda:{torch.cuda.current_device()}', - **tensorizer_args.deserializer_params) as deserializer: + device=f'xpu:{torch.xpu.current_device()}' + if current_platform.is_xpu() else + f'cuda:{torch.cuda.current_device()}', + **tensorizer_args.deserialization_kwargs) as deserializer: deserializer.load_into_module(model) end = time.perf_counter() @@ -442,9 +548,9 @@ def tensorizer_weights_iterator( "examples/others/tensorize_vllm_model.py example script " "for serializing vLLM models.") - deserializer_args = tensorizer_args.deserializer_params - stream_params = tensorizer_args.stream_params - stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + deserializer_args = tensorizer_args.deserialization_kwargs + stream_kwargs = tensorizer_args.stream_kwargs + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs) with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: yield from state.items() @@ -465,8 +571,8 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: """ tensorizer_args = tensorizer_config._construct_tensorizer_args() deserializer = TensorDeserializer(open_stream( - tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), - **tensorizer_args.deserializer_params, + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs), + **tensorizer_args.deserialization_kwargs, lazy_load=True) if tensorizer_config.vllm_tensorized: logger.warning( @@ -477,13 +583,41 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: return ".vllm_tensorized_marker" in deserializer +def serialize_extra_artifacts( + tensorizer_args: TensorizerArgs, + served_model_name: Union[str, list[str], None]) -> None: + if not isinstance(served_model_name, str): + raise ValueError( + f"served_model_name must be a str for serialize_extra_artifacts, " + f"not {type(served_model_name)}.") + + with tempfile.TemporaryDirectory() as tmpdir: + snapshot_download(served_model_name, + local_dir=tmpdir, + ignore_patterns=[ + "*.pt", "*.safetensors", "*.bin", "*.cache", + "*.gitattributes", "*.md" + ]) + for artifact in os.scandir(tmpdir): + if not artifact.is_file(): + continue + with open(artifact.path, "rb") as f, open_stream( + f"{tensorizer_args.tensorizer_dir}/{artifact.name}", + mode="wb+", + **tensorizer_args.stream_kwargs) as stream: + logger.info("Writing artifact %s", artifact.name) + stream.write(f.read()) + + def serialize_vllm_model( model: nn.Module, tensorizer_config: TensorizerConfig, + model_config: "ModelConfig", ) -> nn.Module: model.register_parameter( "vllm_tensorized_marker", nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + tensorizer_args = tensorizer_config._construct_tensorizer_args() encryption_params = None @@ -497,10 +631,16 @@ def serialize_vllm_model( from vllm.distributed import get_tensor_model_parallel_rank output_file = output_file % get_tensor_model_parallel_rank() - with _write_stream(output_file, **tensorizer_args.stream_params) as stream: - serializer = TensorSerializer(stream, encryption=encryption_params) + with open_stream(output_file, mode="wb+", + **tensorizer_args.stream_kwargs) as stream: + serializer = TensorSerializer(stream, + encryption=encryption_params, + **tensorizer_config.serialization_kwargs) serializer.write_module(model) serializer.close() + + serialize_extra_artifacts(tensorizer_args, model_config.served_model_name) + logger.info("Successfully serialized model to %s", str(output_file)) return model @@ -522,8 +662,9 @@ def tensorize_vllm_model(engine_args: "EngineArgs", if generate_keyfile and (keyfile := tensorizer_config.encryption_keyfile) is not None: encryption_params = EncryptionParams.random() - with _write_stream( + with open_stream( keyfile, + mode="wb+", s3_access_key_id=tensorizer_config.s3_access_key_id, s3_secret_access_key=tensorizer_config.s3_secret_access_key, s3_endpoint=tensorizer_config.s3_endpoint, @@ -537,13 +678,13 @@ def tensorize_vllm_model(engine_args: "EngineArgs", engine = LLMEngine.from_engine_args(engine_args) engine.model_executor.collective_rpc( "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, ) else: engine = V1LLMEngine.from_vllm_config(engine_config) engine.collective_rpc( "save_tensorized_model", - kwargs=dict(tensorizer_config=tensorizer_config), + kwargs={"tensorizer_config": tensorizer_config.to_serializable()}, ) @@ -554,7 +695,7 @@ def tensorize_lora_adapter(lora_path: str, needed to load a LoRA adapter are a safetensors-format file called adapter_model.safetensors and a json config file called adapter_config.json. - Serializes the files in the tensorizer_config.lora_dir + Serializes the files in the tensorizer_config.tensorizer_dir """ import safetensors @@ -584,19 +725,19 @@ def tensorize_lora_adapter(lora_path: str, tensorizer_args = tensorizer_config._construct_tensorizer_args() - with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json", + with open_stream(f"{tensorizer_config.tensorizer_dir}/adapter_config.json", mode="wb+", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: f.write(json.dumps(config).encode("utf-8")) - lora_uri = (f"{tensorizer_config.lora_dir}" + lora_uri = (f"{tensorizer_config.tensorizer_dir}" f"/adapter_model.tensors") with open_stream(lora_uri, mode="wb+", - **tensorizer_args.stream_params) as f: + **tensorizer_args.stream_kwargs) as f: serializer = TensorSerializer(f) serializer.write_state_dict(tensors) serializer.close() logger.info("Successfully serialized LoRA files to %s", - str(tensorizer_config.lora_dir)) + str(tensorizer_config.tensorizer_dir)) diff --git a/vllm/model_executor/model_loader/tensorizer_loader.py b/vllm/model_executor/model_loader/tensorizer_loader.py index 0b62e744e445..fa01758ab4ce 100644 --- a/vllm/model_executor/model_loader/tensorizer_loader.py +++ b/vllm/model_executor/model_loader/tensorizer_loader.py @@ -20,6 +20,18 @@ logger = init_logger(__name__) +BLACKLISTED_TENSORIZER_ARGS = { + "device", # vLLM decides this + "dtype", # vLLM decides this + "mode", # Not meant to be configurable by the user +} + + +def validate_config(config: dict): + for k, v in config.items(): + if v is not None and k in BLACKLISTED_TENSORIZER_ARGS: + raise ValueError(f"{k} is not an allowed Tensorizer argument.") + class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" @@ -29,8 +41,9 @@ def __init__(self, load_config: LoadConfig): if isinstance(load_config.model_loader_extra_config, TensorizerConfig): self.tensorizer_config = load_config.model_loader_extra_config else: + validate_config(load_config.model_loader_extra_config) self.tensorizer_config = TensorizerConfig( - **load_config.model_loader_extra_config) + **load_config.model_loader_extra_config["tensorizer_config"]) def _verify_config(self, model_config: ModelConfig, parallel_config: ParallelConfig): @@ -118,10 +131,12 @@ def load_model(self, vllm_config: VllmConfig, def save_model( model: torch.nn.Module, tensorizer_config: Union[TensorizerConfig, dict], + model_config: ModelConfig, ) -> None: if isinstance(tensorizer_config, dict): tensorizer_config = TensorizerConfig(**tensorizer_config) serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, + model_config=model_config, ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 792a1044a564..4b30336f0132 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -22,8 +22,11 @@ QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_embedding_model, - as_reward_model) + as_reward_model, + as_seq_cls_model) from vllm.model_executor.models.interfaces import SupportsQuant +from vllm.model_executor.models.registry import (_PREVIOUSLY_SUPPORTED_MODELS, + _TRANSFORMERS_MODELS) from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -168,9 +171,22 @@ def device_loading_context(module: torch.nn.Module, def resolve_transformers_arch(model_config: ModelConfig, architectures: list[str]): + if model_config.model_impl == ModelImpl.VLLM: + raise ValueError( + "Attempting to resolve architecture from the Transformers library " + "but the model implementation is set to vLLM. This should never " + "happen.") + for i, arch in enumerate(architectures): - if arch == "TransformersForCausalLM": + if arch in _TRANSFORMERS_MODELS: continue + + if model_config.model_impl == ModelImpl.AUTO: + logger.warning( + "%s has no vLLM implementation, falling back to Transformers " + "implementation. Some features may not be supported and " + "performance may not be optimal.", arch) + auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", None) or dict() # Make sure that config class is always initialized before model class, @@ -198,25 +214,13 @@ def resolve_transformers_arch(model_config: ModelConfig, "not present in the model config's 'auto_map' (relevant " "if the model is custom).") model_module = auto_modules["AutoModel"] - # TODO(Isotr0py): Further clean up these raises. - # perhaps handled them in _ModelRegistry._raise_for_unsupported? - if model_config.model_impl == ModelImpl.TRANSFORMERS: - if not model_module.is_backend_compatible(): - raise ValueError( - f"The Transformers implementation of {arch} is not " - "compatible with vLLM.") - architectures[i] = "TransformersForCausalLM" - if model_config.model_impl == ModelImpl.AUTO: - if not model_module.is_backend_compatible(): - raise ValueError( - f"{arch} has no vLLM implementation and the Transformers " - "implementation is not compatible with vLLM. Try setting " - "VLLM_USE_V1=0.") - logger.warning( - "%s has no vLLM implementation, falling back to Transformers " - "implementation. Some features may not be supported and " - "performance may not be optimal.", arch) - architectures[i] = "TransformersForCausalLM" + + if not model_module.is_backend_compatible(): + raise ValueError( + f"The Transformers implementation of '{arch}' is not " + "compatible with vLLM.") + + architectures[i] = model_config._get_transformers_backend_cls() return architectures @@ -227,15 +231,49 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" + "fp8", + "compressed-tensors", + "gptq_marlin", + "awq_marlin", + "quark", + "bitsandbytes", ] vllm_supported_archs = ModelRegistry.get_supported_archs() - vllm_not_supported = not any(arch in vllm_supported_archs - for arch in architectures) + is_supported = lambda arch: (arch in vllm_supported_archs and arch not in + _TRANSFORMERS_MODELS) + vllm_not_supported = not any(is_supported(arch) for arch in architectures) + + if vllm_not_supported: + # try automatic conversion in adapters.py + for arch in architectures: + if not arch.endswith("ForSequenceClassification"): + continue + + assert model_config.task == "classify" + causal_lm_arch = arch.replace("ForSequenceClassification", + "ForCausalLM") + causal_lm_arch_vllm_supported = (causal_lm_arch + in vllm_supported_archs) + if not causal_lm_arch_vllm_supported: + continue + + architectures = [causal_lm_arch] + vllm_not_supported = False + break + + if any(arch in _PREVIOUSLY_SUPPORTED_MODELS for arch in architectures): + previous_version = _PREVIOUSLY_SUPPORTED_MODELS[architectures[0]] + raise ValueError( + f"Model architecture {architectures[0]} was supported" + f" in vLLM until version {previous_version}, and is " + "not supported anymore. Please use an older version" + " of vLLM if you want to use this model architecture.") + if (model_config.model_impl == ModelImpl.TRANSFORMERS or - model_config.model_impl != ModelImpl.VLLM and vllm_not_supported): + model_config.model_impl == ModelImpl.AUTO and vllm_not_supported): architectures = resolve_transformers_arch(model_config, architectures) + logger.debug_once("Resolve transformers arch %s", str(architectures)) elif (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): @@ -243,12 +281,13 @@ def get_model_architecture( model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": + logger.debug_once("Automatic conversion using `as_embedding_model`.") model_cls = as_embedding_model(model_cls) elif model_config.task == "classify": - # Cannot automatically run as_seq_cls_model, - # otherwise it will cause a circular reference on is_cross_encoder_model - pass + logger.debug_once("Automatic conversion using `as_seq_cls_model`.") + model_cls = as_seq_cls_model(model_cls) elif model_config.task == "reward": + logger.debug_once("Automatic conversion using `as_reward_model`.") model_cls = as_reward_model(model_cls) return model_cls, arch diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 857f4bca6824..074126fa669e 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -14,7 +14,6 @@ from typing import Any, Callable, Optional, Union import filelock -import gguf import huggingface_hub.constants import numpy as np import torch @@ -40,6 +39,11 @@ SafetensorsStreamer = runai_model_streamer.placeholder_attr( "SafetensorsStreamer") +try: + import gguf +except ImportError: + gguf = PlaceholderModule("gguf") + try: from fastsafetensors import SafeTensorsFileLoader, SingleGroup except ImportError: @@ -148,8 +152,8 @@ def get_quant_config(model_config: ModelConfig, quant_cls = get_quantization_config(model_config.quantization) # GGUF doesn't have config file - if model_config.quantization == "gguf": - return quant_cls.from_config({}) + if model_config.quantization in ("gguf", "inc"): + return quant_cls() # Read the quantization config from the HF model config, if available. hf_quant_config = getattr(model_config.hf_config, "quantization_config", @@ -478,14 +482,20 @@ def runai_safetensors_weights_iterator( ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" with SafetensorsStreamer() as streamer: - for st_file in tqdm( - hf_weights_files, - desc="Loading safetensors using Runai Model Streamer", - disable=not enable_tqdm(use_tqdm_on_load), - bar_format=_BAR_FORMAT, - ): - streamer.stream_file(st_file) - yield from streamer.get_tensors() + streamer.stream_files(hf_weights_files) + total_tensors = sum( + len(tensors_meta) + for tensors_meta in streamer.files_to_tensors_metadata.values()) + + tensor_iter = tqdm( + streamer.get_tensors(), + total=total_tensors, + desc="Loading safetensors using Runai Model Streamer", + bar_format=_BAR_FORMAT, + disable=not enable_tqdm(use_tqdm_on_load), + ) + + yield from tensor_iter def fastsafetensors_weights_iterator( @@ -758,6 +768,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: modelopt_scale_names = [ ".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale" ] + # Also support qkv_proj scale parameters (from stacked parameter processing) + qkv_proj_scale_names = [ + ".self_attn.qkv_proj.k_scale", ".self_attn.qkv_proj.v_scale" + ] for scale_name in possible_scale_names: if name.endswith(scale_name): if any(mo_scale_name in name @@ -765,6 +779,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: remapped_name = name.replace( f".self_attn.{scale_name[1]}_proj{scale_name}", f".self_attn.attn{scale_name}") + elif any(qkv_scale_name in name + for qkv_scale_name in qkv_proj_scale_names): + # Handle qkv_proj scale parameters + remapped_name = name.replace( + f".self_attn.qkv_proj{scale_name}", + f".self_attn.attn{scale_name}") else: remapped_name = name.replace(scale_name, f".attn{scale_name}") if remapped_name not in params_dict: diff --git a/vllm/model_executor/models/adapters.py b/vllm/model_executor/models/adapters.py index 78d86f6f2044..867de2c68b4c 100644 --- a/vllm/model_executor/models/adapters.py +++ b/vllm/model_executor/models/adapters.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypeVar, cast import torch import torch.nn as nn @@ -13,7 +13,6 @@ if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.model_executor.layers.pooler import PoolingType _T = TypeVar("_T", bound=type[nn.Module]) @@ -34,21 +33,14 @@ def _get_pooling_model_name(orig_model_name: str, pooling_suffix: str) -> str: return model_name + pooling_suffix -def _create_pooling_model_cls( - orig_cls: _T, - *, - default_pooling_type: "PoolingType", - default_normalize: bool, - default_softmax: bool, -) -> _T: +def _create_pooling_model_cls(orig_cls: _T) -> _T: # Lazy import - from vllm.model_executor.layers.pooler import Pooler, PoolerOutput - from vllm.model_executor.pooling_metadata import PoolingMetadata - from .utils import AutoWeightsLoader, WeightsMapper class ModelForPooling(orig_cls, VllmModelForPooling): + is_pooling_model = True + def __init__( self, *, @@ -58,29 +50,19 @@ def __init__( ) -> None: super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) + self.vllm_config = vllm_config + # These are not used in pooling models for attr in ("lm_head", "logits_processor"): if hasattr(self, attr): delattr(self, attr) - pooler_config = vllm_config.model_config.pooler_config - assert pooler_config is not None - # If the model already defines a pooler instance, don't overwrite it - if not getattr(self, "_pooler", None): - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=default_pooling_type, - normalize=default_normalize, - softmax=default_softmax, - ) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - return self._pooler(hidden_states, pooling_metadata) + if not getattr(self, "pooler", None): + self._init_pooler(vllm_config, prefix=prefix) + + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + raise NotImplementedError def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): # TODO: Support uninitialized params tracking @@ -133,14 +115,20 @@ def as_embedding_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.pooler import PoolingType - - ModelForEmbedding = _create_pooling_model_cls( - cls, - default_pooling_type=PoolingType.LAST, - default_normalize=True, - default_softmax=False, - ) + from vllm.model_executor.layers.pooler import DispatchPooler, Pooler + + class ModelForEmbedding(_create_pooling_model_cls(cls)): + + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + { + "encode": Pooler.for_encode(pooler_config), + "embed": Pooler.for_embed(pooler_config), + }, ) + ModelForEmbedding.__name__ = \ _get_pooling_model_name(cls.__name__, "ForEmbedding") @@ -165,47 +153,60 @@ def as_seq_cls_model(cls: _T) -> _T: # Lazy import from vllm.model_executor.layers.linear import RowParallelLinear - from vllm.model_executor.layers.pooler import PoolerOutput, PoolingType + from vllm.model_executor.layers.pooler import (ClassifierPooler, + DispatchPooler, Pooler, + PoolingMethod, PoolingType) from vllm.model_executor.models.interfaces import SupportsCrossEncoding - from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors from .utils import maybe_prefix - ModelForPooling = _create_pooling_model_cls( - cls, - default_pooling_type=PoolingType.LAST, - default_normalize=False, - default_softmax=True, - ) - - class ModelForSequenceClassification(ModelForPooling, + class ModelForSequenceClassification(_create_pooling_model_cls(cls), SupportsCrossEncoding): - def __init__( - self, - *, - vllm_config: "VllmConfig", - prefix: str = "", - **kwargs: Any, - ) -> None: - super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config - self.vllm_config = vllm_config - self.task = vllm_config.model_config.task - self.pooling_type = ( - vllm_config.model_config.pooler_config.pooling_type) - - self.score = RowParallelLinear(config.hidden_size, - config.num_labels, - quant_config=quant_config, - input_is_parallel=False, - bias=False, - prefix=maybe_prefix( - prefix, "score")) + self.score = RowParallelLinear( + config.hidden_size, + config.num_labels, + input_is_parallel=False, + bias=False, + params_dtype=torch.float32, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "score"), + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + pooling_type_str = pooler_config.pooling_type + pooling_type = (PoolingType.LAST if pooling_type_str is None else + PoolingType[pooling_type_str]) + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=PoolingMethod.from_pooling_type(pooling_type), + classifier=self._classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) + + def _classifier(self, x: torch.Tensor): + x, _ = self.score(x.float()) + return x def forward( self, @@ -217,33 +218,6 @@ def forward( return super().forward(input_ids, positions, intermediate_tensors, inputs_embeds) - def pooler( - self, - hidden_states: Union[torch.Tensor, list[torch.Tensor]], - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - - def get_logits(hidden_states): - if isinstance(hidden_states, list): - logits = [self.score(state)[0] for state in hidden_states] - else: - logits, _ = self.score(hidden_states) - return logits - - if self.pooling_type == PoolingType.ALL: - logits = get_logits(hidden_states) - return self._pooler(logits, pooling_metadata) - else: - hidden_states = self._pooler.extract_states( - hidden_states, pooling_metadata) - logits = get_logits(hidden_states) - pooled_data = self._pooler.head(logits, pooling_metadata) - - pooled_outputs = [ - self._pooler.build_output(data) for data in pooled_data - ] - return PoolerOutput(outputs=pooled_outputs) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): tokens = getattr(self.config, "classifier_from_token", None) method = getattr(self.config, "method", None) @@ -277,14 +251,16 @@ def as_reward_model(cls: _T) -> _T: return cls # Lazy import - from vllm.model_executor.layers.pooler import PoolingType + from vllm.model_executor.layers.pooler import DispatchPooler, Pooler + + class ModelForReward(_create_pooling_model_cls(cls)): - ModelForReward = _create_pooling_model_cls( - cls, - default_pooling_type=PoolingType.ALL, - default_normalize=False, - default_softmax=False, - ) + def _init_pooler(self, vllm_config: "VllmConfig", prefix: str = ""): + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) ModelForReward.__name__ = \ _get_pooling_model_name(cls.__name__, "ForReward") @@ -312,12 +288,18 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: else: config.num_labels = len(tokens) + # `llm as reranker` defaults to not using pad_token + use_pad_token = getattr(config, "use_pad_token", False) + config.use_pad_token = use_pad_token + def load_weights_using_from_2_way_softmax( model, weights: Iterable[tuple[str, torch.Tensor]]): # refer to https://huggingface.co/Qwen/Qwen3-Reranker-0.6B/discussions/3 from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead) + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) from vllm.model_executor.models.utils import AutoWeightsLoader model_config = model.vllm_config.model_config @@ -325,8 +307,6 @@ def load_weights_using_from_2_way_softmax( tokens = cast(list[int], tokens) assert len(tokens) == 2 - device = model.score.weight.device - if model.config.tie_word_embeddings: model.lm_head = model.model.embed_tokens else: @@ -345,10 +325,56 @@ def load_weights_using_from_2_way_softmax( false_id = tokenizer.convert_tokens_to_ids(tokens[0]) true_id = tokenizer.convert_tokens_to_ids(tokens[1]) - weight = model.lm_head.weight.data[true_id].to(device).to( - torch.float32) - model.lm_head.weight.data[false_id].to(device).to( + score_weight = model.lm_head.weight.data[[true_id]].to( + torch.float32) - model.lm_head.weight.data[[false_id]].to( torch.float32) - model.score.weight.data.copy_(weight) + + param = model.score.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, score_weight) + + del model.lm_head + loaded_weights.add("score.weight") + loaded_weights.discard("lm_head.weight") + return loaded_weights + + +def load_weights_no_post_processing(model, + weights: Iterable[tuple[str, + torch.Tensor]]): + from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead) + from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader) + from vllm.model_executor.models.utils import AutoWeightsLoader + + model_config = model.vllm_config.model_config + tokens = getattr(model.config, "classifier_from_token", []) + tokens = cast(list[int], tokens) + assert len(tokens) > 0 + + if model.config.tie_word_embeddings: + model.lm_head = model.model.embed_tokens + else: + model.lm_head = ParallelLMHead(model.config.vocab_size, + model.config.hidden_size, + quant_config=model.quant_config) + + loader = AutoWeightsLoader(model) + loaded_weights = loader.load_weights(weights) + + from vllm.transformers_utils.tokenizer import get_tokenizer + tokenizer = get_tokenizer(model_config.tokenizer, + revision=model_config.tokenizer_revision, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code) + + token_ids = [tokenizer.convert_tokens_to_ids(t) for t in tokens] + score_weight = model.lm_head.weight.data[token_ids] + + param = model.score.weight + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, score_weight) del model.lm_head loaded_weights.add("score.weight") @@ -358,6 +384,7 @@ def load_weights_using_from_2_way_softmax( SEQ_CLS_LOAD_METHODS = { "from_2_way_softmax": load_weights_using_from_2_way_softmax, + "no_post_processing": load_weights_no_post_processing, } @@ -368,6 +395,9 @@ def seq_cls_model_loader(model, weights: Iterable[tuple[str, torch.Tensor]]): # - Qwen3-Reranker # - Qwen2ForCausalLM # - mxbai-rerank-v2 + # - no_post_processing: + # - GemmaForCausalLM + # - bge-reranker-v2-gemma config = model.vllm_config.model_config.hf_config method = getattr(config, "method", None) diff --git a/vllm/model_executor/models/arcee.py b/vllm/model_executor/models/arcee.py new file mode 100644 index 000000000000..4e3ba107ba7e --- /dev/null +++ b/vllm/model_executor/models/arcee.py @@ -0,0 +1,347 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2023-2025 vLLM Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# +# Inference-only Arcee (AFM) model – adds support for ReLU^2 feed-forward +# activation. + +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.activation import ReLUSquaredActivation +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, + make_empty_intermediate_tensors_factory, make_layers) + + +class ArceeMLP(nn.Module): + """Feed-forward layer for Arcee using ReLU^2 activation + (no gating as in LLaMA).""" + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[Any] = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True) -> None: + super().__init__() + # Single linear projection up to intermediate size + # (no separate gate projection) + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + # Down projection back to hidden size + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "relu2": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only 'relu2' is supported for AFM.") + # Define ReLU^2 activation: (ReLU(x))^2 elementwise + self.act_fn = ReLUSquaredActivation() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.up_proj(x) # Project to intermediate size + x = self.act_fn(x) # Apply ReLU^2 activation elementwise + x, _ = self.down_proj(x) # Project back down to hidden size + return x + + +class ArceeDecoderLayer(nn.Module): + """Transformer decoder block for Arcee, with self-attention and + ReLU^2 MLP.""" + + def __init__(self, + config: LlamaConfig, + cache_config: Optional[Any] = None, + quant_config: Optional[Any] = None, + prefix: str = "") -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Rotary embedding parameters (reuse LLaMA defaults) + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Determine if attention bias is needed (some variants use bias terms) + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + bias_o_proj = attention_bias + if hasattr(config, "qkv_bias"): + attention_bias = config.qkv_bias + + # Self-Attention (using LLaMA's attention structure) + from vllm.model_executor.models.llama import ( + LlamaAttention) # import here to avoid circular import + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + bias_o_proj=bias_o_proj, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + attn_type=getattr( + config, "attn_type", + "decoder"), # assume decoder (causal) unless specified + ) + # MLP with ReLU^2 activation + self.mlp = ArceeMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + # Layer normalization layers (RMSNorm as in LLaMA) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self-Attention block + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + # Fused residual add + layernorm if supported + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + # Feed-forward block + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class ArceeModel(nn.Module): + """The transformer model backbone for Arcee (embedding layer + stacked + decoder blocks + final norm).""" + + def __init__(self, + *, + vllm_config, + prefix: str = "", + layer_type: type[nn.Module] = ArceeDecoderLayer) -> None: + super().__init__() + config: LlamaConfig = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.org_vocab_size = config.vocab_size + + # Word embeddings (parallelized if using pipeline parallel) + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer( + ) # placeholder on non-embedding ranks + + # Build decoder layers across pipeline ranks + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + # Final RMSNorm on the last pipeline stage + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + # For optional capturing of intermediate hidden states + # (not used by default) + self.aux_hidden_state_layers: tuple[int, ...] = tuple() + + # Prepare factory for empty intermediate tensors + # (for pipeline scheduling) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor, + list[torch.Tensor]]]: + # Embedding lookup (on first pipeline rank) + if get_pp_group().is_first_rank: + hidden_states = (inputs_embeds if inputs_embeds is not None else + self.get_input_embeddings(input_ids)) + residual = None + else: + assert intermediate_tensors is not None, ( + "IntermediateTensors must be provided for non-first " + "pipeline ranks") + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + aux_hidden_states: list[torch.Tensor] = [] + for idx, layer in enumerate( + self.layers[self.start_layer:self.end_layer]): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append( + hidden_states + + residual) # capture pre-layer hidden state if needed + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + # Send intermediate results to the next pipeline stage + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + # On last rank: apply final layer norm + hidden_states, _ = self.norm(hidden_states, residual) + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states + + +class ArceeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + """Arcee Model for causal language modeling, integrated with vLLM + runtime.""" + # Map fused module names to their sub-module components + # (for quantization and LoRA) + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + } + + def __init__(self, *, vllm_config, prefix: str = "") -> None: + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + + # Initialize the inner Transformer model (ArceeModel) + self.model = ArceeModel(vllm_config=vllm_config, + prefix=f"{prefix}.model") + # On the last pipeline stage, set up the LM head and logits processor + if get_pp_group().is_last_rank: + # Determine vocabulary size (including any LoRA extra tokens + # for padded LM head) + self.unpadded_vocab_size = config.vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=vllm_config.quant_config, + bias=getattr(config, "lm_head_bias", False), + prefix=f"{prefix}.lm_head", + ) + if config.tie_word_embeddings: + # Tie output weights with input embedding matrix + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + # Placeholder for lm_head on non-last ranks + self.lm_head = PPMissingLayer() + # Provide a reference to the model's method for generating empty + # tensors (used in pipeline parallel schedule) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, IntermediateTensors]: + # Forward pass through the Arcee model backbone + model_output = self.model(input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return model_output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata) -> Optional[torch.Tensor]: + # Compute final logits from hidden states (last pipeline rank only) + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """Load weights into the model (delegates to inner model and handles + tied embeddings).""" + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + skip_substrs=["gate_proj"]) + # AutoWeightLoader handles weight name remapping, including fusing + # separate q_proj, k_proj, v_proj into qkv_proj + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/bailing_moe.py b/vllm/model_executor/models/bailing_moe.py new file mode 100644 index 000000000000..853c13b135ea --- /dev/null +++ b/vllm/model_executor/models/bailing_moe.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/inclusionAI/Ling/blob/master/models/modeling_bailing_moe.py +# Copyright 2023 The vLLM team. +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BailingMoE model compatible with HuggingFace weights.""" +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class BailingAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_key_value_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.total_num_heads % tp_size == 0 + assert self.total_kv_heads % tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = config.head_dim or (self.hidden_size // + self.total_num_heads) + self.q_size_per_rank = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // tp_size + self.kv_size_per_rank = self.num_kv_heads * self.head_dim + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_kv_heads, + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + prefix=f"{prefix}.attn") + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + is_neox_style=True, + rope_scaling=config.rope_scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + ) -> torch.Tensor: + + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([ + self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank + ], + dim=-1) + + q, k = self.rotary_emb(position_ids, q, k) + + context_layer = self.attn(q, k, v) + + attn_output, _ = self.dense(context_layer) + return attn_output + + +class BailingMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class BailingMoE(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + ): + super().__init__() + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_expert_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts") + + if self.num_shared_experts > 0: + intermediate_size = (config.moe_intermediate_size * + self.num_shared_experts) + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts") + else: + self.shared_experts = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + +class BailingMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.attention = BailingAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") + self.post_attention_layernorm = RMSNorm(hidden_size, + eps=config.rms_norm_eps) + self.mlp = BailingMoE(intermediate_size, + config, + quant_config, + True, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.attention( + hidden_states=hidden_states, + position_ids=position_ids, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BailingMoeModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, self.embed_dim) + else: + self.word_embeddings = PPMissingLayer() + + self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BailingMoeBlock( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers") + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + hidden_states, + position_ids, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() + for name, loaded_weight in weights: + if self.config.norm_head and "lm_head.weight" in name: + loaded_weight = F.normalize(loaded_weight, + dim=0, + p=2, + eps=1e-7) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class BailingMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + self.max_position_embeddings = config.max_position_embeddings + self.model = BailingMoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = (self.word_embeddings if config.tie_word_embeddings + else ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config)) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index d743c52074c6..0f5494427634 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -11,8 +11,9 @@ from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -23,8 +24,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -99,8 +100,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size) + prefix=f"{prefix}.mixer") self.feed_forward = BambaMLP(config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, @@ -123,11 +123,10 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params, - mamba2_metadata) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -170,7 +169,7 @@ def __init__( self.max_position_embeddings = max_position_embeddings if hasattr(config, "partial_rotary_factor"): - rotary_dim = self.head_dim * config.partial_rotary_factor + rotary_dim = int(self.head_dim * config.partial_rotary_factor) elif hasattr(config, "attn_rotary_emb"): rotary_dim = config.attn_rotary_emb # for backward compatibility else: @@ -259,6 +258,7 @@ def forward( } +@support_torch_compile class BambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -436,6 +436,38 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config @@ -492,10 +524,13 @@ def forward(self, self.vllm_config.parallel_config, LayerBlockType.mamba ) - - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -511,38 +546,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.mamba_expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a0ec12674f19..3d328c88ff6e 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -46,7 +46,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsQuant, SupportsV0Only -from .utils import maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix logger = logging.get_logger(__name__) @@ -700,7 +700,8 @@ def forward( class BartModel(nn.Module, SupportsQuant): _tied_weights_keys = [ - "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", ] def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -763,10 +764,54 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, return decoder_outputs + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + other_weights = [] + loaded_stacked_params = [] + model_params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in model_params_dict: + continue + param = model_params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_stacked_params.append(name) + break + else: + if name in model_params_dict: + other_weights.append((name, loaded_weight)) + + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(other_weights) + loaded_params.update(loaded_stacked_params) + return loaded_params + class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant): - packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} - base_model_prefix = "model" + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "decoder.": "model.decoder.", + "encoder.": "model.encoder.", + "shared.": "model.shared." + }, + orig_to_new_substr={ + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + }, + ) def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -789,7 +834,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = BartParallelLMHead(config.vocab_size, config.d_model, embed_scale=embed_scale) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) @@ -828,61 +872,12 @@ def compute_logits( sampling_metadata) return logits - stacked_params_mapping = { - "q_proj": { - "param_name": "qkv_proj", - "shard_id": "q", - }, - "k_proj": { - "param_name": "qkv_proj", - "shard_id": "k", - }, - "v_proj": { - "param_name": "qkv_proj", - "shard_id": "v", - }, - } - - params_mapping = { - "beta": "bias", - "gamma": "weight", - "LayerNorm": "layernorm", - } - - def _rename_key(self, key: str): - prefix = f"{self.base_model_prefix}." - key = key[len(prefix):] if key.startswith(prefix) else key - - for src, dst in self.params_mapping.items(): - key = key.replace(src, dst) - - return key - - def _rename_stacked_param( - self, - name: str, - ) -> tuple[str, Optional[str]]: - for key, mapping in self.stacked_params_mapping.items(): - if key in name: - name = name.replace(key, mapping["param_name"]) - return name, mapping["shard_id"] - return name, None - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - - model_params_dict = dict(self.model.named_parameters()) - top_params_dict = dict(self.named_parameters()) - + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: weights_tuple_list = list(weights) shared_embedding_weight = None - shared_embedding_shard_id = None - for name, loaded_weight in weights_tuple_list: - - name = self._rename_key(name) - name, shard_id = self._rename_stacked_param(name) - if ('shared.weight' in name or 'encoder.embed_tokens.weight' in name or 'decoder.embed_tokens.weight' in name @@ -890,49 +885,24 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): assert shared_embedding_weight is None, ( "Conflicting embedding weights.") shared_embedding_weight = loaded_weight - shared_embedding_shard_id = shard_id - else: - # Skip the specific downstream task weight. - if name.startswith('cls.'): - continue - # use Pooler instead. - if name.startswith('pooler.'): - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in model_params_dict: - continue - param = model_params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - if shard_id: - weight_loader(param, loaded_weight, shard_id) - else: - weight_loader(param, loaded_weight) - - # Assign shared weight values - encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] - encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", - default_weight_loader) - - decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] - decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", - default_weight_loader) - - lm_head_in_param = top_params_dict['lm_head.weight'] - lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", - default_weight_loader) - - assert shared_embedding_weight is not None - - if shared_embedding_shard_id: - encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, - shared_embedding_shard_id) - decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, - shared_embedding_shard_id) - lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, - shared_embedding_shard_id) - else: - encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) - decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) - lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) + loader = AutoWeightsLoader( + self, + skip_prefixes=(["cls.", "pooler."]), + ) + loaded_params = loader.load_weights(weights_tuple_list, + mapper=self.hf_to_vllm_mapper) + + if shared_embedding_weight is not None: + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, shared_embedding_weight) + + self.model.encoder.embed_tokens.weight = self.lm_head.weight + self.model.decoder.embed_tokens.weight = self.lm_head.weight + loaded_params.update({ + 'model.encoder.embed_tokens.weight', 'lm_head.weight', + 'model.decoder.embed_tokens.weight' + }) + + return loaded_params diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index 6e955e1c5121..9dc6115f850e 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable -from typing import Optional +from collections.abc import Iterable, Set +from typing import Optional, Union import torch from torch import nn @@ -17,17 +17,20 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import (ClassifierPooler, Pooler, +from vllm.model_executor.layers.pooler import (ClassifierPooler, + DispatchPooler, Pooler, + PoolingMethod, + PoolingParamsUpdate, PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.pooling_params import PoolingTask +from vllm.sequence import IntermediateTensors from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only -from .utils import WeightsMapper, maybe_prefix +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix class BertEmbedding(nn.Module): @@ -44,9 +47,11 @@ def __init__(self, config: BertConfig): config.type_vocab_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.position_ids = nn.Parameter( - torch.empty((1, config.max_position_embeddings)), ) + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).unsqueeze(0), + ) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": raise ValueError("Only 'absolute' position_embedding_type" + @@ -79,21 +84,40 @@ def forward( return embeddings -class BertPooler(nn.Module): +class BertPooler(Pooler): def __init__(self, config: BertConfig): super().__init__() + + self.pooling = PoolingMethod.from_pooling_type(PoolingType.CLS) self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[0, :] - pooled_output = self.dense(first_token_tensor) + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.pooling.get_pooling_updates(task) + + def _head(self, pooled_output: torch.Tensor): + pooled_output = self.dense(pooled_output) pooled_output = self.activation(pooled_output) return pooled_output + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + pooled_output = self.pooling(hidden_states, pooling_metadata) + + if isinstance(pooled_output, list): + pooled_output = [self._head(output) for output in pooled_output] + else: + pooled_output = self._head(pooled_output) + + return pooled_output + @support_torch_compile class BertEncoder(nn.Module): @@ -314,20 +338,24 @@ def forward(self, hidden_states: torch.Tensor, class BertModel(nn.Module, SupportsQuant): + + is_pooling_model = True + packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]} - def __init__(self, - *, - vllm_config: VllmConfig, - prefix: str = "", - embedding_class: type = BertEmbedding, - add_pooling_layer: bool = False): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type[nn.Module] = BertEmbedding, + ) -> None: super().__init__() + config = vllm_config.model_config.hf_config self.embeddings = embedding_class(config) self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") - self.pooler = BertPooler(config) if add_pooling_layer else None def forward( self, @@ -349,8 +377,7 @@ def forward( token_type_ids=token_type_ids) return self.encoder(hidden_states) - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "query", "q"), @@ -358,52 +385,90 @@ def load_weights(self, weights: Iterable[tuple[str, ("qkv_proj", "value", "v"), ] + loaded_stacked_params = [] + other_weights = [] params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() for name, loaded_weight in weights: - if self.pooler is None and "pooler" in name: - continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue + name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) + loaded_stacked_params.append(name) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) + if name in params_dict: + other_weights.append((name, loaded_weight)) + + return other_weights, loaded_stacked_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + other_weights, loaded_stacked_params = self._load_weights(weights) + + loader = AutoWeightsLoader(self, skip_prefixes=["pooler."]) + loaded_params = loader.load_weights(other_weights) + loaded_params.update(loaded_stacked_params) + return loaded_params + + +class BertPoolingModel(BertModel): + + is_pooling_model = True + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + embedding_class: type[nn.Module] = BertEmbedding, + ) -> None: + super().__init__( + vllm_config=vllm_config, + prefix=prefix, + embedding_class=embedding_class, + ) + + config = vllm_config.model_config.hf_config + self.pooler = BertPooler(config) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + other_weights, loaded_stacked_params = self._load_weights(weights) + + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(other_weights) + loaded_params.update(loaded_stacked_params) return loaded_params class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant): """A model that uses Bert to provide embedding functionalities. - This class encapsulates the BertModel and provides an interface for - embedding operations and customized pooling functions. + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. - Attributes: - model: An instance of BertModel used for forward operations. - _pooler: An instance of Pooler used for pooling operations. - """ - hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + is_pooling_model = True def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + self.model = self._build_model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) - self._pooler = self._build_pooler(pooler_config) + self.pooler = self._build_pooler(pooler_config) def forward( self, @@ -417,18 +482,16 @@ def forward( inputs_embeds=inputs_embeds, intermediate_tensors=intermediate_tensors) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - weights = ((name, data) for name, data in weights - if not name.startswith("lm_head.")) - self.model.load_weights(weights) + weights_list = list(weights) + + has_model_prefix = any( + name.startswith("model.") for name, _ in weights_list) + if not has_model_prefix: + mapper = WeightsMapper(orig_to_new_prefix={"": "model."}) + + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."]) + return loader.load_weights(weights_list, mapper=mapper) def _build_model(self, vllm_config: VllmConfig, @@ -438,10 +501,15 @@ def _build_model(self, embedding_class=BertEmbedding) def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: - return Pooler.from_config_with_defaults(pooler_config, - pooling_type=PoolingType.CLS, - normalize=True, - softmax=False) + return DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "embed": + Pooler.for_embed( + pooler_config, + default_pooling_type=PoolingType.CLS, + ), + }) class BertForSequenceClassification(nn.Module, SupportsV0Only, @@ -456,47 +524,44 @@ class BertForSequenceClassification(nn.Module, SupportsV0Only, _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.num_labels = config.num_labels - self.bert = BertModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "bert"), - embedding_class=BertEmbedding, - add_pooling_layer=True) + self.bert = BertPoolingModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = ClassifierPooler(vllm_config.model_config, - self.classifier, self.bert.pooler) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - - self_weights = [] - - def weight_filter(): - for name, weight in weights: - if name.startswith("bert."): - yield (name[len("bert."):], weight) - else: - self_weights.append((name, weight)) - - self.bert.load_weights(weight_filter()) - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in self_weights: - if name.startswith("classifier"): - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=self.bert.pooler, + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loaded_params = loader.load_weights(weights) + return loaded_params def forward( self, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 817c6bb9a7f9..c4f6144ed91f 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -189,10 +189,13 @@ def __init__( layer_idx = extract_layer_index(prefix) layer_has_sliding_window = ( - getattr(config, "sliding_window_pattern", False) - and (layer_idx + 1) % self.config.sliding_window_pattern != 0) + getattr(config, "sliding_window_pattern", False) and + (layer_idx + 1) % self.config.sliding_window_pattern + != 0) or (getattr(config, "layer_types", False) + and config.layer_types[layer_idx] == "sliding_attention") self.sliding_window = (interleaved_sliding_window + or config.sliding_window if layer_has_sliding_window else None) self.attn = Attention(self.num_heads, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 552c4b074216..cb07fe7d9e1d 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -3,9 +3,14 @@ from copy import deepcopy from typing import TYPE_CHECKING +import vllm.envs as envs from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: + from vllm.config import VllmConfig logger = init_logger(__name__) @@ -170,6 +175,15 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: vllm_config.model_config.hf_config.method = "from_2_way_softmax" +class JinaVLForSequenceClassificationConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config.hf_config + + config.num_labels = 1 + + class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): @staticmethod @@ -191,10 +205,110 @@ def verify_and_update_config(vllm_config: "VllmConfig") -> None: } +class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig): + + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + config = vllm_config.model_config + config.max_seq_len_to_capture = config.max_model_len + logger.info( + "Setting max_seq_len_to_capture to %d " + "to ensure that CUDA graph capture " + "covers sequences of length up to max_model_len.", + config.max_model_len) + + +class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + + if not envs.VLLM_USE_V1: + return + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + use_mla=model_config.use_mla).page_size_bytes + + model_cls = ModelRegistry.resolve_model_cls( + model_config._model_info.architecture)[0] + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtype=kv_cache_dtype, + block_size=model_config.max_model_len, + ).page_size_bytes + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, + 16 * attn_page_size_1_token) + + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if (cache_config.block_size is None + or cache_config.block_size < attn_block_size): + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + # compute new attention page size + attn_page_size = \ + cache_config.block_size * attn_page_size_1_token + + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if (cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size): + cache_config.mamba_page_size_padded = (attn_page_size) + mamba_padding_pct = 100 * (attn_page_size - + mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, "NomicBertModel": NomicBertModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "XLMRobertaModel": JinaRobertaModelConfig, + "JinaVLForRanking": JinaVLForSequenceClassificationConfig, + "GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig, } diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2fa1294b79b9..79ddd3d0f627 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -42,6 +42,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, + MergedReplicatedLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -336,7 +337,7 @@ def forward( kv_a, _ = latent_cache.split( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) - kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) @@ -407,14 +408,24 @@ def __init__( self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") + self.fused_qkv_a_proj = MergedReplicatedLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj") + else: + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + + if self.q_lora_rank is not None: self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False, @@ -427,13 +438,6 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( @@ -495,15 +499,24 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: + q_c = None + kv_lora = None + if self.q_lora_rank is not None: - q_c = self.q_a_proj(hidden_states)[0] + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_lora = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) q_c = self.q_a_layernorm(q_c) q = self.q_b_proj(q_c)[0] else: + kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] q = self.q_proj(hidden_states)[0] - kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( - [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + + kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], + dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c) q = q.view(-1, self.num_local_heads, self.qk_head_dim) # Add head dim of 1 to k_pe @@ -739,14 +752,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_expert_groups = config.n_group self.moe_layers: list[FusedMoE] = [] + example_moe = None for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + assert isinstance(layer, DeepseekV2DecoderLayer) if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) - # Pick last one layer since the first ones may be dense layers. - example_moe = typing.cast( - DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + if example_moe is None: + raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") + self.num_logical_experts = example_moe.n_logical_experts self.num_physical_experts = example_moe.n_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts @@ -770,6 +789,24 @@ def set_eplb_state( logical_replica_count=logical_replica_count, ) + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = (num_physical_experts - + self.num_logical_experts) + for layer in self.model.layers: + if isinstance(layer.mlp, DeepseekV2MoE): + moe = layer.mlp + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) @@ -813,6 +850,8 @@ def load_weights(self, weights: Iterable[tuple[str, # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -846,7 +885,16 @@ def load_weights(self, weights: Iterable[tuple[str, # for mlp.experts[0].gate_gate_up_proj, which breaks load. if (("mlp.experts." in name) and name not in params_dict): continue - name = name.replace(weight_name, param_name) + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ((param_name == "fused_qkv_a_proj") + and name_mapped not in params_dict): + continue + else: + name = name_mapped # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -925,9 +973,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, weight_name: str) -> Optional[int]: - if hasattr(config, - "num_nextn_predict_layers") and (config.num_nextn_predict_layers - > 0): + if (hasattr(config, "num_nextn_predict_layers") + and config.num_nextn_predict_layers > 0): layer_idx = config.num_hidden_layers for i in range(config.num_nextn_predict_layers): if weight_name.startswith(f"model.layers.{layer_idx+i}."): diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index a9654f5f4602..a222c4cbe9d0 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -351,11 +351,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): embed_std = 1 / torch.sqrt( torch.tensor(self.projector_config.n_embed, dtype=torch.float32)) if self.tile_tag == "2D": - # <|view_separator|>, <|\n|> + # <|view_seperator|>, <|\n|> self.image_newline = nn.Parameter( torch.randn(self.projector_config.n_embed) * embed_std) # This is a typo in original implementation - self.view_separator = nn.Parameter( + self.view_seperator = nn.Parameter( torch.randn(self.projector_config.n_embed) * embed_std) else: raise ValueError( @@ -560,13 +560,13 @@ def _pixel_values_to_embedding( if self.global_view_pos == "head": global_local_features = torch.cat([ global_features, - self.view_separator[None, :], + self.view_seperator[None, :], local_features, ]) else: global_local_features = torch.cat([ local_features, - self.view_separator[None, :], + self.view_seperator[None, :], global_features, ]) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py deleted file mode 100644 index c551ecd68ef8..000000000000 --- a/vllm/model_executor/models/eagle.py +++ /dev/null @@ -1,261 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from collections.abc import Iterable -from typing import Optional - -import torch -import torch.nn as nn - -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors - -from .utils import maybe_prefix - -logger = init_logger(__name__) - - -class DummyInputLayerNorm(nn.Module): - - def __init__(self, weight=None, bias=None): - super().__init__() - self.weight = nn.Parameter(weight) if weight is not None else None - self.bias = nn.Parameter(bias) if bias is not None else None - - def forward(self, x): - return x - - -class DummyOutputNorm(nn.Module): - - def forward(self, x, residual): - if residual is None: - return x - else: - return x + residual, None - - -class EAGLE(nn.Module): - """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 - Reference implementation: https://github.com/SafeAILab/EAGLE - - Differences from reference implementation: - 1. In reference, LlamaDecoderLayer implementation doesn't have - input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427). - Following this approach, our implementation also disables - the input_layernorm for the first decoder layer. - 2. We allow any decoder layer to be used in EAGLE whereas in reference - decoder layer is fixed to be LlamaDecoderLayer. - 3. We have an optional token_map which reduces draft vocab to most - frequently used tokens to give some additional speed-up by reducing - sampling overhead. This is disabled unless the checkpoint file has - explicit token_map tensor and config has an optional attribute - truncated_vocab_size < vocab_size. To use this technique, one has to find - the top-k most frequent tokens in target dataset and add that as a tensor - in the draft checkpoint (using key token_map). Also, the draft config - needs to have truncated_vocab_size (=k) as an attribute. - 4. We allow an enhanced EAGLE architecture similar to the DeepSeek MTP - module with regards to the use of additional RMS norms. The original - EAGLE architecture 1) skips the pre-attention norm in its first - transformer block, and 2) skips the final output norm, both of which we - found to be suboptimal. We also add the support for separate norms - applying to both the token embedding and hidden states before projection - as in DeepSeek MTP, which we found to improve performance as well. - """ - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - self.dtype = vllm_config.model_config.dtype - self.config = config - - architectures = getattr(self.config.model, "architectures", []) - model_cls, _ = ModelRegistry.resolve_model_cls(architectures) - - self.model = model_cls(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - self.fc = nn.Linear(config.model.hidden_size * 2, - config.model.hidden_size, - bias=getattr(self.config, "eagle_fc_bias", False)) - - # Modify layer normalization and residual connections as suggested - # in the EAGLE framework: https://github.com/SafeAILab/EAGLE - # While weights and biases are generally not needed, - # they are retained here to support certain unit tests - # (e.g., spec_decode/e2e/test_eagle_correctness.py). - if not hasattr(self.config.model, - "skip_prenorm") or self.config.model.skip_prenorm: - self.model.model.layers[0].input_layernorm = DummyInputLayerNorm( - weight=self.model.model.layers[0].input_layernorm.weight) - - if not hasattr( - self.config.model, - "skip_output_norm") or self.config.model.skip_output_norm: - self.model.model.norm = DummyOutputNorm() - - self.add_para_norm = False - if hasattr(self.config.model, - "add_para_norm") and self.config.model.add_para_norm: - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.add_para_norm = True - - self.orig_vocab_size = config.vocab_size - self.truncated_vocab_size = config.truncated_vocab_size - self.unpadded_vocab_size = self.truncated_vocab_size - - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=self.truncated_vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - self.truncated_vocab_size, - logit_scale) - - # Token map is a idx to token mapping to reduce the vocab size for - # the draft model. Using smaller vocab size for draft, containing - # only most frequent tokens reduces the speculation overhead. This - # doesn't affect the acceptance rate much and thus gives more speed - # -up. By default, this is disabled and is only used if the EAGLE - # checkpoint file has token_map tensor. - self.token_map = None - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - previous_hidden_states: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) - - # Handle both empty previous_hidden_states - # and mismatched batch size - batch_size = inputs_embeds.size(0) - if previous_hidden_states.size(0) == 0 or \ - previous_hidden_states.size(0) != batch_size: - hidden_dim = self.config.model.hidden_size - device = inputs_embeds.device - # Create zero tensor with matching batch size - previous_hidden_states = \ - torch.zeros(batch_size, hidden_dim, device=device) - - if self.add_para_norm: - inputs_embeds = torch.cat([ - self.enorm(inputs_embeds), - self.hnorm(previous_hidden_states) - ], - dim=-1) - else: - inputs_embeds = torch.cat([inputs_embeds, previous_hidden_states], - dim=-1) - - inputs_embeds = self.fc(inputs_embeds) - - inputs_embeds[positions == 0] = 0 # masking inputs at position=0 - - hidden_states = self.model.model( - input_ids=None, - inputs_embeds=inputs_embeds, - positions=positions, - intermediate_tensors=intermediate_tensors, - ) - return hidden_states - - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - - if self.token_map is not None: - _logits = logits - logits = -torch.inf * torch.ones( - size=(*_logits.shape[:-1], self.orig_vocab_size), - device=_logits.device, - dtype=_logits.dtype) - - logits[..., self.token_map] = _logits - - return logits - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - # This implementation is incompatible with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B - # due to missing lm_head weights and its config being that of a - # Llama model. Here's a compatible version with the same weights: - # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm - # Also, here's an example script for converting trained EAGLE - # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d - model_weights = {} - for name, loaded_weight in weights: - if name == "token_map": - if self.config.truncated_vocab_size < self.config.vocab_size: - self.token_map = nn.Parameter(loaded_weight, - requires_grad=False) - elif name.startswith("fc.weight"): - weight_loader = getattr(self.fc.weight, "weight_loader", - default_weight_loader) - weight_loader(self.fc.weight, loaded_weight) - elif name.startswith("fc.bias"): - if self.fc.bias is not None: - weight_loader = getattr(self.fc.bias, "weight_loader", - default_weight_loader) - weight_loader(self.fc.bias, loaded_weight) - else: - logger.warning_once("Found bias in the loaded weights but " - "the model config doesn't have bias.") - elif name.startswith("enorm.weight"): - weight_loader = getattr(self.enorm.weight, "weight_loader", - default_weight_loader) - weight_loader(self.enorm.weight, loaded_weight) - elif name.startswith("hnorm.weight"): - weight_loader = getattr(self.hnorm.weight, "weight_loader", - default_weight_loader) - weight_loader(self.hnorm.weight, loaded_weight) - elif name.startswith("model.lm_head.") or name.startswith( - "model.model."): - model_weights[name.split("model.", 1)[-1]] = loaded_weight - elif name.startswith("lm_head.") or name.startswith("model."): - model_weights[name] = loaded_weight - else: - model_weights[f"model.{name}"] = loaded_weight - - if "lm_head.weight" in model_weights: - lm_head_weight = model_weights.pop("lm_head.weight") - - if self.token_map is not None and\ - lm_head_weight.shape[0] > self.token_map.shape[0]: - - lm_head_weight = lm_head_weight[self.token_map] - - else: - # NOTE(Shangming): initialize the placeholder for lm_head weight. - lm_head_weight = torch.zeros( - self.lm_head.org_vocab_size, - self.lm_head.embedding_dim, - dtype=self.dtype, - ) - - weight_loader = getattr(self.lm_head.weight, "weight_loader", - default_weight_loader) - weight_loader(self.lm_head.weight, lm_head_weight) - - self.model.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index e7a50ff7a1c9..984003e62d11 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -51,8 +51,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP -from .utils import (PPMissingLayer, extract_layer_index, +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -427,66 +427,15 @@ def forward( return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: -class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - fall_back_to_pt_during_load = False - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - else: - self.lm_head = PPMissingLayer() - - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - hidden_states = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.moe_num_experts) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -499,16 +448,9 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.moe_num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if self.config.tie_word_embeddings and name.endswith( "lm_head.weight"): @@ -581,3 +523,76 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params + + +class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Ernie4_5_MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/exaone4.py b/vllm/model_executor/models/exaone4.py new file mode 100644 index 000000000000..97aeb6fd7b17 --- /dev/null +++ b/vllm/model_executor/models/exaone4.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +# Adapted from +# https://github.com/lgai-exaone/transformers/blob/add-exaone4/src/transformers/models/exaone4/modeling_exaone4.py +# Copyright 2025 The LG CNS Gen AI Solution Delivery Team. +# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.exaone4 import Exaone4Config + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class Exaone4GatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Exaone4Attention(nn.Module): + + def __init__( + self, + config: Exaone4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 1000000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.apply_all_layers = False # apply rotary embeddings to every layer. + layer_idx = extract_layer_index(prefix) + interleaved_sliding_window = getattr(config, + "interleaved_sliding_window", + 4096) + sliding_window_pattern = getattr(config, "sliding_window_pattern", + "LLLG") + + if sliding_window_pattern: + layer_has_sliding_window = ( + layer_idx + 1) % sliding_window_pattern.__len__() != 0 + else: + layer_has_sliding_window = False + self.apply_all_layers = True + + if layer_has_sliding_window: + self.sliding_window = interleaved_sliding_window + else: + self.sliding_window = None + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + per_layer_sliding_window=self.sliding_window, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + + if self.sliding_window or self.apply_all_layers: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Exaone4DecoderLayer(nn.Module): + + def __init__( + self, + config: Exaone4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + + self.self_attn = Exaone4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = Exaone4GatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + # Use post-LN + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + + # Fully Connected + hidden_states = self.mlp(hidden_states) + + # Use post-LN + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +@support_torch_compile +class Exaone4Model(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.quant_config = quant_config + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Exaone4DecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer:self.end_layer]: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Exaone4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = Exaone4Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index a76e1f256e04..6a58b1501fe6 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -10,8 +10,9 @@ from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -22,8 +23,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -109,7 +110,6 @@ def __init__( quant_config=quant_config, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size, ) # n_groups is overridden later by `MambaMixer2` self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state @@ -180,13 +180,15 @@ def forward( mamba2_metadata: Mamba2Metadata, **kwargs, ): - hidden_states = self.mamba( + output = torch.empty_like(hidden_states) + self.mamba( hidden_states, + output, mamba_cache_params, mamba2_metadata=mamba2_metadata, mup_vector=self.mup_vector, ) - return hidden_states, residual + return output, residual class FalconH1AttentionDecoderLayer(nn.Module): @@ -399,6 +401,7 @@ def forward( return hidden_states +@support_torch_compile class FalconH1Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -515,6 +518,42 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + intermediate_size = (int(hf_config.mamba_expand * + hf_config.hidden_size) + if hf_config.mamba_d_ssm is None else + hf_config.mamba_d_ssm) + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config @@ -581,12 +620,15 @@ def forward( mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager( self.vllm_config, self.lm_head.weight.dtype if hasattr( self.lm_head, 'weight') else torch.bfloat16, self.config.num_hidden_layers, - *self._get_mamba_cache_shape(), + *mamba_state_shape, ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -607,39 +649,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = (int(self.config.mamba_expand * - hidden_size) if self.config.mamba_d_ssm - is None else self.config.mamba_d_ssm) - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size) - - # - heads and n_groups are TP-ed - conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 26c8f80d5a0c..558d4fbb4de1 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -175,12 +175,21 @@ def _call_hf_processor( # Original output: (1, num_images, Pn, Px * Py * C) # New output: (num_images, Pn, Px * Py * C) - assert (isinstance(image_patches, list) - and len(image_patches) == 1) - assert (isinstance(image_patches[0], torch.Tensor) - and len(image_patches[0]) == len(images)) - - processed_outputs["image_patches"] = image_patches[0] + # image_patches is a list with shape: + # (1, num_images, Pn, Px * Py * C) + # before Transformers 4.53 + if isinstance(image_patches, list): + assert len(image_patches) == 1 + assert (isinstance(image_patches[0], torch.Tensor) + and len(image_patches[0]) == len(images)) + processed_outputs["image_patches"] = image_patches[0] + # image_patches is a tensor with shape: + # (num_images, Pn, Px * Py * C) + # after Transformers 4.53 + elif isinstance(image_patches, torch.Tensor): + assert len(image_patches) == len(images) + else: + raise AssertionError("This line should be unreachable.") return processed_outputs @@ -193,8 +202,10 @@ def _apply_hf_processor_tokens_only( vocab = tokenizer.get_vocab() boa_token_id = vocab["<0x04>"] + if prompt_tokens[-1] != boa_token_id: + prompt_tokens.append(boa_token_id) - return prompt_tokens + [boa_token_id] + return prompt_tokens def _get_mm_fields_config( self, diff --git a/vllm/model_executor/models/gemma3.py b/vllm/model_executor/models/gemma3.py index 954e48d25f67..1a2ce65d1e4c 100644 --- a/vllm/model_executor/models/gemma3.py +++ b/vllm/model_executor/models/gemma3.py @@ -149,14 +149,17 @@ def __init__(self, # TODO(woosuk): Add reference to the original HF implementation. layer_idx = extract_layer_index(prefix) self.is_sliding = (getattr( - config, "interleaved_sliding_window", None) is not None and bool( - (layer_idx + 1) % config.sliding_window_pattern)) + config, "interleaved_sliding_window", None) is not None and (bool( + (layer_idx + 1) % config.sliding_window_pattern))) or ( + getattr(config, "layer_types", None) is not None + and config.layer_types[layer_idx] == "sliding_attention") # Initialize the rotary embedding. if self.is_sliding: # Local attention. Override the values in config.json. self.rope_theta = config.rope_local_base_freq self.rope_scaling = {"rope_type": "default"} - self.sliding_window = config.interleaved_sliding_window + self.sliding_window = (config.interleaved_sliding_window + or config.sliding_window) else: # Global attention. Use the values in config.json. self.rope_theta = config.rope_theta diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index a3908e30ec6e..0996bcf60aa1 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -65,7 +65,7 @@ MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import _Backend from vllm.sequence import IntermediateTensors @@ -1213,7 +1213,10 @@ def get_video_replacement_glm4v(item_idx: int): placeholder.append(eoi_token_id) placeholder.extend(frame_idx) placeholder.append(eov_token_id) - return placeholder + return PromptUpdateDetails.select_token_id( + placeholder, + embed_token_id=hf_processor.video_token_id, + ) return [ PromptReplacement( diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py new file mode 100644 index 000000000000..bdca293d21db --- /dev/null +++ b/vllm/model_executor/models/glm4_moe.py @@ -0,0 +1,685 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4.5 model compatible with HuggingFace weights.""" +import typing +from collections.abc import Callable, Iterable +from typing import Any, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import (get_ep_group, get_pp_group, + get_tensor_model_parallel_world_size) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Glm4MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Glm4MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = self.ep_group.rank() + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + + # noaux_tc is not set in transformers new config now + self.gate.e_score_correction_bias = (nn.Parameter( + torch.empty(config.n_routed_experts))) + + # Load balancing settings. + vllm_config = get_current_vllm_config() + parallel_config = vllm_config.parallel_config + self.enable_eplb = enable_eplb + + self.n_redundant_experts = parallel_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = (self.n_logical_experts + + self.n_redundant_experts) + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = (self.ep_rank * + self.n_local_physical_experts) + self.physical_expert_end = (self.physical_expert_start + + self.n_local_physical_experts) + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func="sigmoid", + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = Glm4MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=self.experts.must_reduce_shared_expert_outputs( + ), + prefix=f"{prefix}.shared_experts", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = ( + self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states)) + return final_hidden_states.view(num_tokens, hidden_dim) + + +class Glm4MoeAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, + max_position_embeddings: int = 131072, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-05, + qkv_bias: bool = False, + use_qk_norm: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.use_qk_norm = use_qk_norm + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q = self.q_norm(q.reshape(-1, self.num_heads, + self.head_dim)).reshape(q.shape) + k = self.k_norm(k.reshape(-1, self.num_kv_heads, + self.head_dim)).reshape(k.shape) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + enable_eplb: bool = False, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 131072) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + + self.self_attn = Glm4MoeAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + head_dim=config.head_dim, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=config.attention_bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + use_qk_norm=config.use_qk_norm, + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace): + self.mlp = Glm4MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile +class Glm4MoeModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + enable_eplb = vllm_config.parallel_config.enable_eplb + self.config = config + + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Glm4MoeDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + enable_eplb=enable_eplb, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or not + # here since otherwise we may skip experts with other + # available replicas. + weight_loader = typing.cast(Callable[..., bool], + param.weight_loader) + success = weight_loader(param, + loaded_weight, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True) + if success: + name = name_mapped + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Glm4MoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = Glm4MoeModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + for layer in self.model.layers: + assert isinstance(layer, Glm4MoeDecoderLayer) + if isinstance(layer.mlp, Glm4MoE): + self.moe_layers.append(layer.mlp.experts) + + # Pick last one layer since the first ones may be dense layers. + example_moe = typing.cast( + Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp) + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ) -> None: + for layer_idx, layer in enumerate(self.moe_layers): + # Register the expert weights. + self.expert_weights.append(layer.get_expert_weights()) + layer.set_eplb_state( + moe_layer_idx=layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + ) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + +def get_spec_layer_idx_from_weight_name(config: PretrainedConfig, + weight_name: str) -> Optional[int]: + if hasattr(config, + "num_nextn_predict_layers") and (config.num_nextn_predict_layers + > 0): + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if f"layers.{layer_idx+i}." in weight_name: + return layer_idx + i + return None diff --git a/vllm/model_executor/models/glm4_moe_mtp.py b/vllm/model_executor/models/glm4_moe_mtp.py new file mode 100644 index 000000000000..0624640054d1 --- /dev/null +++ b/vllm/model_executor/models/glm4_moe_mtp.py @@ -0,0 +1,307 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright 2025 The ZhipuAI Team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights.""" + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import CacheConfig, VllmConfig +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name +from .interfaces import SupportsPP +from .utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +class Glm4MoeMultiTokenPredictorLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.mtp_block = Glm4MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds[positions == 0] = 0 + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class Glm4MoeMultiTokenPredictor(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + Glm4MoeMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( + input_ids, + positions, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers[str(self.mtp_start_layer_idx + + current_step_idx)] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class Glm4MoeMTP(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.config = vllm_config.model_config.hf_config + self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, + previous_hidden_states, inputs_embeds, + spec_step_idx) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> Optional[torch.Tensor]: + return self.model.compute_logits(hidden_states, sampling_metadata, + spec_step_idx) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is None: + continue + name = self._rewrite_spec_layer_name(spec_layer, name) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # According to DeepSeek-V3 Technical Report, MTP modules + # shares embedding layer. We only load the first weights. + if (spec_layer != self.model.mtp_start_layer_idx + and ".layers" not in name): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: + """ + Rewrite the weight name to match the format of the original model. + Add .mtp_block for modules in transformer layer block for spec layer + and rename shared layer weights to be top level. + """ + spec_layer_weight_names = [ + "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head" + ] + shared_weight_names = ["embed_tokens"] + spec_layer_weight = False + shared_weight = False + for weight_name in spec_layer_weight_names: + if weight_name in name: + spec_layer_weight = True + if weight_name in shared_weight_names: + shared_weight = True + break + if not spec_layer_weight: + # treat rest weights as weights for transformer layer block + name = name.replace(f"model.layers.{spec_layer}.", + f"model.layers.{spec_layer}.mtp_block.") + elif shared_weight: + # treat shared weights as top level weights + name = name.replace(f"model.layers.{spec_layer}.", "model.") + return name diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 27021550f998..98d76337395b 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -40,11 +40,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors -from ..layers.pooler import Pooler, PoolingType +from ..layers.pooler import DispatchPooler, Pooler from .interfaces import SupportsPP from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -332,30 +331,29 @@ class GPT2ForSequenceClassification(nn.Module): _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.transformer = GPT2Model(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "gpt2")) self.score = nn.Linear(config.n_embd, config.num_labels, bias=False) + pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=True) + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + Pooler.for_classify(pooler_config, classifier=None), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def forward( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index bd4d5d0b6b28..507a9206c428 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -273,6 +273,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, quant_config=quant_config, ) else: diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 6c7c9f5cc938..6a4dee9ae48d 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -36,7 +36,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -549,7 +548,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str): self.config = config self.quant_config = quant_config self.cache_config = cache_config - self.sampler = get_sampler() # The language model is typically a Granite LLM self.language_model = init_vllm_registered_model( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 5a70f3a616c6..7d31854dce8d 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -24,7 +24,7 @@ # limitations under the License. """Inference-only GraniteMoe model.""" from collections.abc import Iterable -from typing import Optional +from typing import Any, Optional import torch from torch import nn @@ -45,12 +45,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from . import mixtral from .interfaces import SupportsLoRA, SupportsPP -from .utils import AutoWeightsLoader, make_layers, maybe_prefix +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_layers, + maybe_prefix) class GraniteMoeMoE(nn.Module): @@ -111,6 +113,7 @@ def __init__( num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, + rope_scaling: Optional[dict[str, Any]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, attention_multiplier: Optional[float] = None, @@ -161,6 +164,7 @@ def __init__( max_position=max_position, base=int(self.rope_theta), is_neox_style=True, + rope_scaling=rope_scaling, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -196,12 +200,14 @@ def __init__( self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) self.self_attn = GraniteMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + rope_scaling=rope_scaling, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -307,6 +313,103 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def _load_weights(self, + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + This function is copied from `MixtralModel.load_weights`, mainly to + decouple from mixtral, avoiding impact on support like BNB + quantization. + """ + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if (self.quant_config is not None and + (scale_name := self.quant_config.get_cache_scale(name))): + # Loading kv cache quantization scales + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else + loaded_weight[0]) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name.endswith("scale"): + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: new_weights = {} @@ -339,7 +442,7 @@ def load_weights(self, weights: Iterable[tuple[str, new_weights[gate_name] = p else: new_weights[n] = p - return mixtral.MixtralModel.load_weights(self, new_weights.items()) + return self._load_weights(new_weights.items()) class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 676ef24fc4da..59c1dce48ee7 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -11,8 +11,9 @@ from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm @@ -21,8 +22,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -69,8 +70,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.mamba_chunk_size) + prefix=f"{prefix}.mixer") self.block_sparse_moe = None if getattr(config, "num_local_experts", 0) > 0: @@ -105,9 +105,9 @@ def forward( ): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.mamba(hidden_states, mamba_cache_params, - mamba2_metadata) - hidden_states = residual + hidden_states * self.residual_multiplier + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata) + hidden_states = residual + output * self.residual_multiplier residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) @@ -308,6 +308,7 @@ def forward( } +@support_torch_compile class GraniteMoeHybridModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -525,6 +526,38 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -588,9 +621,13 @@ def forward(self, self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.model_config.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.model_config.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -606,38 +643,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.mamba_expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py index bb160dbce45b..1e2e8544179c 100644 --- a/vllm/model_executor/models/granitemoeshared.py +++ b/vllm/model_executor/models/granitemoeshared.py @@ -27,8 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from . import mixtral -from .granitemoe import GraniteMoeAttention, GraniteMoeMoE +from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE from .interfaces import SupportsLoRA, SupportsPP from .utils import AutoWeightsLoader, make_layers, maybe_prefix @@ -82,12 +81,14 @@ def __init__( self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) self.self_attn = GraniteMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + rope_scaling=rope_scaling, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", @@ -242,7 +243,7 @@ def load_weights(self, weights: Iterable[tuple[str, new_weights[gate_name] = p else: new_weights[n] = p - return mixtral.MixtralModel.load_weights(self, new_weights.items()) + return GraniteMoeModel._load_weights(self, new_weights.items()) class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 4273afbf4699..8a3fbc6a49f0 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -1,19 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Set +from typing import Optional, Union -from array import array -from typing import Optional - +import numpy as np import torch import torch.nn as nn from vllm.config import ModelConfig, VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.pooler import PoolerHead +from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, + PoolerHead, PoolerNormalize, + PoolingParamsUpdate, + build_output, get_prompt_lens, + get_prompt_token_ids) from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.pooling_metadata import (PoolingMetadata, - PoolingTensors) -from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingTask +from vllm.sequence import PoolerOutput from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config from .interfaces import SupportsV0Only @@ -21,7 +25,8 @@ logger = init_logger(__name__) -class GritLMPooler(nn.Module): +class GritLMMeanPool(nn.Module): + """As `MeanPool`, but only includes non-instruction tokens.""" def __init__(self, model_config: ModelConfig): super().__init__() @@ -39,8 +44,8 @@ def __init__(self, model_config: ModelConfig): for tok in ["<s>", "▁<", "<", "|", "embed", ">", "<0x0A>", "user"] } - def tokens_to_ids(tokens: list[str]) -> array: - return array("i", [self.token_ids[token] for token in tokens]) + def tokens_to_ids(tokens: list[str]) -> np.ndarray: + return np.array([self.token_ids[token] for token in tokens]) self.user_pattern_ids = tokens_to_ids( ["▁<", "|", "user", "|", ">", "<0x0A>"]) @@ -49,32 +54,44 @@ def tokens_to_ids(tokens: list[str]) -> array: self.embed_pattern_ids = tokens_to_ids( ["▁<", "|", "embed", "|", ">", "<0x0A>"]) - self.head = PoolerHead(normalize=True, softmax=False) - - def _find_array(self, arr: array, target: array, start_idx: int) -> int: + def _find_array( + self, + arr: np.ndarray, + target: np.ndarray, + start_idx: int = 0, + end_idx: Optional[int] = None, + ) -> int: """ - Find the first occurrence of target in arr starting from start_idx. + Find the first occurrence of `target` in `arr` starting from + `start_idx`. Args: - arr: The array to search within - target: The consecutive subsequence to find - start_idx: The starting index to search from + arr: The array to search within. + target: The consecutive subsequence to find. + start_idx: The starting index to search from (inclusive). + end_idx: The ending index to search from (exclusive). Returns: - int: The index of the first occurrence of target in arr. + The index of the first occurrence of `target` in `arr`. """ if start_idx < 0: - raise ValueError("start_idx must be non-negative") - if not target or not arr: - raise ValueError("Empty arr or target not allowed") + raise ValueError("`start_idx` must be non-negative") + if len(arr) == 0 or len(target) == 0: + raise ValueError("Empty `arr` or `target` not allowed") + arr_len = len(arr) target_len = len(target) - for i in range(start_idx, len(arr) - target_len + 1): - if arr[i:i + target_len] == target: + + if end_idx is None: + end_idx = arr_len + + for i in range(start_idx, min(end_idx, arr_len - target_len + 1)): + if (arr[i:i + target_len] == target).all(): return i + return -1 - def _get_instruction_len(self, prompt_token_ids: array) -> int: + def _get_instruction_len(self, prompt_token_ids: np.ndarray) -> int: """ Get the length of the instruction in the prompt. @@ -84,7 +101,6 @@ def _get_instruction_len(self, prompt_token_ids: array) -> int: The pattern matching is done using integers instead of strings because the prompt is given as a list of token IDs. """ - instruction_len = 0 # Return no instruction in case of missing BOS token. @@ -99,7 +115,8 @@ def _get_instruction_len(self, prompt_token_ids: array) -> int: embed_pattern_ids = self.embed_pattern_ids if self._find_array(prompt_token_ids, self.user_pattern_ids, - start_idx=1) == 1: + start_idx=1, + end_idx=2) == 1: embed_pattern_ids = self.embed_newline_pattern_ids # Find the embed pattern in the prompt. @@ -117,64 +134,85 @@ def _get_instruction_len(self, prompt_token_ids: array) -> int: return instruction_len - def forward( + def get_supported_tasks(self) -> Set[PoolingTask]: + return {"encode", "embed"} + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return PoolingParamsUpdate(requires_token_ids=True) + + def forward_one( self, hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> PoolerOutput: - """ - Pool the hidden states by summing the embeddings of - non-instruction tokens. - """ - prompts_token_ids = [ - token_ids.prompt_token_ids_array - for _, token_ids in pooling_metadata.seq_data.items() - ] + prompt_len: Optional[torch.Tensor] = None, + instr_len: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert prompt_len is None or prompt_len == hidden_states.shape[0], \ + "partial prefill not supported with MEAN pooling" + + return hidden_states[instr_len:].mean(dim=0, dtype=torch.float32) + + def forward_all( + self, + hidden_states: torch.Tensor, + prompt_lens: torch.Tensor, + instr_lens: torch.Tensor, + ) -> Union[list[torch.Tensor], torch.Tensor]: + offset = 0 + pooled_data = list[torch.Tensor]() - instruction_lens = torch.tensor( + for prompt_len, instr_len in zip(prompt_lens, instr_lens): + pooled_data.append(hidden_states[offset + instr_len:offset + + prompt_len].mean( + dim=0, dtype=torch.float32)) + offset += prompt_len + + return pooled_data + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[list[torch.Tensor], torch.Tensor]: + prompt_lens = get_prompt_lens(hidden_states, pooling_metadata) + instr_lens = torch.tensor( [ - self._get_instruction_len(prompt_token_ids) - for prompt_token_ids in prompts_token_ids + self._get_instruction_len(token_ids.cpu().numpy()) + for token_ids in get_prompt_token_ids(pooling_metadata) ], - device=hidden_states.device, + device=prompt_lens.device, ) - prompt_lens = PoolingTensors.from_pooling_metadata( - pooling_metadata, hidden_states.device).prompt_lens + if isinstance(hidden_states, list): + return [ + self.forward_one(h, prompt_len, instr_len) for h, prompt_len, + instr_len in zip(hidden_states, prompt_lens, instr_lens) + ] - mask = torch.zeros_like(hidden_states, dtype=torch.bool) + return self.forward_all(hidden_states, prompt_lens, instr_lens) - start_idx = 0 - for prompt_len, instruction_len in zip(prompt_lens, instruction_lens): - end_idx = start_idx + prompt_len - mask[start_idx + instruction_len:end_idx] = True - start_idx = end_idx - masked_hidden_states = hidden_states.masked_fill(~mask, 0.0) +class GritLMPooler(Pooler): - sum_embeddings = torch.zeros(len(prompt_lens), - hidden_states.size(1), - device=hidden_states.device) - - start_idx = 0 - for i, prompt_len in enumerate(prompt_lens): - end_idx = start_idx + prompt_len - sum_embeddings[i] = masked_hidden_states[start_idx:end_idx].sum( - dim=0) - start_idx = end_idx + def __init__(self, model_config: ModelConfig): + super().__init__() - num_non_instruction_tokens = prompt_lens - instruction_lens - mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( - 1) + self.pooling = GritLMMeanPool(model_config) + self.head = PoolerHead(PoolerNormalize()) - pooled_data = self.head(mean_embeddings, - pooling_metadata=pooling_metadata) + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() - pooled_outputs = [ - PoolingSequenceGroupOutput(data) for data in pooled_data - ] + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.pooling.get_pooling_updates(task) - return PoolerOutput(outputs=pooled_outputs) + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + pooled_data = self.pooling(hidden_states, pooling_metadata) + pooled_data = self.head(pooled_data, pooling_metadata) + return build_output(pooled_data) class GritLM(LlamaForCausalLM, SupportsV0Only): @@ -195,13 +233,15 @@ class GritLM(LlamaForCausalLM, SupportsV0Only): - "<|user|>\nPROMPT\n<|assistant|>\n" """ + is_pooling_model = True + def __init__( self, vllm_config: VllmConfig, prefix: str = "", **kwargs, ) -> None: - # Use full attention for pooling + # Use full attention for pooling (this is why V1 is not supported yet) if vllm_config.model_config.runner_type == "pooling": hf_config = vllm_config.model_config.hf_config hf_config.is_causal = False @@ -214,11 +254,11 @@ def __init__( super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) - self._pooler = GritLMPooler(vllm_config.model_config) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) + pooler_config = vllm_config.model_config.pooler_config + if pooler_config is not None: + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "embed": + GritLMPooler(vllm_config.model_config), + }) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 2d930527b2be..3659249cd8bd 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -360,6 +360,16 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Map Grok1's unique expert parameter names to standard names + # Grok1 uses "num_experts" in its config + num_experts = getattr(self.config, "num_experts", 8) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", # Grok1 specific + ckpt_down_proj_name="linear_1", # Grok1 specific + ckpt_up_proj_name="linear_v", # Grok1 specific + num_experts=num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -369,18 +379,9 @@ def load_weights(self, weights: Iterable[tuple[str, ("qkv_proj", "v_proj", "v"), ] - # Map Grok1's unique expert parameter names to standard names - # Grok1 uses "num_experts" in its config - num_experts = getattr(self.config, "num_experts", 8) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="linear", # Grok1 specific - ckpt_down_proj_name="linear_1", # Grok1 specific - ckpt_up_proj_name="linear_v", # Grok1 specific - num_experts=num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() - + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -544,3 +545,6 @@ def load_weights(self, weights: Iterable[tuple[str, skip_prefixes=skip_prefixes, ) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/hunyuan_v1_moe.py b/vllm/model_executor/models/hunyuan_v1.py similarity index 92% rename from vllm/model_executor/models/hunyuan_v1_moe.py rename to vllm/model_executor/models/hunyuan_v1.py index 89ca3e8a6071..fbba849a76f2 100644 --- a/vllm/model_executor/models/hunyuan_v1_moe.py +++ b/vllm/model_executor/models/hunyuan_v1.py @@ -49,7 +49,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -57,7 +56,22 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_layers) + + +def _is_moe(config: PretrainedConfig) -> bool: + num_experts = getattr(config, "num_experts", None) + if isinstance(num_experts, int): + return num_experts > 1 + if isinstance(num_experts, list) and num_experts: + # Ensure all elements are integers before calling max. + if all(isinstance(e, int) for e in num_experts): + return max(num_experts) > 1 + else: + return False + return False def _get_cla_factor(config: PretrainedConfig) -> int: @@ -139,8 +153,8 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - # MistralConfig has an optional head_dim introduced by Mistral-Nemo - if hasattr(config, "head_dim"): + + if hasattr(config, "head_dim") and config.head_dim: self.head_dim = config.head_dim elif hasattr(config, "attention_head_dim"): self.head_dim = config.attention_head_dim @@ -489,12 +503,23 @@ def __init__( else: raise RuntimeError(f"Unsupported attention type: {attention_type}") - self.mlp = HunYuanSparseMoeBlock( - config=config, - quant_config=quant_config, - layer_id=layer_id, - prefix=f"{prefix}.mlp", - ) + if _is_moe(config): + self.mlp = HunYuanSparseMoeBlock( + config=config, + quant_config=quant_config, + layer_id=layer_id, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = HunYuanMLP( + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -618,95 +643,6 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - -class HunYuanMoEV1ForCausalLM(nn.Module): - packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - lora_config = vllm_config.lora_config - self.config = config - self.quant_config = quant_config - self.lora_config = lora_config - - self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") - if get_pp_group().is_last_rank: - self.unpadded_vocab_size = config.vocab_size - if lora_config: - self.unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead( - self.unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, - logit_scale) - self.sampler = get_sampler() - else: - self.lm_head = PPMissingLayer() - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) - return model_output - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def _split_qkv_weight(self, qkv: torch.Tensor): num_attention_heads = self.config.num_attention_heads num_kv_heads = getattr(self.config, "num_key_value_heads", @@ -729,6 +665,19 @@ def _split_qkv_weight(self, qkv: torch.Tensor): v = v.reshape(-1, hidden_size) return torch.concat((q, k, v)) + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + if _is_moe(self.config): + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + else: + return [] + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): cla_factor = _get_cla_factor(self.config) stacked_params_mapping = [ @@ -755,16 +704,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): ), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts, - ) - params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -816,7 +758,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - + loaded_params.add(name) is_found = True break if is_found: @@ -895,3 +837,101 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class HunYuanV1Base(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + self.model = HunYuanModel(vllm_config=vllm_config, prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() + + +class HunYuanDenseV1ForCausalLM(HunYuanV1Base): + pass + + +class HunYuanMoEV1ForCausalLM(HunYuanV1Base): + pass diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py index 4643468af4ce..de216a81e934 100644 --- a/vllm/model_executor/models/idefics3.py +++ b/vllm/model_executor/models/idefics3.py @@ -22,8 +22,8 @@ import torch from torch import nn -from transformers import (AddedToken, BatchFeature, Idefics3Config, - Idefics3ImageProcessor, Idefics3Processor) +from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor, + Idefics3Processor) from vllm.config import VllmConfig from vllm.model_executor.layers.linear import ReplicatedLinear @@ -199,21 +199,14 @@ def get_num_patches( return grid_w * grid_h + 1 - # TODO: Remove after requiring transformers>=4.52 - def _get_content(self, token: Union[AddedToken, str]) -> str: - if isinstance(token, str): - return token - - return token.content - def _get_image_token( self, processor: Optional[Idefics3Processor]) -> tuple[str, str, str]: if processor is None: processor = self.get_hf_processor() - image_token = self._get_content(processor.image_token) - fake_image_token = self._get_content(processor.fake_image_token) + image_token = processor.image_token + fake_image_token = processor.fake_image_token global_image_token = processor.global_image_tag return image_token, fake_image_token, global_image_token diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index a018bd5d09d9..957b57276b4c 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -5,10 +5,14 @@ from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, Union, overload, runtime_checkable) +import numpy as np import torch from torch import Tensor from typing_extensions import Self, TypeIs +from vllm.config import ModelConfig, SpeechToTextConfig +from vllm.inputs import TokensPrompt +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -18,6 +22,7 @@ if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors @@ -89,11 +94,22 @@ def get_input_embeddings( ) -> Tensor: ... + # TODO: Remove this overload once v0 is deprecated @overload def get_input_embeddings( self, input_ids: Tensor, multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> Tensor: + ... + + def get_input_embeddings( + self, + input_ids: Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + # Only necessary so that the v0 overload is valid + # TODO: Remove attn_metadata once v0 is deprecated + attn_metadata: Optional["AttentionMetadata"] = None, ) -> Tensor: """ Returns the input embeddings merged from the text embeddings from @@ -103,13 +119,6 @@ def get_input_embeddings( ... -# We can't use runtime_checkable with ClassVar for issubclass checks -# so we need to treat the class as an instance and use isinstance instead -@runtime_checkable -class _SupportsMultiModalType(Protocol): - supports_multimodal: Literal[True] - - @overload def supports_multimodal( model: type[object]) -> TypeIs[type[SupportsMultiModal]]: @@ -124,10 +133,86 @@ def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: def supports_multimodal( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: - if isinstance(model, type): - return isinstance(model, _SupportsMultiModalType) + return getattr(model, "supports_multimodal", False) - return isinstance(model, SupportsMultiModal) + +@runtime_checkable +class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal_raw_input: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs and processes + them in their raw form and not embeddings. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + +@overload +def supports_multimodal_raw_input( + model: object) -> TypeIs[SupportsMultiModalWithRawInput]: + ... + + +@overload +def supports_multimodal_raw_input( + model: type[object]) -> TypeIs[type[SupportsMultiModalWithRawInput]]: + ... + + +def supports_multimodal_raw_input( + model: Union[type[object], object] +) -> Union[TypeIs[type[SupportsMultiModalWithRawInput]], + TypeIs[SupportsMultiModalWithRawInput]]: + return getattr(model, "supports_multimodal_raw_input", False) + + +@runtime_checkable +class SupportsScoreTemplate(Protocol): + """The interface required for all models that support score template.""" + + supports_score_template: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports score template. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + @classmethod + def get_score_template(cls, query: str, document: str) -> Optional[str]: + """ + Generate a full prompt by populating the score template with query and document content. + """ # noqa: E501 + ... + + @classmethod + def post_process_tokens(cls, prompt: TokensPrompt) -> None: + """ + Perform architecture-specific manipulations on the input tokens. + """ + ... + + +@overload +def supports_score_template( + model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]: + ... + + +@overload +def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: + ... + + +def supports_score_template( + model: Union[type[object], object], +) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]: + return getattr(model, "supports_score_template", False) @runtime_checkable @@ -337,11 +422,6 @@ class HasInnerState(Protocol): """ -@runtime_checkable -class _HasInnerStateType(Protocol): - has_inner_state: ClassVar[Literal[True]] - - @overload def has_inner_state(model: object) -> TypeIs[HasInnerState]: ... @@ -355,10 +435,7 @@ def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: def has_inner_state( model: Union[type[object], object] ) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]: - if isinstance(model, type): - return isinstance(model, _HasInnerStateType) - - return isinstance(model, HasInnerState) + return getattr(model, "has_inner_state", False) @runtime_checkable @@ -374,11 +451,6 @@ class IsAttentionFree(Protocol): """ -@runtime_checkable -class _IsAttentionFreeType(Protocol): - is_attention_free: ClassVar[Literal[True]] - - @overload def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: ... @@ -392,10 +464,7 @@ def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: def is_attention_free( model: Union[type[object], object] ) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]: - if isinstance(model, type): - return isinstance(model, _IsAttentionFreeType) - - return isinstance(model, IsAttentionFree) + return getattr(model, "is_attention_free", False) @runtime_checkable @@ -410,10 +479,24 @@ class IsHybrid(Protocol): , also indicates that the model's hf_config has 'layers_block_type' """ + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. -@runtime_checkable -class _IsHybridType(Protocol): - is_hybrid: ClassVar[Literal[True]] + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + ... @overload @@ -429,10 +512,7 @@ def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: def is_hybrid( model: Union[type[object], object] ) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]: - if isinstance(model, type): - return isinstance(model, _IsHybridType) - - return isinstance(model, IsHybrid) + return getattr(model, "is_hybrid", False) @runtime_checkable @@ -497,6 +577,13 @@ def set_eplb_state( """ ... + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + ... + def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]: return isinstance(model, MixtureOfExperts) @@ -507,11 +594,6 @@ class HasNoOps(Protocol): has_noops: ClassVar[Literal[True]] = True -@runtime_checkable -class _HasNoOpsType(Protocol): - has_noops: ClassVar[Literal[True]] - - @overload def has_noops(model: object) -> TypeIs[HasNoOps]: ... @@ -525,10 +607,7 @@ def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: def has_noops( model: Union[type[object], object] ) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]: - if isinstance(model, type): - return isinstance(model, _HasNoOpsType) - - return isinstance(model, HasNoOps) + return getattr(model, "has_noops", False) @runtime_checkable @@ -552,11 +631,7 @@ def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: def _supports_cross_encoding( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: - - if isinstance(model, type): - return isinstance(model, SupportsCrossEncoding) - - return isinstance(model, SupportsCrossEncoding) + return getattr(model, "supports_cross_encoding", False) def supports_cross_encoding( @@ -565,12 +640,6 @@ def supports_cross_encoding( return is_pooling_model(model) and _supports_cross_encoding(model) -def has_step_pooler(model: Union[type[object], object]) -> bool: - """Check if the model uses step pooler.""" - return is_pooling_model(model) and any( - type(module).__name__ == "StepPool" for module in model.modules()) - - class SupportsQuant: """The interface required for all models that support quantization.""" @@ -589,13 +658,9 @@ def __new__(cls, *args, **kwargs) -> Self: instance.quant_config = quant_config # apply model mappings to config for proper config-model matching - # NOTE: `TransformersForCausalLM` is not supported due to how this - # class defines `hf_to_vllm_mapper` as a post-init `@property`. - # After this is fixed, get `instance.hf_to_vllm_mapper` directly - if getattr(instance, "hf_to_vllm_mapper", None) is not None: - instance.quant_config.apply_vllm_mapper( - instance.hf_to_vllm_mapper) - if getattr(instance, "packed_modules_mapping", None) is not None: + if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None: + instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper) + if instance.packed_modules_mapping is not None: instance.quant_config.packed_modules_mapping.update( instance.packed_modules_mapping) @@ -623,10 +688,21 @@ class SupportsTranscription(Protocol): supports_transcription: ClassVar[Literal[True]] = True + supports_transcription_only: ClassVar[bool] = False + """ + Transcription models can opt out of text generation by setting this to + `True`. + """ + @classmethod - def get_decoder_prompt(cls, language: str, task_type: str, - prompt: str) -> str: - """Get the decoder prompt for the ASR model.""" + def get_generation_prompt(cls, audio: np.ndarray, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, language: str, + task_type: str, + request_prompt: str) -> PromptType: + """Get the prompt for the ASR model. + The model has control over the construction, as long as it + returns a valid PromptType.""" ... @classmethod @@ -634,6 +710,25 @@ def validate_language(cls, language: str) -> bool: """Check if the model supports a specific ISO639_1 language.""" ... + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, + task_type: Literal["transcribe", + "translate"]) -> SpeechToTextConfig: + """Get the speech to text config for the ASR model.""" + ... + + @classmethod + def get_num_audio_tokens(cls, audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig) -> Optional[int]: + """ + Map from audio duration to number of audio tokens produced by the ASR + model, without running a forward pass. + This is used for estimating the amount of processing for this audio. + """ + return None + @overload def supports_transcription( @@ -649,10 +744,7 @@ def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: def supports_transcription( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]: - if isinstance(model, type): - return isinstance(model, SupportsTranscription) - - return isinstance(model, SupportsTranscription) + return getattr(model, "supports_transcription", False) @runtime_checkable @@ -675,7 +767,4 @@ def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]: def supports_v0_only( model: Union[type[object], object], ) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]: - if isinstance(model, type): - return isinstance(model, SupportsV0Only) - - return isinstance(model, SupportsV0Only) + return getattr(model, "supports_v0_only", False) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py index 4a1ea74a218a..4d68227b2af8 100644 --- a/vllm/model_executor/models/interfaces_base.py +++ b/vllm/model_executor/models/interfaces_base.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload, - runtime_checkable) +from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, + Union, overload, runtime_checkable) import torch import torch.nn as nn @@ -13,8 +12,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig - from vllm.model_executor.layers.pooler import PoolerOutput - from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.layers.pooler import Pooler from vllm.model_executor.sampling_metadata import SamplingMetadata logger = init_logger(__name__) @@ -130,16 +128,20 @@ def is_text_generation_model( @runtime_checkable -class VllmModelForPooling(VllmModel[T], Protocol[T]): +class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]): """The interface required for all pooling models in vLLM.""" - def pooler( - self, - hidden_states: T, - pooling_metadata: "PoolingMetadata", - ) -> "PoolerOutput": - """Only called on TP rank 0.""" - ... + is_pooling_model: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pooling. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + pooler: "Pooler" + """The pooler is only called on TP rank 0.""" @overload @@ -158,7 +160,4 @@ def is_pooling_model( if not is_vllm_model(model): return False - if isinstance(model, type): - return isinstance(model, VllmModelForPooling) - - return isinstance(model, VllmModelForPooling) + return getattr(model, "is_pooling_model", False) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index e8549b4e0538..d29779a35e5c 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -22,15 +22,14 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .utils import (is_pp_missing_parameter, @@ -404,6 +403,8 @@ def load_weights(self, weights: Iterable[tuple[str, class InternLM2ForRewardModel(InternLM2ForCausalLM): + is_pooling_model = True + def __init__( self, *, @@ -428,12 +429,10 @@ def __init__( ) pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.ALL, - normalize=False, - softmax=False, - ) + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) def forward( self, @@ -446,10 +445,3 @@ def forward( inputs_embeds) logits, _ = self.v_head(hidden_states) return logits - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 8294f846bbd1..34281b2e99ee 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -19,16 +19,16 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer -from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, + PoolingType) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.mamba_cache import (MambaCacheManager, MambaCacheParams) -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP, @@ -562,31 +562,37 @@ def _is_moe_layer(name: str): class JambaForSequenceClassification(JambaForCausalLM): + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) + config = vllm_config.model_config.hf_config num_labels: int = config.num_labels score_bias: bool = getattr(config, 'score_bias', False) - self.score = nn.Linear(config.hidden_size, num_labels, bias=score_bias) - pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.LAST, - normalize=False, - softmax=False) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - hidden_states = hidden_states.float() - logits = self.score(hidden_states) - return self._pooler(logits, pooling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - # TODO: The reward weights themselves have float32 accuracy data, we + # TODO: The original reward weights have float32 accuracy data, we # would like to load them in fp32 to get that extra precision. - super().load_weights(weights) - self.score = self.score.float() + # Currently weight_loader passes the weight which is already in bf16 + self.score = nn.Linear( + config.hidden_size, + num_labels, + bias=score_bias, + dtype=torch.float32, + ) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + Pooler.for_classify( + pooler_config, + classifier=self.score, + default_pooling_type=PoolingType.LAST, + default_normalize=False, + default_softmax=False, + ), + }) diff --git a/vllm/model_executor/models/jina_vl.py b/vllm/model_executor/models/jina_vl.py new file mode 100644 index 000000000000..0c4284f7daaa --- /dev/null +++ b/vllm/model_executor/models/jina_vl.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable, Mapping +from typing import Optional + +import torch +import torch.nn as nn +from transformers import BatchFeature, PretrainedConfig + +from vllm.config import VllmConfig +from vllm.inputs import TokensPrompt +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import DispatchPooler, Pooler +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors + +from .interfaces import (SupportsCrossEncoding, SupportsMultiModal, + SupportsScoreTemplate) +from .qwen2_vl import (Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, + Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo) +from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix + +logger = init_logger(__name__) + + +class JinaVLScorer(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + bias=True) + self.out_proj = RowParallelLinear(config.hidden_size, + config.num_labels, + bias=True) + + def forward(self, x, **kwargs): + x, _ = self.dense(x) + x = torch.relu(x) + x, _ = self.out_proj(x) + return x + + +class JinaVLMultiModalProcessor(Qwen2VLMultiModalProcessor): + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + + # NOTE: We should reverse the order of the mm_data because the + # query prompt is placed after the document prompt in the score + # template for JinaVLForRanking model, but in mm_data they are + # stored in the opposite order (query first, then document). + for _, value in mm_data.items(): + value.reverse() + return super()._call_hf_processor(prompt, mm_data, mm_kwargs, + tok_kwargs) + + +@MULTIMODAL_REGISTRY.register_processor(JinaVLMultiModalProcessor, + info=Qwen2VLProcessingInfo, + dummy_inputs=Qwen2VLDummyInputsBuilder) +class JinaVLForSequenceClassification(Qwen2VLForConditionalGeneration, + SupportsCrossEncoding, + SupportsMultiModal, + SupportsScoreTemplate): + + is_pooling_model = True + weight_mapper = WeightsMapper( + orig_to_new_prefix={ + "score.0.": "score.dense.", + "score.2.": "score.out_proj.", + # mapping for new names in checkpoint saved after transformers v4.52 + "model.language_model.": "language_model.model.", + "visual.": "visual.", + # mapping for original checkpoint + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "qwen2_vl")) + config = vllm_config.model_config.hf_config + pooler_config = vllm_config.model_config.pooler_config + + # logit bias for sigmoid normalization + self.LOGIT_BIAS = 2.65 + + self.score = JinaVLScorer(config) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + Pooler.for_classify(pooler_config, classifier=None), + "score": + Pooler.for_classify(pooler_config, classifier=None), + }) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<|vision_start|><|image_pad|><|vision_end|>" + + raise ValueError("Only image modality is supported") + + @classmethod + def get_score_template(cls, query: str, document: str) -> Optional[str]: + return f"**Document**:\n{document}\n**Query**:\n{query}" + + @classmethod + def post_process_tokens(cls, prompt: TokensPrompt) -> None: + + # add score target token at the end of prompt tokens + prompt['prompt_token_ids'].append(100) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> torch.Tensor: + hidden_states = super().forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + logits = self.score(hidden_states) - self.LOGIT_BIAS + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.weight_mapper) diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 34cd26b4c062..3e1c64bb62ea 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1592,6 +1592,9 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: return modalities + def get_language_model(self) -> torch.nn.Module: + return self.language_model + def get_multimodal_embeddings( self, **kwargs: object) -> Optional[MultiModalEmbeddings]: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5d5080479e51..48ec611df12d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): "qscale_act": "input_scale", "qscale_weight": "weight_scale", "kv_fake_quantizer.qscale_act": "kv_scale", + "q_fake_quantizer.qscale_act": "attn.q_scale", + "k_fake_quantizer.qscale_act": "k_scale", + "v_fake_quantizer.qscale_act": "v_scale", "wq": "q_proj", "wk": "k_proj", "wv": "v_proj", diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0c9baab1f2e4..fab1c163ac28 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -35,7 +35,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk, @@ -432,12 +433,24 @@ def load_weights(self, weights: Iterable[tuple[str, for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name or "experts" in name: continue - name = name.replace(weight_name, param_name) + # This check is for ModelOpt ckpts with kv cache quant enabled + if not (name.endswith( + (".k_scale", ".v_scale")) and "self_attn" in name): + name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue + if name.endswith("scale") and "expert" not in name: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader == default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break else: @@ -452,6 +465,44 @@ def load_weights(self, weights: Iterable[tuple[str, if not moe_loaded: if is_pp_missing_parameter(name, self): continue + + # Handle flat expert scale parameters that + # don't match per-expert patterns + if ("experts." in name and ("w13_input_scale" in name + or "w13_weight_scale" in name + or "w2_input_scale" in name + or "w2_weight_scale" in name)): + # These are flat expert scales that apply to all experts + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + + # Check for MoE-specific loading support via + # attribute instead of expensive runtime reflection + supports_moe = getattr(weight_loader, + 'supports_moe_loading', False) + + if supports_moe: + # This is a MoE weight loader + if "w13_" in name: + shard_id = "w1" + elif "w2_" in name: + shard_id = "w2" + else: + shard_id = "w1" + + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=0) + else: + # Regular weight loader (handles both + # param.weight_loader and default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py new file mode 100644 index 000000000000..222ab5dfaee4 --- /dev/null +++ b/vllm/model_executor/models/llama4_eagle.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 the LLAMA4, Meta Inc., vLLM, and HuggingFace Inc. team. +# All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.torchao import TorchAOConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama4 import (Llama4DecoderLayer, + Llama4ForCausalLM) +from vllm.model_executor.models.utils import extract_layer_index + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +@support_torch_compile +class LlamaModel(nn.Module): + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + start_layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + self.validate_and_update_config(start_layer_id, quant_config) + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList([ + Llama4DecoderLayer( + self.config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + self.norm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states, hidden_states + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + name = name.removeprefix("model.") + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + for name in params_dict: + # if PP disabled then draft will share embed with target + if get_pp_group().world_size == 1 and \ + "embed_tokens." in name: + continue + assert name in loaded_params, f"{name} is not loaded!" + return loaded_params + + def validate_and_update_config( + self, + start_layer_id: int, + quant_config: Optional[QuantizationConfig] = None) -> None: + # yoco and moe is not supported by draft model yet + assert self.config.yoco_global_kv_layer is None + assert self.config.yoco_local_kv_layer is None + assert len(self.config.moe_layers) == 0 + # draft model layer index is increased by start_layer_id, + # so we need to pad relevant configs accordingly + self.config.no_rope_layers = [ + 0 + ] * start_layer_id + self.config.no_rope_layers + # currently only TorchAO quantization is supported + if isinstance(quant_config, TorchAOConfig): + + def pad_layer_name(layer: str) -> str: + layer_index = extract_layer_index(layer) + return layer.replace(str(layer_index), + str(layer_index + start_layer_id)) + + quant_config.torchao_config.module_fqn_to_config = { + pad_layer_name(layer): quantization + for layer, quantization in + quant_config.torchao_config.module_fqn_to_config.items() + } + + +class EagleLlama4ForCausalLM(Llama4ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = ( + vllm_config.speculative_config.draft_model_config.hf_config) + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config) + # draft model quantization config may differ from target model + quant_config = VllmConfig.get_quantization_config( + vllm_config.speculative_config.draft_model_config, + vllm_config.load_config) + self.model = LlamaModel(vllm_config=vllm_config, + prefix="model", + start_layer_id=target_layer_num, + quant_config=quant_config) + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> None: + loader = AutoWeightsLoader( + self, + # lm_head is tied with target model (Llama4ForCausalLM) + skip_prefixes=(["lm_head."]), + ) + + model_weights = {} + weights = [ + self.permute_qk_weight_for_rotary(name, loaded_weight) + for name, loaded_weight in weights + ] + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index d2403ccbb972..adad181617e6 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -10,16 +10,16 @@ from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -62,8 +62,7 @@ def __init__(self, rms_norm_eps=config.layer_norm_epsilon, activation=config.hidden_act, quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size) + prefix=f"{prefix}.mixer") self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -81,11 +80,12 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params, - mamba2_metadata) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + return output, residual +@support_torch_compile class Mamba2Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -199,6 +199,38 @@ def load_weights(self, weights: Iterable[tuple[str, class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.n_groups, + num_heads=hf_config.num_heads, + head_dim=hf_config.head_dim, + state_size=hf_config.state_size, + conv_kernel=hf_config.conv_kernel, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -254,9 +286,13 @@ def forward(self, self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) else: @@ -275,39 +311,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = getattr( - self.config, "intermediate_size", - self.config.expand * self.config.hidden_size) - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = ( - self.config.n_groups + - extra_groups_for_head_shards(self.config.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.conv_kernel - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.num_heads, world_size), - self.config.head_dim, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py index 49ba974c69a5..27685c59a3ea 100644 --- a/vllm/model_executor/models/mamba_cache.py +++ b/vllm/model_executor/models/mamba_cache.py @@ -36,10 +36,12 @@ def __init__(self, vllm_config: VllmConfig, dtype: torch.dtype, # Initialize parent class super().__init__(max_batch_size) + # assume conv_state = (dim, state_len) + assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + - conv_state_shape, + (conv_state_shape[1], conv_state_shape[0]), dtype=dtype, - device="cuda") + device="cuda").transpose(-1, -2) temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + temporal_state_shape, dtype=dtype, diff --git a/vllm/model_executor/models/mimo.py b/vllm/model_executor/models/mimo.py index 9b83f848ef42..5b497dd9d89f 100644 --- a/vllm/model_executor/models/mimo.py +++ b/vllm/model_executor/models/mimo.py @@ -36,7 +36,6 @@ from vllm.distributed import get_pp_group from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) @@ -176,7 +175,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mimo_mtp.py b/vllm/model_executor/models/mimo_mtp.py index 6066ec76c5fc..19afc5be3fb8 100644 --- a/vllm/model_executor/models/mimo_mtp.py +++ b/vllm/model_executor/models/mimo_mtp.py @@ -30,7 +30,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -161,8 +160,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size) - self.sampler = get_sampler() - def forward( self, input_ids: torch.Tensor, @@ -187,14 +184,6 @@ def compute_logits( return self.model.compute_logits(hidden_states, self.lm_head, sampling_metadata, spec_step_idx) - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ diff --git a/vllm/model_executor/models/minicpmo.py b/vllm/model_executor/models/minicpmo.py index 71593d4bb896..4e4fc3d5c762 100644 --- a/vllm/model_executor/models/minicpmo.py +++ b/vllm/model_executor/models/minicpmo.py @@ -30,8 +30,10 @@ from torch import nn from transformers import BatchFeature, PretrainedConfig from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.models.whisper.modeling_whisper import ( - ACT2FN, WHISPER_ATTENTION_CLASSES, WhisperConfig, WhisperEncoder) +from transformers.models.whisper.modeling_whisper import (ACT2FN, + WhisperAttention, + WhisperConfig, + WhisperEncoder) from vllm.config import VllmConfig from vllm.model_executor.layers.quantization import QuantizationConfig @@ -378,14 +380,13 @@ class MiniCPMWhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig, layer_idx: int): super().__init__() self.embed_dim = config.d_model - self.self_attn = WHISPER_ATTENTION_CLASSES[ - config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, - dropout=config.attention_dropout, - config=config, - layer_idx=layer_idx, - ) + self.self_attn = WhisperAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + layer_idx=layer_idx, + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 87480796ae98..f2773af490c5 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -667,16 +667,24 @@ def __init__( eps=config.rms_norm_eps) if config.attention_type == 0: self.layernorm_attention_alpha = getattr( - config, 'layernorm_linear_attention_alpha', 1) + config, 'layernorm_linear_attention_alpha', + getattr(config, 'linear_attn_alpha_factor', 1)) self.layernorm_attention_beta = getattr( - config, 'layernorm_linear_attention_beta', 1) + config, 'layernorm_linear_attention_beta', + getattr(config, 'linear_attn_beta_factor', 1)) else: self.layernorm_attention_alpha = getattr( - config, 'layernorm_full_attention_alpha', 1) + config, 'layernorm_full_attention_alpha', + getattr(config, 'full_attn_alpha_factor', 1)) self.layernorm_attention_beta = getattr( - config, 'layernorm_full_attention_beta', 1) - self.layernorm_mlp_alpha = getattr(config, 'layernorm_mlp_alpha', 1) - self.layernorm_mlp_beta = getattr(config, 'layernorm_mlp_beta', 1) + config, 'layernorm_full_attention_beta', + getattr(config, 'full_attn_beta_factor', 1)) + self.layernorm_mlp_alpha = getattr( + config, 'layernorm_mlp_alpha', + getattr(config, 'mlp_alpha_factor', 1)) + self.layernorm_mlp_beta = getattr( + config, 'layernorm_mlp_beta', getattr(config, 'mlp_beta_factor', + 1)) self.postnorm = getattr(config, 'postnorm', False) self.shared_moe = False @@ -794,6 +802,18 @@ def __init__( self.decoder_attention_types = getattr( config, "attn_type_list", False) or getattr( config, "decoder_attention_types", False) + # The HF format uses "layer_types" instead of "attn_type_list" + # where "linear_attention" is 0 and "full_attention" is 1 + if not self.decoder_attention_types and hasattr(config, "layer_types"): + self.decoder_attention_types = [] + for layer_type in config.layer_types: + if layer_type == "linear_attention": + self.decoder_attention_types.append(0) + elif layer_type == "full_attention": + self.decoder_attention_types.append(1) + else: + raise ValueError(f"Unsupported layer type: {layer_type}") + # Default to full attention if not self.decoder_attention_types: self.decoder_attention_types = [1] * config.num_hidden_layers self.num_layers = config.num_hidden_layers @@ -1022,8 +1042,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: else: self.lm_head = PPMissingLayer() self.lm_head.float() - flash_layer_count = sum(1 for attn_type in self.config.attn_type_list - if attn_type == 1) + flash_layer_count = sum( + 1 for attn_type in self.model.decoder_attention_types + if attn_type == 1) self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] return @@ -1085,9 +1106,10 @@ def which_layer(name: str) -> int: return None def is_linear_attn_layer(layer_idx: int) -> bool: - if layer_idx is None or not hasattr(self.config, "attn_type_list"): + if layer_idx is None or layer_idx >= len( + self.model.decoder_attention_types): return False - return self.config.attn_type_list[layer_idx] == 0 + return self.model.decoder_attention_types[layer_idx] == 0 def is_moe_weight(name: str) -> bool: return "block_sparse_moe" in name and not name.endswith(".bias") @@ -1275,7 +1297,7 @@ def load_basic_weight(name: str, loaded_weight: torch.Tensor, for name, loaded_weight in weights: weight_at_layer = which_layer(name) if weight_at_layer and weight_at_layer >= len( - self.config.attn_type_list): + self.model.decoder_attention_types): continue if is_layer_norm_weight(name): diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index dec365119c72..30de83da49e0 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -317,6 +317,15 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -326,16 +335,9 @@ def load_weights(self, weights: Iterable[tuple[str, ("qkv_proj", "v_proj", "v"), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -486,3 +488,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 1276d626a7c3..dea85d320adf 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], } @classmethod @@ -902,32 +903,109 @@ def _consolidate_qkv_weights( qkv_weight = torch.cat(weight, dim=0) yield key, qkv_weight - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: + def _rename_weight_for_modelopt_checkpoint(self, name: str) -> str: + """Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM + format.""" + if name.startswith("model."): + # Handle expert scale parameters with flat naming + if "feed_forward.experts." in name and ("_input_scale" in name or + "_weight_scale" in name): + renamed = name.replace("model.", "language_model.model.", 1) + # Map checkpoint naming to vLLM's expected naming + if "down_proj_input_scale" in renamed: + return renamed.replace("down_proj_input_scale", + "w2_input_scale") + elif "down_proj_weight_scale" in renamed: + return renamed.replace("down_proj_weight_scale", + "w2_weight_scale") + elif "gate_up_proj_input_scale" in renamed: + return renamed.replace("gate_up_proj_input_scale", + "w13_input_scale") + elif "gate_up_proj_weight_scale" in renamed: + return renamed.replace("gate_up_proj_weight_scale", + "w13_weight_scale") + return renamed + + # Handle attention scale parameters + elif "self_attn." in name and (".k_scale" in name + or ".v_scale" in name): + renamed = name.replace("model.", "language_model.model.", 1) + if ".k_proj.k_scale" in renamed: + return renamed.replace(".k_proj.k_scale", ".attn.k_scale") + elif ".v_proj.v_scale" in renamed: + return renamed.replace(".v_proj.v_scale", ".attn.v_scale") + return renamed + + # Standard model.* to language_model.model.* renaming + return name.replace("model.", "language_model.model.", 1) + + elif name.startswith("lm_head.weight"): + return name.replace("lm_head.weight", + "language_model.lm_head.weight") + + return name + + def _separate_and_rename_weights( + self, weights: Iterable[tuple[str, torch.Tensor]] + ) -> tuple[list[tuple[str, torch.Tensor]], list[tuple[str, torch.Tensor]]]: + """Rename weights and separate them into language_model and other + weights.""" + language_model_weights = [] + other_weights = [] - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), - (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), - (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), - ] - params_dict = dict(self.named_parameters()) - updated_params: set[str] = set() + for name, weight in weights: + renamed = self._rename_weight_for_modelopt_checkpoint(name) - # language_model is an Llama4ForCausalLM instance. We load it's - # using llama4's load_weights routine. - language_model_weights, other_weights = self.separate_weights( - weights, prefix="language_model.") - loader = AutoWeightsLoader(self) - loaded_language_model_params = loader.load_weights( - language_model_weights) - assert loaded_language_model_params is not None - updated_params.update(loaded_language_model_params) + if renamed.startswith("language_model."): + language_model_weights.append((renamed, weight)) + else: + other_weights.append((renamed, weight)) + + return language_model_weights, other_weights + + def _handle_expert_scale_broadcasting( + self, weights: list[tuple[str, torch.Tensor]], params_dict: dict + ) -> tuple[list[tuple[str, torch.Tensor]], set[str]]: + """Handle expert scale parameters that need broadcasting. + + ModelOpt checkpoints use a single value tensor scalar for BMM style + experts, vLLM expects the scale to be broadcasted across all experts. + """ + regular_weights = [] + expert_scale_weights = [] + updated_params = set() + + for name, weight in weights: + # Check if this is an expert scale parameter that needs broadcasting + if ("feed_forward.experts." in name and "scale" in name + and ".shared_expert" not in name): + if name in params_dict: + param = params_dict[name] + if (hasattr(param, 'data') and param.data.numel() > 1 + and weight.numel() == 1): + # Broadcast single value to all experts + param.data.fill_(weight.item()) + updated_params.add(name) + continue + + expert_scale_weights.append((name, weight)) + else: + regular_weights.append((name, weight)) + + return regular_weights, expert_scale_weights, updated_params + + def _load_other_weights(self, other_weights: Iterable[tuple[str, + torch.Tensor]], + params_dict: dict, + stacked_params_mapping: list) -> set[str]: + """Load non-language-model weights with stacking support.""" + updated_params = set() if self.use_data_parallel: other_weights = self._consolidate_qkv_weights(other_weights) for name, loaded_weight in other_weights: + # Try stacked parameter mapping first for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name or self.use_data_parallel: continue @@ -938,10 +1016,56 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + # Use regular weight loading param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) updated_params.add(name) + + return updated_params + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".self_attn.qkv_proj", ".self_attn.q_proj", "q"), + (".self_attn.qkv_proj", ".self_attn.k_proj", "k"), + (".self_attn.qkv_proj", ".self_attn.v_proj", "v"), + # Shared expert gate_up_proj stacking + (".shared_expert.gate_up_proj", ".shared_expert.gate_proj", 0), + (".shared_expert.gate_up_proj", ".shared_expert.up_proj", 1), + # Feed forward gate_up_proj stacking (for non-MoE layers if any) + (".feed_forward.gate_up_proj", ".feed_forward.gate_proj", 0), + (".feed_forward.gate_up_proj", ".feed_forward.up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params: set[str] = set() + + # Separate and rename weights + language_model_weights, other_weights = ( + self._separate_and_rename_weights(weights)) + + # Handle expert scale parameters + regular_weights, expert_scale_weights, updated_params_from_experts = ( + self._handle_expert_scale_broadcasting(language_model_weights, + params_dict)) + updated_params.update(updated_params_from_experts) + + loader = AutoWeightsLoader(self) + loaded_language_model_params = loader.load_weights(regular_weights) + assert loaded_language_model_params is not None + updated_params.update(loaded_language_model_params) + + if expert_scale_weights: + loaded_expert_scale_params = loader.load_weights( + expert_scale_weights) + if loaded_expert_scale_params: + updated_params.update(loaded_expert_scale_params) + + updated_params.update( + self._load_other_weights(other_weights, params_dict, + stacked_params_mapping)) + return updated_params diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 9d619b38d38d..be1c3438d9db 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable -from typing import Optional +from collections.abc import Iterable, Set +from typing import Optional, Union import torch from torch import nn @@ -13,13 +13,18 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import ClassifierPooler +from vllm.model_executor.layers.pooler import (ClassifierPooler, + DispatchPooler, Pooler, + PoolingMethod, + PoolingParamsUpdate, + PoolingType) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.pooling_params import PoolingTask +from vllm.sequence import IntermediateTensors from .interfaces import SupportsCrossEncoding, SupportsV0Only from .utils import WeightsMapper, maybe_prefix @@ -252,10 +257,13 @@ def forward( return norm_outputs -class ModernBertPooler(nn.Module): +class ModernBertPooler(Pooler): def __init__(self, config: ModernBertConfig): super().__init__() + + pooling_type = PoolingType[config.classifier_pooling.upper()] + self.pooling = PoolingMethod.from_pooling_type(pooling_type) self.dense = nn.Linear(config.hidden_size, config.hidden_size, config.classifier_bias) self.pooling_type = config.classifier_pooling @@ -264,22 +272,35 @@ def __init__(self, config: ModernBertConfig): eps=config.norm_eps, bias=config.norm_bias) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - pooled_output = hidden_states - if self.pooling_type == "mean": - pooled_output = pooled_output.mean(dim=0, keepdim=False) - elif self.pooling_type == "cls": - pooled_output = pooled_output[0, :] + def get_supported_tasks(self) -> Set[PoolingTask]: + return self.pooling.get_supported_tasks() + + def get_pooling_updates(self, task: PoolingTask) -> PoolingParamsUpdate: + return self.pooling.get_pooling_updates(task) + + def _head(self, pooled_output: torch.Tensor): + return self.norm(self.act(self.dense(pooled_output))) + + def forward( + self, + hidden_states: Union[torch.Tensor, list[torch.Tensor]], + pooling_metadata: PoolingMetadata, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + pooled_output = self.pooling(hidden_states, pooling_metadata) + + if isinstance(pooled_output, list): + pooled_output = [self._head(output) for output in pooled_output] else: - raise ValueError("Pooling type should be either `cls` or `mean`, " - f"but got {self.pooling_type}") - pooled_output = self.norm(self.act(self.dense(pooled_output))) + pooled_output = self._head(pooled_output) + return pooled_output class ModernBertForSequenceClassification(nn.Module, SupportsV0Only, SupportsCrossEncoding): + is_pooling_model = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config @@ -287,9 +308,28 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = ModernBertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "modernbert")) self.classifier = nn.Linear(config.hidden_size, config.num_labels) - self._pooler = ClassifierPooler(vllm_config.model_config, - self.classifier, - ModernBertPooler(config)) + + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=ModernBertPooler(config), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=ModernBertPooler(config), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): @@ -318,13 +358,6 @@ def weight_filter(): default_weight_loader) weight_loader(param, loaded_weight) - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def forward( self, input_ids: Optional[torch.LongTensor], diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 5d51b01df9db..6a999e2254e7 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -25,8 +25,9 @@ from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation @@ -37,8 +38,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -154,7 +155,6 @@ def __init__( activation=config.mamba_hidden_act, quant_config=quant_config, prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size, ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -173,9 +173,9 @@ def forward( else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params, - mamba2_metadata) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata) + return output, residual class NemotronHAttention(nn.Module): @@ -293,6 +293,7 @@ def forward( } +@support_torch_compile class NemotronHModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -460,6 +461,38 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.n_groups, + num_heads=hf_config.mamba_num_heads, + head_dim=hf_config.mamba_head_dim, + state_size=hf_config.ssm_state_size, + conv_kernel=hf_config.conv_kernel, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config @@ -516,10 +549,13 @@ def forward(self, self.vllm_config.parallel_config, LayerBlockType.mamba ) - - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -535,39 +571,6 @@ def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = ( - self.config.n_groups + - extra_groups_for_head_shards(self.config.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.ssm_state_size) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.conv_kernel - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_num_heads, world_size), - self.config.mamba_head_dim, - self.config.ssm_state_size, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/nemotron_vl.py b/vllm/model_executor/models/nemotron_vl.py new file mode 100644 index 000000000000..5d0513d70741 --- /dev/null +++ b/vllm/model_executor/models/nemotron_vl.py @@ -0,0 +1,505 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from abc import ABC +from collections.abc import Iterable +from typing import Optional + +import torch +import torch.nn as nn +from PIL import Image +from transformers import AutoModel, PretrainedConfig +from transformers.image_processing_utils_fast import BaseImageProcessorFast + +from vllm.config import VllmConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.models.internvl import ( + BaseInternVLDummyInputsBuilder, BaseInternVLMultiModalProcessor, + BaseInternVLProcessingInfo, InternVLImageEmbeddingInputs, + InternVLImageInputs, InternVLImagePixelInputs, InternVLProcessor) +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.processing import PromptUpdateDetails +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import ( + cached_image_processor_from_config) +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + maybe_prefix, merge_multimodal_embeddings) + +IMG_START = '<img>' +IMG_END = '</img>' +IMG_CONTEXT = '<image>' + + +class NemotronVLProcessor(InternVLProcessor): + + def __init__( + self, + config: PretrainedConfig, + tokenizer: AnyTokenizer, + image_processor: BaseImageProcessorFast, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> None: + ABC.__init__(self) + self.config = config + self.tokenizer = tokenizer + self.image_processor = image_processor + image_size: int = config.force_image_size + patch_size: int = config.patch_size + + if min_dynamic_patch is None: + min_dynamic_patch = 1 + assert isinstance(min_dynamic_patch, int) + + if max_dynamic_patch is None: + max_dynamic_patch = self.image_processor.max_num_tiles + assert isinstance(max_dynamic_patch, int) + + if dynamic_image_size is None: + dynamic_image_size = True + assert isinstance(dynamic_image_size, bool) + + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.image_size = image_size + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail: bool = self.image_processor.use_thumbnail + + @property + def image_token_id(self) -> int: + return self.tokenizer.get_vocab()[IMG_CONTEXT] + + def _preprocess_image( + self, + text: list[str], + images: list[Image.Image], + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + ) -> tuple[list[str], dict[str, torch.Tensor]]: + if len(images) == 0: + image_inputs = {} + else: + pixel_values_lst = self._images_to_pixel_values_lst( + images, + min_dynamic_patch=min_dynamic_patch, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + image_inputs: dict[str, NestedTensors] = { + "pixel_values_flat": + torch.cat(pixel_values_lst), + "image_num_patches": + torch.tensor([len(item) for item in pixel_values_lst]), + } + + for pixel_values in pixel_values_lst: + num_patches = pixel_values.shape[0] + feature_size = num_patches * self.num_image_token + image_repl = self.get_image_repl(feature_size, num_patches) + NVL_IMAGE_CONTEXT = image_repl.full.replace( + "<image>", "<NVL_IMG_CONTEXT>") + text = [ + t.replace('<image>', NVL_IMAGE_CONTEXT, 1) for t in text + ] + text = [t.replace("<NVL_IMG_CONTEXT>", IMG_CONTEXT) for t in text] + return text, image_inputs + + def get_image_repl( + self, + feature_size: int, + num_patches: Optional[int], + ) -> PromptUpdateDetails[str]: + repl_features = IMG_CONTEXT * feature_size + repl_full = IMG_START + repl_features + IMG_END + + return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT) + + +class NemotronVLProcessingInfo(BaseInternVLProcessingInfo): + """Processing info for Nemotron VL models.""" + + def get_hf_processor( + self, + *, + min_dynamic_patch: Optional[int] = None, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, + **kwargs: object, + ) -> NemotronVLProcessor: + if min_dynamic_patch is not None: + kwargs["min_dynamic_patch"] = min_dynamic_patch + if max_dynamic_patch is not None: + kwargs["max_dynamic_patch"] = max_dynamic_patch + if dynamic_image_size is not None: + kwargs["dynamic_image_size"] = dynamic_image_size + + image_processor = self.get_image_processor() + return self.ctx.init_processor( + NemotronVLProcessor, + config=self.get_hf_config(), + tokenizer=self.get_tokenizer(), + image_processor=image_processor, + **kwargs, + ) + + def get_image_processor( + self, + **kwargs: object, + ): + return cached_image_processor_from_config( + self.ctx.model_config, + **kwargs, + ) + + +@MULTIMODAL_REGISTRY.register_processor( + BaseInternVLMultiModalProcessor[NemotronVLProcessingInfo], + info=NemotronVLProcessingInfo, + dummy_inputs=BaseInternVLDummyInputsBuilder[NemotronVLProcessingInfo]) +class LlamaNemotronVLChatModel(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + if modality.startswith("image"): + return "<image>" + + raise ValueError("Only image modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + self._patch_quant_config(config, quant_config) + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + self.llm_arch_name = config.text_config.architectures[0] + self.vision_model = self._init_vision_model( + config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "vision_model"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + + self.mlp1 = self._init_mlp1(config) + + self.img_context_token_id = None + + self.visual_token_mask = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + def _patch_quant_config(self, config: PretrainedConfig, + quant_config: QuantizationConfig): + # the awq models from OpenGVLab missing `modules_to_not_convert` + # patch the quant_config to add `modules_to_not_convert` back + if isinstance(quant_config, AWQConfig): + text_config = config.text_config + llm_quant_config = getattr(text_config, "quantization_config", + None) + if (not quant_config.modules_to_not_convert) and \ + (llm_quant_config is not None): + quant_config.modules_to_not_convert.append("vision_model") + + def _init_vision_model( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + prefix: str, + ): + return AutoModel.from_config(config.vision_config, + trust_remote_code=True) + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vit_hidden_size + vision_projection_hidden_size = config.projector_hidden_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2, + bias=True), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + vision_projection_hidden_size, + bias=True), + nn.GELU(), + nn.Linear(vision_projection_hidden_size, llm_hidden_size), + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + pass + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: + # https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/modeling.py#L177 + vit_embeds = self.vision_model(x=pixel_values).features + vit_embeds = vit_embeds.to(dtype=torch.bfloat16) + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + #use force_image_size to get image_size + h = w = self.config.force_image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[InternVLImageInputs]: + pixel_values_flat = kwargs.pop("pixel_values_flat", None) + image_num_patches = kwargs.pop("image_num_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values_flat is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return InternVLImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + image_token_id = kwargs["image_token_id"] + assert isinstance(image_token_id, torch.Tensor) + self.img_context_token_id = image_token_id.flatten().unique().item() + + if pixel_values_flat is not None: + if not isinstance(pixel_values_flat, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values_flat)}") + + if not isinstance(image_num_patches, (torch.Tensor, list)): + raise ValueError("Incorrect type of image_num_patches. " + f"Got type: {type(image_num_patches)}") + + pixel_values_flat = flatten_bn(pixel_values_flat, concat=True) + image_num_patches = flatten_bn(image_num_patches, concat=True) + + return InternVLImagePixelInputs( + type="pixel_values", + pixel_values_flat=self._validate_pixel_values( + pixel_values_flat), + num_patches=image_num_patches, + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: InternVLImageInputs, + ) -> tuple[torch.Tensor, ...]: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + + image_embeds = self.extract_feature(image_input["pixel_values_flat"]) + + num_patches = image_input["num_patches"] + + # Only one image in the current batch + if len(num_patches) == 1: + return (image_embeds.view(-1, + self.config.text_config.hidden_size), ) + + # NOTE: Image embeddings are split into separate tensors for each image + # by the size of each embedding. + feature_size = image_embeds.shape[1] + image_embeds = image_embeds.view(-1, + self.config.text_config.hidden_size) + image_feature_sizes = [ + num_patches * feature_size for num_patches in num_patches + ] + return image_embeds.split(image_feature_sizes) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values_flat", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + + return modalities + + def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None: + self.visual_token_mask = None + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings(self, + **kwargs: object) -> MultiModalEmbeddings: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return [] + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None \ + and len(multimodal_embeddings) != 0: + context_token_ids = [self.img_context_token_id] + assert len(context_token_ids) >= 1 + self._set_visual_token_mask(input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + context_token_ids, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> IntermediateTensors: + + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + vision_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None + + forward_kwargs = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + } + + # Only required if the model is mono-architecture + if self.visual_token_mask is not None: + forward_kwargs.update( + {"visual_token_mask": self.visual_token_mask}) + self.visual_token_mask = None + + hidden_states = self.language_model.model(**forward_kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + ## Ignore registered_buffers + ## see https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/input_conditioner.py#L28 # noqa: E501 + skip_substrs = ["norm_mean", "norm_std"] + loader = AutoWeightsLoader(self, skip_substrs=skip_substrs) + return loader.load_weights(weights) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp1", + tower_model="vision_model") diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index ebfdb690fe29..7552f64c423e 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -330,6 +330,15 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -341,16 +350,9 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -425,6 +427,17 @@ def load_weights(self, weights: Iterable[tuple[str, class OlmoeForCausalLM(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -466,3 +479,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 77197abe5710..b1f2e53b0c71 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -125,7 +125,7 @@ def _call_hf_processor( ) -> BatchFeature: tokenizer = self.info.get_tokenizer() if not mm_data: - prompt_ids = tokenizer.encode(prompt) + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt") return super()._call_hf_processor( diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py deleted file mode 100644 index 754ddda233f4..000000000000 --- a/vllm/model_executor/models/phi3_small.py +++ /dev/null @@ -1,465 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from collections.abc import Iterable -from typing import Optional, Union - -import torch -from torch import nn -from transformers.configuration_utils import PretrainedConfig - -from vllm.attention import Attention -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform -from vllm.sequence import IntermediateTensors - -from .interfaces import SupportsPP -from .utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, - maybe_prefix) - - -def load_column_parallel_weight(param: torch.nn.Parameter, - loaded_weight: torch.Tensor): - tp = get_tensor_model_parallel_world_size() - rk = get_tensor_model_parallel_rank() - assert param.size(0) * tp == loaded_weight.size(0) - s = rk * param.size(0) - e = (rk + 1) * param.size(0) - loaded_weight = loaded_weight[s:e] - assert param.shape == loaded_weight.shape - param.data.copy_(loaded_weight) - - -class HeadMajorQKVParallelLinear(QKVParallelLinear): - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor): - return load_column_parallel_weight(param, loaded_weight) - - -class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): - - def weight_loader(self, param: torch.nn.Parameter, - loaded_weight: torch.Tensor): - return load_column_parallel_weight(param, loaded_weight) - - -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def quick_gelu(x): - return x * torch.sigmoid(1.702 * x) - - -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) -def gegelu(input, limit: Optional[float] = None): - a_gelu, a_linear = input[..., ::2], input[..., 1::2] - if limit is not None: - a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, - a_gelu.clamp(min=None, max=limit)) - a_linear = torch.where( - torch.isinf(a_linear), - a_linear, - a_linear.clamp(min=-limit, max=limit), - ) - out_gelu = quick_gelu(a_gelu) - return out_gelu * (a_linear + 1) - - -class Phi3SmallMLP(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = config - assert (self.config.hidden_act == "gegelu" - ), "Only `gegelu` is supported for the 4.7 series of models .." - self.hidden_size = config.hidden_size - self.gegelu_limit = config.gegelu_limit - self.intermediate_size = config.intermediate_size - - self.up_proj = HeadMajorColumnParallelLinear( - self.hidden_size, - 2 * [self.intermediate_size], - bias=True, - quant_config=quant_config, - ) - self.down_proj = RowParallelLinear( - self.intermediate_size, - self.hidden_size, - bias=True, - quant_config=quant_config, - ) - - def forward(self, x): - gate_up, _ = self.up_proj(x) - x = gegelu(gate_up) - x, _ = self.down_proj(x) - return x - - -class Phi3SmallSelfAttention(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__() - self.layer_idx = layer_idx - self.config = config - self.sparse_block_size = config.blocksparse_block_size - self.homo_heads = config.blocksparse_homo_head_pattern - self.local_blocks = config.blocksparse_num_local_blocks - self.vert_stride = config.blocksparse_vert_stride - - assert (config.blocksparse_block_size == - config.blocksparse_triton_kernel_block_size) - - self.hidden_size = config.hidden_size - # Number of Query Heads - self.num_heads = config.num_attention_heads - - self.head_dim = self.hidden_size // self.num_heads - self.tp_size = get_tensor_model_parallel_world_size() - # Number of total Key Value Heads before tensor parallel - self.num_key_value_heads = config.num_key_value_heads - self.num_q_per_kv = self.num_heads // self.num_key_value_heads - if self.tp_size > 1: - assert self.num_key_value_heads % self.tp_size == 0 - self.num_kv_heads_per_partition = max( - 1, self.num_key_value_heads // self.tp_size) - self.num_heads_per_partition = self.num_heads // self.tp_size - - self.max_position_embeddings = config.max_position_embeddings - self.rope_embedding_base = config.rope_embedding_base - self.rope_position_scale = config.rope_position_scale - self.is_causal = True - - norm_factor = None - if config.mup_use_scaling: - norm_factor = self.head_dim / config.mup_attn_multiplier - else: - norm_factor = math.sqrt(self.head_dim) - self.scale = 1 / norm_factor - - self.query_key_value = HeadMajorQKVParallelLinear( - self.hidden_size, - self.head_dim, - self.num_heads, - self.num_key_value_heads, - bias=True, - quant_config=quant_config, - ) - - self.dense = RowParallelLinear(self.hidden_size, - self.hidden_size, - bias=True, - quant_config=quant_config) - - if getattr(self.config, "rope_scaling", None) is not None: - rope_scaling = self.config.rope_scaling - for key in rope_scaling: - if isinstance(rope_scaling[key], list): - rope_scaling[key] = tuple(rope_scaling[key]) - - if "factor" not in rope_scaling: - rope_scaling["factor"] = self.rope_position_scale - else: - rope_scaling = { - "rope_type": "linear", - "factor": self.rope_position_scale, - } - - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_embedding_base, - rope_scaling=rope_scaling, - ) - - # blocksparse params - self.blocksparse_block_size = config.blocksparse_block_size - self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks - self.blocksparse_vert_stride = config.blocksparse_vert_stride - - use_dense_attn = (getattr(self.config, - "dense_attention_every_n_layers", None) - and (self.layer_idx + 1) % - self.config.dense_attention_every_n_layers == 0) - - bs_params = None - if not use_dense_attn: - bs_params = { - 'max_seqlen': self.max_position_embeddings, - 'num_heads': self.num_heads_per_partition, - "num_kv_heads": self.num_kv_heads_per_partition, - "block_size": self.sparse_block_size, - "local_blocks": self.local_blocks, - "vert_stride": self.vert_stride, - "homo_head": self.homo_heads - } - - self.attn = Attention(self.num_heads_per_partition, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads_per_partition, - cache_config=cache_config, - quant_config=quant_config, - blocksparse_params=bs_params, - prefix=f"{prefix}.attn") - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], - Optional[tuple[torch.Tensor]]]: - qkv, _ = self.query_key_value(hidden_states) - - qkv = qkv.view(qkv.shape[:-1] + - (-1, (self.num_q_per_kv + 2), self.head_dim)) - q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) - - # NOTE: this is required by RotaryEmbed, which indeed does not have to - # TODO: allow 3D QK for rotary forward - q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) - k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partition) - v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partition) - - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v) - output, _ = self.dense(attn_output) - - return output - - -class Phi3SmallDecoderLayer(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - layer_idx: int, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = Phi3SmallSelfAttention(config, - layer_idx, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn") - self.mlp = Phi3SmallMLP(config, quant_config) - - self.input_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.post_attention_layernorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_epsilon) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class Phi3SmallModel(nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.config = config - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) - self.mup_embedding_multiplier = config.mup_embedding_multiplier - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: Phi3SmallDecoderLayer(config, - int(prefix.split('.')[-1]), - cache_config, - quant_config, - prefix=prefix), - prefix=f"{prefix}.layers") - - self.final_layernorm = nn.LayerNorm(config.hidden_size, - eps=config.layer_norm_epsilon) - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - intermediate_tensors: Optional[IntermediateTensors], - inputs_embeds: Optional[torch.Tensor], - ) -> Union[torch.Tensor, IntermediateTensors]: - if get_pp_group().is_first_rank: - if inputs_embeds is not None: - hidden_states = inputs_embeds - else: - hidden_states = self.get_input_embeddings(input_ids) - if (self.mup_embedding_multiplier is not None - and self.mup_embedding_multiplier > 0.0): - hidden_states = hidden_states * self.mup_embedding_multiplier - else: - assert intermediate_tensors - hidden_states = intermediate_tensors["hidden_states"] - for layer in self.layers[self.start_layer:self.end_layer]: - hidden_states = layer(positions, hidden_states) - if not get_pp_group().is_last_rank: - return IntermediateTensors({"hidden_states": hidden_states}) - hidden_states = self.final_layernorm(hidden_states) - return hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class Phi3SmallForCausalLM(nn.Module, SupportsPP): - _tied_weights_keys = ["lm_head.weight"] - - hf_to_vllm_mapper = WeightsMapper( - orig_to_new_suffix={"rotary_emb.inv_freq": None}) - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - self.model = Phi3SmallModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) - self.vocab_size = config.vocab_size - self.mup_width_multiplier = config.mup_width_multiplier - self.lm_head = ParallelLMHead( - self.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - ) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # tokens in tiktoken but not used - if hasattr(config, 'dummy_token_indices'): - device = self.lm_head.weight.device - self.register_buffer('dummy_token_indices', - torch.LongTensor( - config.dummy_token_indices).to(device), - persistent=False) - else: - self.dummy_token_indices = None - - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.get_input_embeddings(input_ids) - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, value): - self.lm_head = value - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - if self.dummy_token_indices is not None and logits is not None: - logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) - logits = logits / self.mup_width_multiplier - return logits - - def forward( - self, - input_ids: torch.LongTensor, - positions: Optional[torch.LongTensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, IntermediateTensors]: - output_hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - output_hidden_states = output_hidden_states - return output_hidden_states - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader( - self, - skip_prefixes=(["lm_head.weight"] - if self.config.tie_word_embeddings else None)) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phi4flash.py b/vllm/model_executor/models/phi4flash.py new file mode 100644 index 000000000000..a4ded2b7a304 --- /dev/null +++ b/vllm/model_executor/models/phi4flash.py @@ -0,0 +1,736 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable +from typing import Optional, Union + +import torch +import torch.nn as nn +from transformers.activations import ACT2FN + +import vllm.envs as envs +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.selector import _Backend +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, + SupportsV0Only) +from vllm.model_executor.models.mamba_cache import (MambaCacheManager, + MambaCacheParams) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import make_layers, maybe_prefix + +logger = init_logger(__name__) + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +class SambaYMLP(nn.Module): + """Gated Linear Unit. + + Reference: + Language Modeling with Gated Convolutional Networks. + https://arxiv.org/pdf/1612.08083v3.pdf. + + """ + + def __init__(self, config): + super().__init__() + + self.config = config + self.fc1 = nn.Linear(config.hidden_size, + 2 * config.intermediate_size, + bias=False) + self.fc2 = nn.Linear(config.intermediate_size, + config.hidden_size, + bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + y = self.fc1(hidden_states) + gate, y = y.chunk(2, dim=-1) + y = y * self.activation_fn(gate) + return self.fc2(y) + + +def get_virtual_engine(): + forward_context: ForwardContext = get_forward_context() + return forward_context.virtual_engine + + +class SambaYAttention(nn.Module): + + def __init__(self, + config, + layer_idx: Optional[int] = None, + yoco_cross: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = ""): + super().__init__() + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing " + "a `layer_idx` is not recommended and will lead to errors " + "during the forward call if caching is used. Please make " + "sure to provide a `layer_idx` when creating this class.") + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.yoco_cross = yoco_cross + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError("hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size} and " + f"`num_heads`: {self.num_heads}).") + + op_size = self.num_heads * self.head_dim + 2 * ( + self.num_key_value_heads * self.head_dim) + self.out_proj = nn.Linear(self.num_heads * self.head_dim, + self.hidden_size, + bias=True) + if yoco_cross: + self.Wqkv = nn.Linear(self.hidden_size, + self.num_heads * self.head_dim, + bias=True) + else: + self.Wqkv = nn.Linear(self.hidden_size, op_size, bias=True) + + # disable sliding window for the second half of the model + sliding_window = config.interleaved_sliding_window[layer_idx] + if layer_idx >= config.num_hidden_layers // 2: + assert sliding_window is None, \ + "sliding_window must be none for the second decoder" + else: + assert sliding_window is not None, \ + "sliding_window must be set for the first decoder" + + assert self.num_heads % 2 == 0, 'num_heads should be even' + assert self.num_key_value_heads % 2 == 0, 'num_heads should be even' + + self.lambda_init = self.lambda_init_fn(layer_idx) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0, + std=0.1)) + self.subln = nn.RMSNorm(2 * self.head_dim, + eps=1e-5, + elementwise_affine=True) + + params = { + 'differential_flash_attention_config': { + 'lambda_init': self.lambda_init, + 'lambda_q1': self.lambda_q1, + 'lambda_k1': self.lambda_k1, + 'lambda_q2': self.lambda_q2, + 'lambda_k2': self.lambda_k2, + "subln": self.subln, + } + } + + if yoco_cross: + kv_shared_layer_index = config.num_hidden_layers // 2 + 1 + kv_sharing_target_layer_name = \ + f"model.layers.{kv_shared_layer_index}.self_attn.attn" + else: + kv_sharing_target_layer_name = None + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.head_dim**-0.5, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + **params) + assert self.attn.backend == _Backend.DIFFERENTIAL_FLASH_ATTN,\ + "DIFFERENTIAL_FLASH_ATTN required" + + def lambda_init_fn(self, depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + def forward( + self, + hidden_states: torch.Tensor, + ): + + if not self.yoco_cross: # need to generate kv-cache + qkv = self.Wqkv(hidden_states) + q, k, v = qkv.split([ + self.hidden_size, self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim + ], + dim=-1) + attn_output = self.attn(q, k, v) + else: # reuse the kv cache, full attention + q = self.Wqkv(hidden_states) + attn_output = self.attn(q, None, None) + attn_output = attn_output.view(-1, self.num_heads * self.head_dim) + return self.out_proj(attn_output) + + +class Phi4Mamba(nn.Module): + + def __init__( + self, + d_model, + d_state=16, + d_conv=4, + expand=2, + dt_rank="auto", + dt_min=0.001, + dt_max=0.1, + dt_init="random", # difference + dt_scale=1.0, # difference + dt_init_floor=1e-4, + conv_bias=True, + bias=False, + use_fast_path=True, # Fused kernel options + layer_idx=None, + device=None, + dtype=None, + yoco_cross=False, + yoco_kv=False, + ): + factory_kwargs = {"params_dtype": dtype} # difference + super().__init__() + self.yoco_cross = yoco_cross + self.yoco_kv = yoco_kv + self.d_model = d_model + self.d_state = d_state + self.d_conv = d_conv + self.expand = expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / + 16) if dt_rank == "auto" else dt_rank + self.use_fast_path = use_fast_path + self.layer_idx = layer_idx + self.swiGluActivation = SwiGLUActivation() + if self.yoco_cross: + self.in_proj = MergedColumnParallelLinear(self.d_model, + [self.d_inner], + bias=bias, + **factory_kwargs) + self.out_proj = RowParallelLinear(self.d_inner, + self.d_model, + bias=bias, + **factory_kwargs) + return + self.conv1d = ColumnParallelLinear( + input_size=d_conv, + output_size=self.d_inner, + bias=conv_bias, + params_dtype=dtype, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear( + self.d_model, + [self.d_inner] * 2, + bias=bias, + params_dtype=dtype, + ) + + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.d_inner, + self.dt_rank + self.d_state * 2, + bias=False, + params_dtype=dtype, + ) + + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear( + self.dt_rank, + self.d_inner, + bias=True, + skip_bias_add=True, + params_dtype=dtype, + ) + + # # D "skip" parameter + # self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 + self.A = nn.Parameter( + torch.empty( + self.d_inner, + self.d_state, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.d_inner, dtype=torch.float32)) + + self.out_proj = RowParallelLinear( + self.d_inner, + self.d_model, + bias=bias, + input_is_parallel=True, + params_dtype=dtype, + ) + self.activation = "silu" + + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + yoco_key_values=None) -> torch.Tensor: + + if self.yoco_cross: + out = self.in_proj(hidden_states)[0] + out = self.swiGluActivation(yoco_key_values, out) + out = self.out_proj(out) + return out[0], yoco_key_values + + # 1. Gated MLP's linear projection + # projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + projected_states = self.in_proj( + hidden_states.to(self.in_proj.weight.dtype))[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=mamba_cache_params.conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + cache_indices=mamba_cache_params.state_indices_tensor, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + mamba_cache_params.conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=mamba_cache_params.state_indices_tensor) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.dt_rank, self.d_state, self.d_state], + dim=-1, + ) + + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + mamba_cache_params.ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + # z, + None if self.yoco_kv else gate, + time_proj_bias, + delta_softplus=True, + cache_indices=mamba_cache_params.state_indices_tensor, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + mamba_cache_params.ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + # z + # gate.transpose(0, 1), + None if self.yoco_kv else gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + state_batch_indices=mamba_cache_params.state_indices_tensor) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + if self.yoco_kv: + # gate = gate.transpose(-1,-2).contiguous() + yoco_key_values = scan_outputs.transpose(-2, -1) + scan_outputs = self.swiGluActivation(scan_outputs, gate) + + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + + return contextualized_states, yoco_key_values + + +class SambaYDecoderLayer(nn.Module): + + def __init__( + self, + config, + layer_idx, + cache_config, + prefix: str = "", + ) -> None: + super().__init__() + + self.config = config + self.layer_idx = layer_idx + + self.mlp = SambaYMLP(config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + self.yoco_mb = False + self.yoco_cross = False + if layer_idx >= config.num_hidden_layers // 2: + self.yoco_mb = True + self.yoco_cross = (layer_idx + >= (config.num_hidden_layers // 2 + 2)) + self.use_mamba = config.mb_per_layer > 0 and \ + layer_idx % config.mb_per_layer == 0 + if self.use_mamba: + factory_kwargs = {"dtype": None} + self.attn = Phi4Mamba(config.hidden_size, + layer_idx=layer_idx, + yoco_cross=self.yoco_cross, + yoco_kv=self.yoco_mb, + **factory_kwargs) + else: + self.attn = SambaYAttention(config, + layer_idx=layer_idx, + yoco_cross=self.yoco_cross, + cache_config=cache_config, + prefix=f"{prefix}.self_attn") + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + ssm_output: Optional[torch.LongTensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.use_mamba: + assert mamba_cache_params is not None + else: + assert mamba_cache_params is None + + residual = hidden_states + hidden_states = self.input_layernorm( + hidden_states.to(dtype=self.input_layernorm.weight.dtype)) + + if self.use_mamba: + attn_outputs, ssm_output = self.attn(hidden_states, + attn_metadata, + mamba_cache_params, + yoco_key_values=ssm_output) + residual = residual.to(torch.float32) + else: + attn_outputs = self.attn(hidden_states, ) + hidden_states = residual + attn_outputs + residual = hidden_states + hidden_states = self.post_attention_layernorm( + hidden_states.to(dtype=self.post_attention_layernorm.weight.dtype)) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, ssm_output + + +class SambaYModel(nn.Module): + + def __init__(self, + config, + cache_config=None, + quant_config=None, + lora_config=None, + prefix: str = "") -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + # Pipeline parallel is not supported since the second half of + # the layers share the kv cache. + if get_pp_group().world_size != 1: + raise ValueError("Pipeline Parallel not supported") + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SambaYDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, + prefix=prefix), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + mamba_cache_params: MambaCacheParams, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + mamba_state_idx = 0 + ssm_output = None + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + if i == self.config.num_hidden_layers // 2 + 2: + # profile run + kv_cache_idx = self.config.num_hidden_layers // 2 + 1 + cache_layer = self.layers[kv_cache_idx] + kv_cache = cache_layer.attn.attn.kv_cache + if kv_cache[0].numel() == 0: + break + + # Starting from this layer, we do not need to calculate + # the kv cache since we reuse the kv cache from last layer. + # If in prefill phase, we can <s>prune></s> truncate + # the hidden state to save computation cost. + if attn_metadata.prefill_metadata and not envs.VLLM_USE_V1: + selected_token_indices = torch.cumsum( + attn_metadata.seq_lens_tensor, dim=0) - 1 + hidden_states = hidden_states.index_select( + 0, selected_token_indices) + ssm_output = ssm_output.index_select( + 0, selected_token_indices) + + if layer.use_mamba: + if i < self.config.num_hidden_layers // 2 or \ + not layer.yoco_cross: + mamba_cache = mamba_cache_params.at_layer_idx( + mamba_state_idx) + mamba_state_idx += 1 + else: + mamba_cache = mamba_cache_params.at_layer_idx( + mamba_state_idx - 1) + + hidden_states, ssm_output = layer(hidden_states, + positions, + attn_metadata, + mamba_cache, + ssm_output=ssm_output) + else: + hidden_states, ssm_output = layer( + hidden_states, + positions, + attn_metadata, + None, # mamba_cache_params + ssm_output=ssm_output) + + hidden_states = self.final_layernorm( + hidden_states.to(dtype=self.final_layernorm.weight.dtype)) + return hidden_states + + +class Phi4FlashForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + lora_config = vllm_config.lora_config + quant_config = vllm_config.quant_config + scheduler_config = vllm_config.scheduler_config + self.compilation_config = vllm_config.compilation_config + self.vllm_config = vllm_config + # Prefix caching and chunked prefill is not supported for this model. + assert not cache_config.enable_prefix_caching, \ + "Phi4flash currently does not support prefix caching" + assert not scheduler_config.chunked_prefill_enabled, \ + "Phi4Flash currently does not support prefix caching" + super().__init__() + self.config = config + self.model_config = vllm_config.model_config + self.scheduler_config = scheduler_config + self.model = SambaYModel(config, + cache_config=cache_config, + prefix=maybe_prefix(prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + self.embedding_bias = None + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logits_as_input=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + if self.mamba_cache is None: + num_mamba_layers = self.config.num_hidden_layers \ + // 2 // self.config.mb_per_layer + 1 + self.mamba_cache = MambaCacheManager( + self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, + *self._get_mamba_cache_shape()) + mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) + + attn_metadata = get_forward_context().attn_metadata + # input_ids and hidden_states isn't a one-to-one mapping in prefill + # stage due to YOCO optimization. + hidden_states = self.model(input_ids, positions, attn_metadata, + mamba_cache_params, intermediate_tensors, + inputs_embeds) + return hidden_states + + def _get_mamba_cache_shape( + self + ) -> tuple[Optional[tuple[int, int]], Optional[tuple[int, int]]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + mamba_expand = self.config.mamba_expand # 2 + mamba_d_conv = self.config.mamba_d_conv # 4 + mamba_d_state = self.config.mamba_d_state # 16 + conv_state_shape = ( + mamba_expand * hidden_size // world_size, + mamba_d_conv - 1, + ) + temporal_state_shape = ( + mamba_expand * hidden_size // world_size, + mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + # If the shape is the same, it means that we have already + # prune hidden states manually. + prune_hidden_states = hidden_states.size( + 0) != sampling_metadata.selected_token_indices.size(0) + processed_logits = self.logits_processor( + self.lm_head, + hidden_states, + sampling_metadata, + self.embedding_bias, + prune_hidden_states=prune_hidden_states) + return processed_logits + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]], + ): + weights = {name: weight for name, weight in weights} + adjusted_weights = {} + for name, weight in weights.items(): + if "A_log" in name: + name = name.replace("A_log", "A") + weight = -torch.exp(weight.float()) + if "inner_cross_attn." in name: + name = name.replace("inner_cross_attn.", "") + adjusted_weights[name] = weight + adjusted_weights["lm_head.weight"] = weights[ + "model.embed_tokens.weight"] + loaded_params: set[str] = set() + for name, param in self.named_parameters(): + weight = adjusted_weights.get(name) + if weight is not None and weight.shape != param.shape: + logger.warning("Shape mismatch: %s %s %s", name, weight.shape, + param.shape) + loaded_params.add(name) + missing_keys, unexpected_keys = self.load_state_dict(adjusted_weights, + strict=False) + assert len(unexpected_keys) == 0, f"Unexpected keys: {unexpected_keys}" + assert len(missing_keys) == 0, f"Missing keys: {missing_keys}" + return loaded_params diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 2ab4edc18ccf..cfe0982204fa 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -516,6 +516,14 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -525,14 +533,9 @@ def load_weights(self, weights: Iterable[tuple[str, ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="w1", - ckpt_down_proj_name="w2", - ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if (self.quant_config is not None and (scale_name := self.quant_config.get_cache_scale(name))): @@ -672,3 +675,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/prithvi_geospatial_mae.py b/vllm/model_executor/models/prithvi_geospatial_mae.py index a36f24bc80ec..0f00fd47fe4f 100644 --- a/vllm/model_executor/models/prithvi_geospatial_mae.py +++ b/vllm/model_executor/models/prithvi_geospatial_mae.py @@ -16,6 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only IBM/NASA Prithvi Geospatial model.""" + from collections.abc import Iterable, Mapping, Sequence from typing import Optional, Union @@ -24,21 +25,22 @@ from transformers import BatchFeature from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import (AllPool, PoolerHead, + PoolerIdentity, SimplePooler) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import (IsAttentionFree, - SupportsMultiModal, - SupportsV0Only) +from vllm.model_executor.models.interfaces import ( + IsAttentionFree, MultiModalEmbeddings, SupportsMultiModalWithRawInput) from vllm.model_executor.models.utils import AutoWeightsLoader -from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, - MultiModalInputs, MultiModalKwargs) + MultiModalFieldElem, MultiModalInputs, + MultiModalKwargs, MultiModalKwargsItem, + MultiModalSharedField, PlaceholderRange) from vllm.multimodal.parse import MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder -from vllm.sequence import (IntermediateTensors, PoolerOutput, - PoolingSequenceGroupOutput) +from vllm.sequence import IntermediateTensors class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo): @@ -62,8 +64,9 @@ def get_dummy_mm_data( # The size of pixel_values might change in the cases where we resize # the input but never exceeds the dimensions below. return { - "pixel_values": torch.full((1, 6, 512, 512), 1.0), - "location_coords": torch.full((1, 2), 1.0), + "pixel_values": torch.full((6, 512, 512), 1.0, + dtype=torch.float16), + "location_coords": torch.full((1, 2), 1.0, dtype=torch.float16), } @@ -75,8 +78,10 @@ def _get_mm_fields_config( hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - location_coords=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), + location_coords=MultiModalFieldConfig.shared(batch_size=1, + modality="image"), ) def _get_prompt_updates( @@ -99,24 +104,51 @@ def apply( for k, v in mm_data.items(): mm_kwargs[k] = v + mm_placeholders = {"image": [PlaceholderRange(offset=0, length=0)]} + + # This model receives in input a multi-dimensional tensor representing + # a single image patch and therefore it is not to be split + # into multiple elements, but rather to be considered a single one. + # Hence, the decision of using a MultiModalSharedField. + # The expected shape is (num_channels, width, height). + + # This model however allows the user to also submit multiple image + # patches as a batch, adding a further dimension to the above shape. + # At this stage we only support submitting one patch per request and + # batching is achieved via vLLM batching. + # TODO (christian-pinto): enable support for multi patch requests + # in tandem with vLLM batching. + multimodal_kwargs_items = [ + MultiModalKwargsItem.from_elems([ + MultiModalFieldElem( + modality="image", + key=key, + data=data, + field=MultiModalSharedField(1), + ) for key, data in mm_kwargs.items() + ]) + ] return MultiModalInputs( type="multimodal", prompt=prompt, prompt_token_ids=[1], - mm_kwargs=MultiModalKwargs(mm_kwargs), + mm_kwargs=MultiModalKwargs.from_items(multimodal_kwargs_items), mm_hashes=None, - mm_placeholders={}, + mm_placeholders=mm_placeholders, ) @MULTIMODAL_REGISTRY.register_processor( PrithviGeoSpatialMAEMultiModalProcessor, info=PrithviGeoSpatialMAEProcessingInfo, - dummy_inputs=PrithviGeoSpatialMAEInputBuilder) -class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal, - SupportsV0Only): - """ Prithvi Masked Autoencoder""" + dummy_inputs=PrithviGeoSpatialMAEInputBuilder, +) +class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, + SupportsMultiModalWithRawInput): + """Prithvi Masked Autoencoder""" + + is_pooling_model = True @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -126,10 +158,10 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only image modality is supported") def _instantiate_model(self, config: dict) -> Optional[nn.Module]: - # We might be able/need to support different tasks with this same model if config["task_args"]["task"] == "SemanticSegmentationTask": from terratorch.cli_tools import SemanticSegmentationTask + task = SemanticSegmentationTask( config["model_args"], config["task_args"]["model_factory"], @@ -142,7 +174,8 @@ def _instantiate_model(self, config: dict) -> Optional[nn.Module]: scheduler_hparams=config["scheduler_params"], plot_on_val=config["task_args"]["plot_on_val"], freeze_decoder=config["task_args"]["freeze_decoder"], - freeze_backbone=config["task_args"]["freeze_backbone"]) + freeze_backbone=config["task_args"]["freeze_backbone"], + ) return task.model else: @@ -162,14 +195,14 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = ""): "Only SemanticSegmentationTask is supported for now " "by PrithviGeospatialMAE.") + self.pooler = SimplePooler(AllPool(), PoolerHead(PoolerIdentity())) + def _parse_and_validate_multimodal_data( self, **kwargs) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - pixel_values = kwargs.pop("pixel_values", None) if not isinstance(pixel_values, torch.Tensor): raise ValueError(f"Incorrect type of pixel_values. " f"Got type: {type(pixel_values)}") - pixel_values = torch.unbind(pixel_values, dim=0)[0] location_coords = kwargs.pop("location_coords", None) if not isinstance(location_coords, torch.Tensor): @@ -181,6 +214,17 @@ def _parse_and_validate_multimodal_data( return pixel_values, location_coords + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + # We do not really use any input tokens and therefore no embeddings + # to be calculated. However, due to the mandatory token ids in + # the input prompt we pass one token and the size of the dummy + # embedding tensors must reflect that. + return torch.empty((input_ids.shape[0], 0)) + def forward( self, input_ids: Optional[torch.Tensor], @@ -189,7 +233,6 @@ def forward( inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ): - pixel_values, location_coords = ( self._parse_and_validate_multimodal_data(**kwargs)) model_output = self.model(pixel_values, @@ -197,13 +240,6 @@ def forward( return model_output.output - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return PoolerOutput([PoolingSequenceGroupOutput(hidden_states)]) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: params_list = [] diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7ef9d248da4b..23f65b99c22c 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -50,7 +50,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, @@ -496,6 +495,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -Qwen2ForSequenceClassification = as_seq_cls_model(Qwen2ForCausalLM) diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 377a34f2088a..c5a5c10d9509 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -144,8 +144,16 @@ def get_hf_processor( ) -> Qwen2_5OmniProcessor: if fps is not None: kwargs["fps"] = fps + + # Monkey patch for Transformers v4.53 + processor_class = Qwen2_5OmniProcessor + if processor_class.image_processor_class != "AutoImageProcessor": + processor_class.image_processor_class = "AutoImageProcessor" + if processor_class.video_processor_class != "AutoVideoProcessor": + processor_class.video_processor_class = "AutoVideoProcessor" + processor = self.ctx.get_hf_processor( - Qwen2_5OmniProcessor, + processor_class, image_processor=self.get_image_processor(min_pixels=min_pixels, max_pixels=max_pixels, size=size, diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 42a87c4a796e..8ae096536fdc 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -974,7 +974,7 @@ def _process_image_input( grid_thw_list = grid_thw.tolist() if image_input["type"] == "image_embeds": - image_embeds = image_input["image_embeds"] + image_embeds = image_input["image_embeds"].type(self.visual.dtype) else: pixel_values = image_input["pixel_values"] image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list) @@ -994,7 +994,7 @@ def _process_video_input( grid_thw_list = grid_thw.tolist() if video_input["type"] == "video_embeds": - video_embeds = video_input["video_embeds"] + video_embeds = video_input["video_embeds"].type(self.visual.dtype) else: pixel_values_videos = video_input["pixel_values_videos"] video_embeds = self.visual(pixel_values_videos, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index a2c65f4b5edb..b061e2f69a6c 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -391,6 +391,15 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -402,16 +411,9 @@ def load_weights(self, weights: Iterable[tuple[str, ("gate_up_proj", "up_proj", 1), ] - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -446,6 +448,7 @@ def load_weights(self, weights: Iterable[tuple[str, if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue @@ -490,9 +493,20 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Qwen2MoeForCausalLM(nn.Module, SupportsPP): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): fall_back_to_pt_during_load = False + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -538,3 +552,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py index 9a8508081678..f12e9a041a94 100644 --- a/vllm/model_executor/models/qwen2_rm.py +++ b/vllm/model_executor/models/qwen2_rm.py @@ -15,9 +15,9 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler, + PoolingType) +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2Model @@ -25,6 +25,10 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP): + + is_pooling_model = True + pooler: Pooler + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -61,7 +65,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config=quant_config, return_bias=False), ) - self._pooler: SimplePooler self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) @@ -80,13 +83,6 @@ def forward( logits = self.score(hidden_states) return logits - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, @@ -96,27 +92,33 @@ def load_weights(self, weights: Iterable[tuple[str, class Qwen2ForRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config, prefix=""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 1 super().__init__(vllm_config=vllm_config, prefix=prefix) + pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.ALL, - normalize=False, - softmax=False) + assert pooler_config is not None + + self.pooler = DispatchPooler( + {"encode": Pooler.for_encode(pooler_config)}, ) class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel): - def __init__(self, *, vllm_config, prefix=""): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config.model_config.hf_config.num_labels = 2 super().__init__(vllm_config=vllm_config, prefix=prefix) + pooler_config = vllm_config.model_config.pooler_config - self._pooler = Pooler.from_config_with_defaults( - pooler_config, - pooling_type=PoolingType.STEP, - normalize=False, - softmax=True, - step_tag_id=151651, - ) + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode( + pooler_config, + default_pooling_type=PoolingType.STEP, + default_normalize=False, + default_softmax=True, + default_step_tag_id=151651, + ) + }) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 41b38b855ebf..ad63bb4af4e9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -823,10 +823,11 @@ def get_image_processor( def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} - def get_max_tokens_per_item( - self, seq_len: int, - mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]: - + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: max_image_tokens = self.get_max_image_tokens() max_video_tokens = self.get_max_video_tokens(seq_len, mm_counts) return {"image": max_image_tokens, "video": max_video_tokens} diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py index de99a76f2897..393ce41a91a0 100644 --- a/vllm/model_executor/models/qwen3.py +++ b/vllm/model_executor/models/qwen3.py @@ -44,7 +44,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .adapters import as_seq_cls_model from .interfaces import SupportsLoRA, SupportsPP from .qwen2 import Qwen2MLP as Qwen3MLP from .qwen2 import Qwen2Model @@ -320,6 +319,3 @@ def load_weights(self, weights: Iterable[tuple[str, if self.config.tie_word_embeddings else None), ) return loader.load_weights(weights) - - -Qwen3ForSequenceClassification = as_seq_cls_model(Qwen3ForCausalLM) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index ff182aadf738..12899c28016b 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -50,7 +50,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsPP +from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, @@ -375,6 +375,15 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ @@ -391,16 +400,9 @@ def load_weights(self, weights: Iterable[tuple[str, ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale") - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = FusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.num_experts) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for (param_name, weight_name, shard_id) in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). @@ -480,7 +482,7 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params -class Qwen3MoeForCausalLM(nn.Module, SupportsPP): +class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -539,3 +541,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1e38f57304ec..2aaac7798fc0 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -12,19 +12,18 @@ import tempfile from abc import ABC, abstractmethod from collections.abc import Set -from dataclasses import dataclass, field +from dataclasses import asdict, dataclass, field from functools import lru_cache from typing import Callable, Optional, TypeVar, Union -import cloudpickle import torch.nn as nn from vllm.logger import init_logger from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_pp, - supports_transcription, supports_v0_only) + supports_multimodal, supports_multimodal_raw_input, + supports_pp, supports_transcription, supports_v0_only) from .interfaces_base import is_text_generation_model logger = init_logger(__name__) @@ -34,13 +33,16 @@ # [Decoder-only] "AquilaModel": ("llama", "LlamaForCausalLM"), "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArceeForCausalLM": ("arcee", "ArceeForCausalLM"), "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "MiniMaxForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxText01ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), "MiniMaxM1ForCausalLM": ("minimax_text_01", "MiniMaxText01ForCausalLM"), # baichuan-7b, upper case 'C' in the class name "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-13b, lower case 'c' in the class name "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), + "BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"), "BambaForCausalLM": ("bamba", "BambaForCausalLM"), "BloomForCausalLM": ("bloom", "BloomForCausalLM"), "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), @@ -56,6 +58,7 @@ "Ernie4_5_ForCausalLM": ("ernie45", "Ernie4_5_ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), + "Exaone4ForCausalLM": ("exaone4", "Exaone4ForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "Fairseq2LlamaForCausalLM": ("fairseq2_llama", "Fairseq2LlamaForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), @@ -65,6 +68,7 @@ "Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501 "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), + "Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"), "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), @@ -75,7 +79,8 @@ "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), - "HunYuanMoEV1ForCausalLM": ("hunyuan_v1_moe", "HunYuanMoEV1ForCausalLM"), + "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), + "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"), @@ -83,7 +88,6 @@ "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - "Llama4ForCausalLM": ("llama4", "Llama4ForCausalLM"), # For decapoda-research/llama-* "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "MambaForCausalLM": ("mamba", "MambaForCausalLM"), @@ -109,8 +113,8 @@ "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Phi4FlashForCausalLM": ("phi4flash", "Phi4FlashForCausalLM"), "Plamo2ForCausalLM": ("plamo2", "Plamo2ForCausalLM"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), @@ -180,8 +184,7 @@ "ModernBertForSequenceClassification": ("modernbert", "ModernBertForSequenceClassification"), # [Auto-converted (see adapters.py)] - "Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501 - "Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501 + "JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501, } _MULTIMODAL_MODELS = { @@ -202,6 +205,7 @@ "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 + "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"), "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 @@ -227,6 +231,7 @@ "Phi4MMForCausalLM": ("phi4mm", "Phi4MMForCausalLM"), "TarsierForConditionalGeneration": ("tarsier", "TarsierForConditionalGeneration"), # noqa: E501 "Tarsier2ForConditionalGeneration": ("qwen2_vl", "Tarsier2ForConditionalGeneration"), # noqa: E501 + "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 # [Encoder-decoder] "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"), # noqa: E501 "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 @@ -237,16 +242,20 @@ _SPECULATIVE_DECODING_MODELS = { "MiMoMTPModel": ("mimo_mtp", "MiMoMTP"), - "EAGLEModel": ("eagle", "EAGLE"), "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), + "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), + "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), + # Temporarily disabled. + # # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1. + # "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } _TRANSFORMERS_MODELS = { + "TransformersForMultimodalLM": ("transformers", "TransformersForMultimodalLM"), # noqa: E501 "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"), } # yapf: enable @@ -268,6 +277,8 @@ sys.executable, "-m", "vllm.model_executor.models.registry" ] +_PREVIOUSLY_SUPPORTED_MODELS = {"Phi3SmallForCausalLM": "0.9.2"} + @dataclass(frozen=True) class _ModelInfo: @@ -276,12 +287,14 @@ class _ModelInfo: is_pooling_model: bool supports_cross_encoding: bool supports_multimodal: bool + supports_multimodal_raw_input: bool supports_pp: bool has_inner_state: bool is_attention_free: bool is_hybrid: bool has_noops: bool supports_transcription: bool + supports_transcription_only: bool supports_v0_only: bool @staticmethod @@ -292,11 +305,14 @@ def from_model_cls(model: type[nn.Module]) -> "_ModelInfo": is_pooling_model=True, # Can convert any model into a pooling model supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), + supports_multimodal_raw_input=supports_multimodal_raw_input(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), is_hybrid=is_hybrid(model), supports_transcription=supports_transcription(model), + supports_transcription_only=(supports_transcription(model) and + model.supports_transcription_only), supports_v0_only=supports_v0_only(model), has_noops=has_noops(model), ) @@ -452,10 +468,26 @@ def _try_load_model_cls(self, return _try_load_model_cls(model_arch, self.models[model_arch]) def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: - if model_arch not in self.models: - return None + if model_arch in self.models: + return _try_inspect_model_cls(model_arch, self.models[model_arch]) + + if model_arch.endswith("ForSequenceClassification"): + causal_lm_arch = model_arch.replace("ForSequenceClassification", + "ForCausalLM") + if causal_lm_arch not in self.models: + return None - return _try_inspect_model_cls(model_arch, self.models[model_arch]) + info = _try_inspect_model_cls(causal_lm_arch, + self.models[causal_lm_arch]) + + info = _ModelInfo(**dict( + asdict(info), **{ + "architecture": model_arch, + "supports_cross_encoding": True + })) + return info + + return None def _normalize_archs( self, @@ -470,9 +502,23 @@ def _normalize_archs( normalized_arch = list( filter(lambda model: model in self.models, architectures)) - # make sure Transformers backend is put at the last as a fallback - if len(normalized_arch) != len(architectures): - normalized_arch.append("TransformersForCausalLM") + # try automatic conversion in adapters.py + for arch in architectures: + if not arch.endswith("ForSequenceClassification"): + continue + causal_lm_arch = arch.replace("ForSequenceClassification", + "ForCausalLM") + if causal_lm_arch in self.models: + normalized_arch.append(arch) + + # NOTE(Isotr0py): Be careful of architectures' order! + # Make sure Transformers backend architecture is at the end of the + # list, otherwise pooling models automatic conversion will fail! + for arch in normalized_arch: + if arch.startswith("TransformersFor"): + normalized_arch.remove(arch) + normalized_arch.append(arch) + return normalized_arch def inspect_model_cls( @@ -529,6 +575,13 @@ def is_multimodal_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_multimodal + def supports_multimodal_raw_input( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_multimodal_raw_input + def is_pp_supported_model( self, architectures: Union[str, list[str]], @@ -571,6 +624,13 @@ def is_transcription_model( model_cls, _ = self.inspect_model_cls(architectures) return model_cls.supports_transcription + def is_transcription_only_model( + self, + architectures: Union[str, list[str]], + ) -> bool: + model_cls, _ = self.inspect_model_cls(architectures) + return model_cls.supports_transcription_only + def is_v1_compatible( self, architectures: Union[str, list[str]], @@ -598,6 +658,7 @@ def _run_in_subprocess(fn: Callable[[], _T]) -> _T: output_filepath = os.path.join(tempdir, "registry_output.tmp") # `cloudpickle` allows pickling lambda functions directly + import cloudpickle input_bytes = cloudpickle.dumps((fn, output_filepath)) # cannot use `sys.executable __file__` here because the script diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 048fa827fb2b..c6b411644034 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import itertools from collections.abc import Iterable from typing import Optional, Union @@ -10,14 +9,14 @@ from transformers import RobertaConfig from vllm.config import VllmConfig -from vllm.model_executor.layers.pooler import ClassifierPooler +from vllm.model_executor.layers.pooler import (ClassifierPooler, CLSPool, + DispatchPooler, Pooler) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.model_executor.models.utils import WeightsMapper, maybe_prefix -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + maybe_prefix) +from vllm.sequence import IntermediateTensors from .bert_with_rope import BertWithRope, JinaRobertaModel from .interfaces import SupportsCrossEncoding, SupportsV0Only @@ -39,8 +38,10 @@ def __init__(self, config: RobertaConfig): config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.position_ids = nn.Parameter( - torch.empty((1, config.max_position_embeddings)), ) + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).unsqueeze(0), + ) self.position_embedding_type = config.position_embedding_type if self.position_embedding_type != "absolute": @@ -63,16 +64,10 @@ def forward( # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - pos_list = [] - token_list = [] - offset = 0 - for seq_len in seq_lens: - pos_list.append(position_ids[offset:offset + seq_len]) - token_list.append(input_ids[offset:offset + seq_len]) - offset += seq_len - + seq_lens_list = seq_lens.tolist() new_pos_list = [] - for positions, tokens in zip(pos_list, token_list): + for positions, tokens in zip(position_ids.split(seq_lens_list), + input_ids.split(seq_lens_list)): # Verify assumption that incoming position are # always a sequence from 0 to N. expected_pos = torch.arange(positions.size()[0], @@ -105,8 +100,8 @@ def __init__(self, config: RobertaConfig): self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) - def forward(self, features, **kwargs): - x = features[0, :] # take <s> token (equiv. to [CLS]) + def forward(self, x: torch.Tensor) -> torch.Tensor: + # CLSPool has already been applied in `pooling` x = self.dense(x) x = torch.tanh(x) x = self.out_proj(x) @@ -136,16 +131,20 @@ def _build_model(self, embedding_class=RobertaEmbedding) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - weights = self.hf_to_vllm_mapper.apply(weights) - # Separate weights in "roberta"-prefixed and all else (not in memory). - # For use with models like FacebookAI/roberta-base. - bert_weights, task_weights = roberta_task_weights_filter(weights) - loaded = self.model.load_weights(bert_weights) - if not len(loaded): - # Fix for models like `sentence-transformers/stsb-roberta-base-v2` - # which use the same architecture, but have no "roberta" prefix. - loaded = self.model.load_weights(task_weights) - assert len(loaded), "Unable to load RobertaEmbeddingModel" + weights_list = list(weights) + has_roberta_prefix = any( + name.startswith("roberta.") for name, _ in weights_list) + if has_roberta_prefix: + # For models with the `roberta.` prefix e.g. + # `FacebookAI/roberta-base` + mapper = WeightsMapper(orig_to_new_prefix={"roberta.": "model."}) + else: + # For models without the `roberta.` prefix e.g. + # `sentence-transformers/stsb-roberta-base-v2` + mapper = WeightsMapper(orig_to_new_prefix={"": "model."}) + + loader = AutoWeightsLoader(self, skip_prefixes=["lm_head."]) + return loader.load_weights(weights_list, mapper=mapper) class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, @@ -160,6 +159,7 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding, _pooler: An instance of Pooler used for pooling operations. """ + is_pooling_model = True jina_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ 'emb_ln': "embeddings.LayerNorm", @@ -179,34 +179,34 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_labels = config.num_labels self.roberta = BertModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "bert"), - embedding_class=RobertaEmbedding, - add_pooling_layer=False) + embedding_class=RobertaEmbedding) self.classifier = RobertaClassificationHead(config) - self._pooler = ClassifierPooler(vllm_config.model_config, - self.classifier) + pooler_config = vllm_config.model_config.pooler_config + assert pooler_config is not None + + self.pooler = DispatchPooler({ + "encode": + Pooler.for_encode(pooler_config), + "classify": + ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_seq_cls( + vllm_config.model_config), + ), + "score": + ClassifierPooler( + pooling=CLSPool(), + classifier=self.classifier, + act_fn=ClassifierPooler.act_fn_for_cross_encoder( + vllm_config.model_config), + ), + }) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - bert_weights, task_weights = roberta_task_weights_filter(weights) - bert_weights = self.jina_to_vllm_mapper.apply(bert_weights) - - self.roberta.load_weights(bert_weights) - - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in task_weights: - if name.startswith("classifier"): - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - def pooler( - self, - hidden_states: torch.Tensor, - pooling_metadata: PoolingMetadata, - ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.jina_to_vllm_mapper) def forward( self, @@ -245,27 +245,3 @@ def create_position_ids_from_input_ids(input_ids, past_key_values_length) * mask return incremental_indices.long() + padding_idx - - -def roberta_task_weights_filter( - all_weights: Iterable[tuple[str, torch.Tensor]] -) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[str, - torch.Tensor]]]: - """ - Separate task-specific weights that are applied on top - of the encoder-decoder bert base. - To do so, return two generators over the original iterator. - Also, remove the "roberta." prefix to make it loadable - from vanilla BertModel. - """ - # Copy of a lazy iterator without in-memory overhead so both - # iterators can be iterated upon independently. - all_weights1, all_weights2 = itertools.tee(all_weights) - - def encoder_decoder_weights(): - for name, weight in all_weights1: - if name.startswith("roberta."): - yield (name[len("roberta."):], weight) - - return encoder_decoder_weights(), ((n, w) for n, w in all_weights2 - if not n.startswith("roberta.")) diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 25f026e9bef8..979d789b330c 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -13,8 +13,7 @@ from transformers import PretrainedConfig, SiglipVisionConfig from transformers.image_utils import ImageInput, get_image_size, to_numpy_array from transformers.models.llava import LlavaProcessor -from transformers.processing_utils import (ProcessingKwargs, Unpack, - _validate_images_text_input_order) +from transformers.processing_utils import ProcessingKwargs, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from vllm.config import VllmConfig @@ -94,9 +93,6 @@ def __call__( raise ValueError( "You have to specify at least one of `images` or `text`.") - # check if images and text inputs are reversed for BC - images, text = _validate_images_text_input_order(images, text) - output_kwargs = self._merge_kwargs( TarsierProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 04ee3a454f9d..610f8e752dbd 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -15,8 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Wrapper around `transformers` models""" -from collections.abc import Iterable -from contextlib import nullcontext +from collections.abc import Iterable, Mapping +from contextlib import contextmanager, nullcontext from typing import Literal, Optional, Union import regex as re @@ -41,11 +41,21 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalInputs, PlaceholderRange) +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo) +from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processor import cached_get_processor +from vllm.utils import is_list_of -from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant +from .interfaces import (SupportsLoRA, SupportsMultiModal, SupportsPP, + SupportsQuant) from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, - is_pp_missing_parameter, + flatten_bn, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, maybe_prefix) logger = init_logger(__name__) @@ -112,6 +122,271 @@ def replace_linear_class( ) +# Copied from `accelerate` +@contextmanager +def init_on_device_without_buffers(device: torch.device): + """ + A context manager under which models are initialized with all + parameters on the specified device. However buffers are not + initialized on specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + """ + + old_register_parameter = nn.Module.register_parameter + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls( + module._parameters[name].to(device), **kwargs) + + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + for torch_function_name in tensor_constructors_to_patch: + setattr( + torch, torch_function_name, + patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + nn.Module.register_parameter = old_register_parameter + for torch_function_name, old_torch_function in ( + tensor_constructors_to_patch.items()): + setattr(torch, torch_function_name, old_torch_function) + + +class MultiModalProcessingInfo(BaseProcessingInfo): + + def get_hf_config(self): + return self.ctx.model_config.hf_config + + def get_supported_mm_limits(self): + return {"image": None} + + def get_mm_max_tokens_per_item(self, seq_len, mm_counts): + return {"image": self.get_max_image_tokens()} + + def get_max_image_tokens(self) -> int: + width, height = self.get_max_image_size() + processor = self.get_hf_processor() + mm_processor_kwargs = self.ctx.model_config.mm_processor_kwargs or {} + mm_tokens = processor._get_num_multimodal_tokens( + image_sizes=([height, width], ), **mm_processor_kwargs) + image_tokens = mm_tokens["num_image_tokens"][0] + return image_tokens + + def get_hf_processor(self): + processor = cached_get_processor(self.ctx.model_config.model) + return processor + + def get_max_image_size(self): + return 10_000, 10_000 # hardcode for arbitrary very large size + + +class MultiModalDummyInputsBuilder( + BaseDummyInputsBuilder[MultiModalProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + + processor = self.info.get_hf_processor() + if "gemma3" in processor.__class__.__name__.lower(): + image_token = processor.boi_token + else: + image_token = getattr(processor, "image_token", "") + return image_token * num_images + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + + target_width, target_height = self.info.get_max_image_size() + + return { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + } + + +class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]): + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ): + """ + Given the original multi-modal items for this modality + and HF-processed data, output the updates to perform. + + The information returned by this method is used to update token inputs + which bypass the HF processor. It is also used to update the output of + HF processor if the HF process does not apply prompt updates to text + inputs. + + Moreover, this information is critical to determine the token positions + in order to construct :class:`~vllm-multimodal.input.PlaceholderRange` + for each multi-modal item. + """ + return None + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs, + num_image_patches: torch.Tensor = None, + ): + # HF Processors always return a mask but vLLM doesn't need it + hf_inputs.pop("attention_mask", None) + mm_fields = { + key: MultiModalFieldConfig.flat_from_sizes("image", + num_image_patches) + for key in hf_inputs + } + mm_fields["image_embeds"] = MultiModalFieldConfig.flat_from_sizes( + "image", num_image_patches) + mm_fields["num_image_patches"] = MultiModalFieldConfig.batched("image") + return mm_fields + + def _apply_hf_processor_text_mm( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ): + """ + Apply the HF processor on the prompt text and multi-modal data + together. + + In addition, return whether prompt replacements have been applied. + """ + processor_data, passthrough_data = self._get_hf_mm_data(mm_items) + processor_data["return_mm_token_type_ids"] = True + + processed_data = self._call_hf_processor( + prompt=prompt_text, + mm_data=processor_data, + mm_kwargs=hf_processor_mm_kwargs, + tok_kwargs=tokenization_kwargs, + ) + processed_data.update(passthrough_data) + + prompt_ids, = processed_data.pop("input_ids").tolist() + mm_token_type_ids = processed_data.pop( + "mm_token_type_ids" + ) if "mm_token_type_ids" in processed_data else processed_data.pop( + "token_type_ids") # for gemma3 only + + return prompt_ids, processed_data, mm_token_type_ids + + def apply( + self, + prompt: Union[str, list[int]], + mm_data: MultiModalDataDict, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Optional[Mapping[str, object]] = None, + return_mm_hashes: bool = False, + ) -> MultiModalInputs: + """ + Process multi-modal inputs to be used in vLLM. + + Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + """ + if tokenization_kwargs is None: + tokenization_kwargs = {} + + mm_items = self._to_mm_items(mm_data) + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + if not isinstance(prompt, str): + # the prompt is the tokenized ids which is not supported + # by the hf_processor, which is why we would need to decode the ids + # into string + prompt = hf_processor.decode(prompt) + + (prompt_ids, processed_data, + mm_token_type_ids) = self._apply_hf_processor_text_mm( + prompt_text=prompt, + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + ) + + # HF processor will return `mm_token_type_ids` from which + # we can infer mm_placeholders. Until then hardcode to make code run + # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1 + mm_positions = torch.where(mm_token_type_ids == 1)[1] + images = mm_items.get_items("image", ImageProcessorItems) + mm_processor_kwargs = (self.info.ctx.model_config.mm_processor_kwargs + or {}) + image_sizes = [] + for item_idx in range(len(images)): + image_size = images.get_image_size(item_idx) + image_sizes.append((image_size.height, image_size.width)) + + mm_tokens_per_modality = hf_processor._get_num_multimodal_tokens( + image_sizes=image_sizes, **mm_processor_kwargs) + + mm_placeholders = {} + split_sizes = mm_tokens_per_modality["num_image_tokens"] + if split_sizes: + chunked_mm_positions = torch.split(mm_positions, split_sizes) + mm_tokens = torch.tensor(prompt_ids)[mm_token_type_ids[0].bool()] + chunked_mm_tokens = torch.split(mm_tokens, split_sizes) + ranges = [ + PlaceholderRange( + offset=positions[0].item(), + length=positions.shape[0], + is_embed=(mm_tokens == hf_processor.image_token_id).bool()) + for positions, mm_tokens in zip(chunked_mm_positions, + chunked_mm_tokens) + ] + mm_placeholders = {"image": ranges} + + num_image_patches = torch.tensor( + mm_tokens_per_modality["num_image_patches"] + ) if "num_image_patches" in mm_tokens_per_modality else None + processed_data['num_image_patches'] = num_image_patches + mm_kwargs = MultiModalKwargs.from_hf_inputs( + processed_data, + self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs, + num_image_patches), + ) + + mm_hashes = self._hash_mm_items(mm_items, hf_processor_mm_kwargs, + tokenization_kwargs) + return MultiModalInputs( + type="multimodal", + prompt=prompt, + prompt_token_ids=prompt_ids, + mm_kwargs=mm_kwargs, + mm_hashes=mm_hashes, + mm_placeholders=mm_placeholders, + ) + + class ConfigOverride: """Context manager to temporarily override config attributes.""" @@ -139,7 +414,7 @@ def __exit__(self, exc_type, exc_value, traceback): setattr(self.config, key, value) -class TransformersModel(nn.Module): +class TransformersModel: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -153,6 +428,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config: QuantizationConfig = vllm_config.quant_config self.config = config + self.text_config = config.get_text_config() self.cache_config = cache_config self.device_config = device_config self.model_config = model_config @@ -173,14 +449,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config_override = ConfigOverride( config, sliding_window=config.interleaved_sliding_window) - # Use meta device to delay allocating GPU tensors - with torch.device("meta"), config_override: - # FIXME(Isotr0py): We need to refactor this part in the future to - # avoid registering an extra model layer, otherwise we will need a - # weights mapper to rename weights. + # Set correct attn and init on "meta" to delay allocating GPU tensors + # TODO: @raushan, use the public `model.set_attn_implementation()` + # method after v4.54.0 is released + self.text_config._attn_implementation = "vllm" + with init_on_device_without_buffers("meta"), config_override: self.model: PreTrainedModel = AutoModel.from_config( config, - attn_implementation="vllm", torch_dtype=model_config.dtype, trust_remote_code=model_config.trust_remote_code, ) @@ -189,27 +464,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.tensor_parallel() # Input embeddings + text_config = config.get_text_config() if not isinstance(self.model.get_input_embeddings(), PPMissingLayer): self.model.set_input_embeddings( VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size, + text_config.vocab_size, + text_config.hidden_size, + org_num_embeddings=text_config.vocab_size, quant_config=quant_config, )) # Attention layers self.attention_instances = self.create_attention_instances() - # Initialize buffers (e.g. rotary embedding inverse frequency) - self.init_buffers(self.model) - # Initialize any parameters that have not had their modules replaced self.init_parameters(self.model) self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory(["hidden_states"], - config.hidden_size)) + text_config.hidden_size)) def pipeline_parallel(self): """ @@ -240,14 +513,15 @@ def pipeline_parallel(self): # Layers before module list for name in pp_plan[:module_list_idx]: - if self.pp_group.is_first_rank or (self.config.tie_word_embeddings - and self.pp_group.is_last_rank): + if self.pp_group.is_first_rank or ( + self.text_config.tie_word_embeddings + and self.pp_group.is_last_rank): continue setattr(self.model, name, PPMissingLayer()) # Module list - start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers, - self.pp_rank, self.pp_size) + start_layer, end_layer = get_pp_indices( + self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) for i in range(len(layers)): @@ -298,7 +572,7 @@ def create_attention_instances(self) -> dict[int, Attention]: self.parallel_config) head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - start, end = get_pp_indices(self.config.num_hidden_layers, + start, end = get_pp_indices(self.text_config.num_hidden_layers, self.pp_rank, self.pp_size) attention_instances = {} @@ -323,35 +597,6 @@ def create_attention_instances(self) -> dict[int, Attention]: prefix=f"{i}.attn") return attention_instances - def init_buffers(self, module: nn.Module): - """ - If a `buffer` is on the `meta` device, then its parent - `module` is the original module created by: - - ```python - with torch.device("meta"): - self.model: PreTrainedModel = AutoModel.from_config(...) - ``` - - This means that: - - `type(module)` is a class from `transformers` - - This class is constructed using a `PretrainedConfig` - """ - for name, buffer in module.named_buffers(recurse=False): - if buffer.device == torch.device("meta"): - if module == self.model: - logger.warning( - "To initialize buffers correctly, we instantiate the " - "parent module and and extract the value of the " - "buffer from it. In this case, the parent module is " - "the base model. Instantiating the entire model here " - "risks GPU OOM. Could this buffer be moved to a child " - "module?") - new_buffer = getattr(type(module)(self.config), name) - setattr(module, name, new_buffer) - for child in module.children(): - self.init_buffers(child) - def init_parameters(self, module: nn.Module): """ If a `parameter` is on the `meta` device, then its parent @@ -366,14 +611,12 @@ def init_parameters(self, module: nn.Module): if param.device == torch.device("meta"): new_param = nn.Parameter( torch.empty_like(param.data, + dtype=self.model_config.dtype, device=self.device_config.device)) setattr(module, name, new_param) for child in module.children(): self.init_parameters(child) - def get_input_embeddings(self) -> nn.Module: - return self.model.get_input_embeddings() - def forward( self, input_ids: Optional[torch.Tensor], @@ -391,11 +634,16 @@ def forward( if inputs_embeds is not None: inputs_embeds = inputs_embeds[None, ...] + if self.model_config.uses_mrope: + position_ids = positions[:, None] + else: + position_ids = positions[None, ...] + hidden_states = self.model( input_ids=input_ids, inputs_embeds=inputs_embeds, use_cache=False, - position_ids=positions[None, ...], + position_ids=position_ids, attention_instances=self.attention_instances, return_dict=False)[0][0, ...] # we remove batch dimension for now @@ -440,7 +688,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config - self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix) + self.transformers_model = TransformersModel(vllm_config=vllm_config, + prefix=prefix) + self.model = self.transformers_model.model if get_pp_group().is_last_rank: self.unpadded_vocab_size = config.vocab_size @@ -462,22 +712,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - - # FIXME(Isotr0py): Don't use any weights mapper for Transformers backend, - # this makes thing complicated. We need to remove this mapper after refactor - # `TransformersModel` in the future. - # NOTE: `SupportsQuant` can be updated after property decorator is removed - @property - def hf_to_vllm_mapper(self): - prefix_mapper = { - name: "model." + name - for name, _ in self.model.model.named_children() - } - return WeightsMapper( - orig_to_new_substr={"model.": "model.model."}, - orig_to_new_prefix=prefix_mapper, - ) + self.transformers_model.make_empty_intermediate_tensors) def forward( self, @@ -486,8 +721,116 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - model_output = self.model(input_ids, positions, intermediate_tensors, - inputs_embeds) + model_output = self.transformers_model.forward(input_ids, positions, + intermediate_tensors, + inputs_embeds) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + skip_prefixes = ["lm_head." + ] if self.config.tie_word_embeddings else None + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) + return loader.load_weights(weights) + + +@MULTIMODAL_REGISTRY.register_processor( + MultiModalProcessor, + info=MultiModalProcessingInfo, + dummy_inputs=MultiModalDummyInputsBuilder) +class TransformersForMultimodalLM(nn.Module, SupportsQuant, SupportsLoRA, + SupportsPP, SupportsMultiModal): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens"] + + # Backwards compatibility for prev released models. State dicts back then + # had different formats and cannot be loaded with `AutoModel` mapping as is + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "model.language_model", + "text_model.model": "model.text_model", + "vision_tower": "model.vision_tower", + "vqmodel": "model.vqmodel", + "visual": "model.visual", + "vision_model": "model.vision_model", + "vision_embed_tokens": "model.vision_embed_tokens", + "image_newline": "model.image_newline", + "multi_modal_projector": "model.multi_modal_projector", + "text_model.lm_head": "lm_head", + "language_model.lm_head": "lm_head", + # Qwen models used "model" as the name for the language model. + # Therefore, we must map each of submodule explicitly to avoid + # conflicts with newer models that use "model.language_model". + "model.embed_tokens": "model.language_model.embed_tokens", + "model.layers": "model.language_model.layers", + "model.norm": "model.language_model.norm", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: PretrainedConfig = vllm_config.model_config.hf_config + quant_config: QuantizationConfig = vllm_config.quant_config + + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.transformers_model = TransformersModel(vllm_config=vllm_config, + prefix=prefix) + self.model = self.transformers_model.model + text_config = config.get_text_config() + + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = text_config.vocab_size + self.lm_head = ParallelLMHead( + text_config.vocab_size, + text_config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if text_config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.get_input_embeddings()) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + text_config.vocab_size, + logit_scale) + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.transformers_model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + if inputs_embeds is None: + multimodal_embeds = self.get_multimodal_embeddings(**kwargs) + if multimodal_embeds is not None: + inputs_embeds = self.get_input_embeddings( + input_ids, multimodal_embeds) + input_ids = None + + model_output = self.transformers_model.forward(input_ids, positions, + intermediate_tensors, + inputs_embeds) return model_output def compute_logits( @@ -503,7 +846,76 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader( self, - skip_prefixes=(["lm_head."] - if self.config.tie_word_embeddings else None), + skip_prefixes=([ + "lm_head." + ] if self.config.get_text_config().tie_word_embeddings else None), ) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_multimodal_embeddings(self, **kwargs): + pixel_values = kwargs.pop("pixel_values", None) + pixel_values = pixel_values if pixel_values is not None else kwargs.pop( + "image_patches", None) + image_embeds = kwargs.pop("image_embeds", None) + + if image_embeds is not None: + return image_embeds + + if pixel_values is None and image_embeds is None: + return None + + num_image_patches = kwargs.pop("num_image_patches") + if pixel_values is not None: + if isinstance(pixel_values, torch.Tensor): + pixel_values = flatten_bn(pixel_values).to(self.dtype) + elif is_list_of(pixel_values, torch.Tensor): + pixel_values = flatten_bn(flatten_bn(pixel_values), + concat=True).to(self.dtype) + else: + raise ValueError( + f"Unsupported pixel_values type {type(pixel_values)}. " + "Expected `torch.Tensor` or list of `torch.Tensor`.") + + if isinstance(num_image_patches, list): + num_image_patches = torch.cat(num_image_patches) + + vision_embeddings = self.model.get_image_features( + pixel_values, + **{ + k: v.flatten(0, 1) + for k, v in kwargs.items() + }, + ) + + if isinstance(vision_embeddings, torch.Tensor): + if vision_embeddings.ndim == 2: + vision_embeddings = vision_embeddings.unsqueeze(0) + + # Embeddings have to be 2D tensors of length `num_images` + # but transformers returns concat tensors if each patch + # is of different size. We split it back to make vLLM happy + vision_embeddings = torch.split( + vision_embeddings, + num_image_patches.flatten().tolist()) + vision_embeddings = [ + embed.flatten(start_dim=0, end_dim=-2) + for embed in vision_embeddings + ] + + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings=None, + ) -> torch.Tensor: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + if (multimodal_embeddings is not None + and len(multimodal_embeddings) != 0): + mask = (input_ids == self.config.image_token_id) + mask = mask.unsqueeze(-1).expand_as(inputs_embeds) + multimodal_embeddings = torch.cat(multimodal_embeddings) + + inputs_embeds = inputs_embeds.masked_scatter( + mask, multimodal_embeddings) + return inputs_embeds diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py new file mode 100644 index 000000000000..97cab628317e --- /dev/null +++ b/vllm/model_executor/models/voxtral.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from math import ceil +from typing import Optional, Union, cast + +import numpy as np +import regex as re +import torch +import torch.nn as nn +from mistral_common.audio import mel_filter_bank +from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio, + TextChunk, UserMessage) +from mistral_common.protocol.instruct.request import ChatCompletionRequest +from mistral_common.protocol.transcription.request import TranscriptionRequest +from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder +from transformers import TensorType, WhisperConfig +from transformers.tokenization_utils_base import TextInput + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import SupportsPP +# yapf: disable +from vllm.model_executor.models.whisper import ( + WhisperEncoder, WhisperForConditionalGeneration) +# yapf: enable +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, + MultiModalKwargs, NestedTensors) +from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems, + MultiModalDataParser) +from vllm.multimodal.processing import (BaseMultiModalProcessor, + BaseProcessingInfo, MultiModalHashes, + PromptReplacement, PromptUpdate) +from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.tokenizer import (MistralTokenizer, + cached_tokenizer_from_config) + +from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, + SupportsTranscription) +from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + + +class VoxtralProcessorAdapter: + """ + Provide a HF-compatible interface for + :class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`. + """ + + def __init__(self, tokenizer: MistralTokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + + @cached_property + def _audio_processor(self) -> AudioEncoder: + audio_encoder = self.tokenizer.instruct.audio_encoder + assert isinstance(audio_encoder, AudioEncoder) + return audio_encoder + + @cached_property + def audio_token_id(self) -> int: + return self._audio_processor.special_ids.audio + + @cached_property + def begin_audio_token_id(self) -> int: + return self._audio_processor.special_ids.begin_audio + + # @cached_property + # def begin_transcript_token_id(self) -> int: + # return self._audio_processor.special_ids.begin_transcript + + # @cached_property + # def end_transcript_token_id(self) -> int: + # return self._audio_processor.special_ids.end_transcript + + @cached_property + def sampling_rate(self) -> int: + return self._audio_processor.audio_config.sampling_rate + + @cached_property + def frame_rate(self) -> float: + return self._audio_processor.audio_config.frame_rate + + def get_num_audio_tokens( + self, + audio_length: int, + ) -> int: + pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames( + audio_length, self.sampling_rate) + return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate)) + + def __call__( + self, + text: Optional[Union[TextInput, list[TextInput]]] = None, + audios: Optional[Union[np.ndarray, list[np.ndarray]]] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> Mapping[str, NestedTensors]: + if text is None: + text = [] + if not isinstance(text, list): + text = [text] + if audios is None: + audios = [] + if not isinstance(audios, list): + audios = [audios] + + if not audios: + input_ids = self.tokenizer(text).input_ids + return {"input_ids": torch.tensor(input_ids)} + + # Allow dummy text, which is used for profiling as well as token inputs + if any(len(t) > 0 for t in text): + raise ValueError( + "You've passed text inputs instead of token inputs. " + "Make sure to process your input via `mistral_common`'s " + "tokenizer or pass a chat completion request. " + "For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.") + + audios_tokens = list[torch.Tensor]() + audios_processed = list[torch.Tensor]() + for audio in audios: + assert isinstance(audio, np.ndarray) + assert audio.ndim == 1 + + # pad if necessary + audio = self._audio_processor.pad(audio, self.sampling_rate) + + audio_tokens = [ + self.begin_audio_token_id + ] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio)) + + audios_tokens.append(torch.tensor(audio_tokens)) + audios_processed.append(torch.tensor(audio)) + + return { + "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), + "audio_arrays": audios_processed, + } + + +class VoxtralProcessingInfo(BaseProcessingInfo): + + def get_tokenizer(self) -> MistralTokenizer: + tokenizer = cached_tokenizer_from_config(self.ctx.model_config) + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError("This model requires `--tokenizer-mode mistral`") + + return tokenizer + + def get_hf_processor(self) -> VoxtralProcessorAdapter: + return VoxtralProcessorAdapter(self.get_tokenizer()) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"audio": 5} # Performance tends to degrade after 5 + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int]: + return {"audio": self.get_max_audio_tokens()} + + def get_max_audio_tokens(self) -> int: + return self.ctx.model_config.max_model_len + + def get_max_audio_array_len(self) -> int: + processor = self.get_hf_processor() + return self.get_max_audio_tokens() * int( + processor.sampling_rate // processor.frame_rate) + + +class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): + + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + return "" + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalDataDict: + num_audios = mm_counts.get("audio", 0) + + target_length = self.info.get_max_audio_array_len() + + return { + "audio": + self._get_dummy_audios(length=target_length, num_audios=num_audios) + } + + def get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + tokenizer = self.info.get_tokenizer() + + dummy_text = self.get_dummy_text(mm_counts) + dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts) + dummy_audios = dummy_mm_data.get("audio", []) + + audio_chunks: list[AudioChunk] = [] + format = "wav" + for audio in dummy_audios: + audio_item = Audio( + audio_array=audio, + sampling_rate=self.info.get_hf_processor().sampling_rate, + format=format, + ) + chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item)) + audio_chunks.append(chunk) + + request = ChatCompletionRequest(messages=[ + UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]), + ]) + res = tokenizer.mistral.encode_chat_completion(request) + dummy_tokens = res.tokens + # whixtral tokenizer adds padding to the audio + # so we need to update the audio arrays + dummy_mm_data["audio"] = [a.audio_array for a in res.audios] + + return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data) + + +class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo] + ): + + def _get_mm_fields_config( + self, + hf_inputs: Mapping[str, NestedTensors], + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(audio_arrays=MultiModalFieldConfig.batched("audio")) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + audio_id = processor.audio_token_id + + def get_replacement(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + + nb_audio_tokens = processor.get_num_audio_tokens(audio_len) + + return [audio_id] * nb_audio_tokens + + return [ + PromptReplacement( + modality="audio", + target="", # Never match the prompt (see below note) + replacement=get_replacement, + ), + ] + + def _cached_apply_hf_processor( + self, + prompt: Union[str, list[int]], + mm_data_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + *, + return_mm_hashes: bool, + ) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]: + prompt_ids, mm_kwargs, mm_hashes, _ = super( + )._cached_apply_hf_processor( + prompt=prompt, + mm_data_items=mm_data_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + tokenization_kwargs=tokenization_kwargs, + return_mm_hashes=return_mm_hashes, + ) + + # NOTE: The tokens are already inserted by the chat template + return prompt_ids, mm_kwargs, mm_hashes, True + + def _get_data_parser(self) -> MultiModalDataParser: + sampling_rate = self.info.get_hf_processor().sampling_rate + return MultiModalDataParser(target_sr=sampling_rate) + + +@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor, + info=VoxtralProcessingInfo, + dummy_inputs=VoxtralDummyInputsBuilder) +class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP, SupportsTranscription): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config) + + config = vllm_config.model_config.hf_config + self.config = config + self.downsample_factor = self.config.audio_config.downsample_factor + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + ) + self.whisper_encoder = VoxtralEncoderModel( + vllm_config.with_hf_config(config.audio_config), + prefix=maybe_prefix(prefix, "whisper_encoder"), + ) + self.audio_language_adapter = AudioLanguageAdapter( + hidden_size=config.audio_config.d_model * self.downsample_factor, + dim=config.text_config.hidden_size, + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner, this + # condition is for v0 compatibility. + elif inputs_embeds is None: + audio_embeddings = self.get_multimodal_embeddings(**kwargs) + inputs_embeds = self.get_input_embeddings(input_ids, + audio_embeddings) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def get_multimodal_embeddings( + self, **kwargs + ) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], + None]: + audio_inputs = self._parse_and_validate_audio_arrays(**kwargs) + if audio_inputs is None: + return None + + audio_embeddings = self.whisper_encoder(audio_inputs) + + for i, audio_embedding in enumerate(audio_embeddings): + seq_len, dim = audio_embedding.shape + # Pad such that seq_len is divisible by downsample_factor + target_seq_len = self.downsample_factor * math.ceil( + seq_len / self.downsample_factor) + audio_embedding = torch.nn.functional.pad( + audio_embedding, + (0, 0, 0, target_seq_len - seq_len), + ) + audio_embeddings[i] = audio_embedding.reshape( + target_seq_len // self.downsample_factor, + dim * self.downsample_factor) + + # Concat, project and resplit + audio_embeddings_packed = torch.cat(audio_embeddings, dim=0) + audio_embeddings_packed = self.audio_language_adapter( + audio_embeddings_packed) + audio_embeddings = torch.split(audio_embeddings_packed, + [a.shape[0] for a in audio_embeddings], + dim=0) + + return audio_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + audio_encoder = self.tokenizer.instruct.audio_encoder + audio_tok_id = audio_encoder.audio_token + + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, audio_tok_id) + return inputs_embeds + + def _parse_and_validate_audio_arrays( + self, **kwargs: object) -> Union[list[torch.Tensor], None]: + audio_arrays = kwargs.pop("audio_arrays", None) + if audio_arrays is None: + return None + + if not isinstance(audio_arrays, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio_arrays. " + f"Got type: {type(audio_arrays)}") + + audio_arrays = flatten_bn(audio_arrays) + if isinstance(audio_arrays, torch.Tensor): + audio_arrays = list(audio_arrays.unbind(0)) + return audio_arrays + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + @classmethod + def get_speech_to_text_config(cls, model_config: ModelConfig, + task_type: str) -> SpeechToTextConfig: + tokenizer = cached_tokenizer_from_config(model_config) + audio_config = tokenizer.instruct.audio_encoder.audio_config + max_audio_clip_s = audio_config.chunk_length_s + sample_rate = audio_config.sampling_rate + return SpeechToTextConfig( + max_audio_clip_s=max_audio_clip_s, + sample_rate=sample_rate, + # mistral_common and whisper encoder take care of chunking + min_energy_split_window_size=None, + ) + + @classmethod + # for speech-to-text transcription + def get_generation_prompt(cls, audio: np.ndarray, + model_config: ModelConfig, + stt_config: SpeechToTextConfig, language: str, + task_type: str, + request_prompt: str) -> PromptType: + tokenizer = cached_tokenizer_from_config(model_config) + audio = Audio(audio, int(stt_config.sample_rate), + format="wav") # lossless + req = TranscriptionRequest(model=model_config.model, + audio=RawAudio.from_audio(audio), + language=language) + + tokenized = tokenizer.instruct.encode_transcription(req) + audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) + prompts_dict = {"multi_modal_data": {"audio": audio}} + prompts_dict["prompt_token_ids"] = tokenized.tokens + return cast(PromptType, prompts_dict) + + @classmethod + def validate_language(cls, language: str) -> bool: + # same as whisper + return WhisperForConditionalGeneration.validate_language(language) + + @classmethod + def get_num_audio_tokens(cls, audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig) -> Optional[int]: + """ + Map from audio duration to number of audio tokens produced by the ASR + model, without running a forward pass. + This is used for estimating the amount of processing for this audio. + """ + tokenizer = cached_tokenizer_from_config(model_config) + adapter = VoxtralProcessorAdapter(tokenizer) + return adapter.get_num_audio_tokens( + int(audio_duration_s * stt_config.sample_rate)) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + # fmt: off + remapping_rules = [ + (r"mm_whisper_embeddings\.(.*)", r"\1"), + (r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"), + (r"audio_language_adapter\.0\.weight", r"audio_language_adapter.w_in.weight"), # noqa: E501 + (r"audio_language_adapter\.2\.weight", r"audio_language_adapter.w_out.weight"), # noqa: E501 + ] + # fmt: on + + audio_params = dict( + nn.ModuleDict({ + "audio_language_adapter": + self.audio_language_adapter, + }).named_parameters()) + + loaded_weights = set() + + def llm_weights_generator(): + nonlocal loaded_weights + for name, w in weights: + is_encoder = ( + name.startswith("mm_whisper_embeddings") and + not name.startswith("mm_whisper_embeddings.tok_embeddings") + and not name.startswith( + "mm_whisper_embeddings.audio_language_projection")) + + for pattern, repl in remapping_rules: + if re.fullmatch(pattern, name): + name = re.sub(pattern, repl, name) + + if is_encoder: + name = self.whisper_encoder.load_weight((name, w)) + loaded_weights.add(f"whisper_encoder.{name}") + continue + + if name in audio_params: + param = audio_params[name] + with torch.no_grad(): + default_weight_loader(param, w) + loaded_weights.add(name) + else: + yield (name, w) + + for name in self.language_model.load_weights(llm_weights_generator()): + loaded_weights.add(f"language_model.{name}") + + # potentially manually add position embeddings + sin_key = "whisper_encoder.whisper_encoder.embed_positions.weight" + if sin_key not in loaded_weights: + # make sure we don't hit an error here + loaded_weights.add(sin_key) + + return loaded_weights + + +class AudioLanguageAdapter(nn.Module): + + def __init__(self, hidden_size: int, dim: int) -> None: + super().__init__() + self.w_in = nn.Linear(hidden_size, dim, bias=False) + self.gelu = nn.GELU() + self.w_out = nn.Linear(dim, dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w_out(self.gelu(self.w_in(x))) + + +class VoxtralEncoderModel(nn.Module): + packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]} + + # fmt: off + mistral_remapping = [ + (r"whisper_encoder\.conv_layers\.0\.(weight|bias)", r"whisper_encoder.conv1.\1"), # noqa: E501 + (r"whisper_encoder\.conv_layers\.1\.(weight|bias)", r"whisper_encoder.conv2.\1"), # noqa: E501 + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.\2_proj.\3"), # noqa: E501 + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.wo\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn.out_proj.\2"), # noqa: E501 + (r"whisper_encoder\.transformer\.layers\.(\d+)\.attention_norm\.(weight|bias)", r"whisper_encoder.layers.\1.self_attn_layer_norm.\2"), # noqa: E501 + (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w1\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc1.\2"), # noqa: E501 + (r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", r"whisper_encoder.layers.\1.mlp.fc2.\2"), # noqa: E501 + (r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)", r"whisper_encoder.layers.\1.final_layer_norm.\2"), # noqa: E501 + (r"whisper_encoder\.transformer\.norm\.(weight|bias)", r"whisper_encoder.layer_norm.\1"), # noqa: E501 + ] + # fmt: on + + def __init__( + self, + vllm_config: VllmConfig, + *, + prefix: str = "", + ) -> None: + super().__init__() + self.config = cast(WhisperConfig, vllm_config.model_config.hf_config) + self.dtype: torch.dtype = vllm_config.model_config.dtype + self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "whisper_encoder"), + is_standalone_encoder=True, + init_in_fp32=True) + mel_filters = mel_filter_bank( + num_frequency_bins=1 + self.config.window_size // 2, + num_mel_bins=self.config.num_mel_bins, + min_frequency=0.0, + max_frequency=8000.0, + sampling_rate=self.config.sampling_rate, + ) + self.mel_filters = torch.tensor(mel_filters, dtype=torch.float32) + + def compute_whisper_melspec( + self, + audio_waveforms: torch.Tensor, + ) -> torch.Tensor: + input_dtype = audio_waveforms.dtype + window = torch.hann_window(self.config.window_size).to( + audio_waveforms.device) + stft = torch.stft( + audio_waveforms, + self.config.window_size, + self.config.hop_length, + window=window, + return_complex=True, + ) + magnitudes = stft[..., :-1].abs()**2 + mel_spec = self.mel_filters.T @ magnitudes + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + return log_spec.to(input_dtype) + + @property + def downsample_factor(self) -> int: + return self.whisper_encoder.conv1.stride[ + 0] * self.whisper_encoder.conv2.stride[0] + + @property + def chunk_size(self) -> int: + return self.config.max_source_positions * self.downsample_factor + + def prepare_inputs_for_conv( + self, + audio_waveforms: list[torch.Tensor], + ) -> tuple[torch.Tensor, list[int]]: + assert isinstance(audio_waveforms, list) + # list[num_mel_bins, seq_len] + input_features = [ + self.compute_whisper_melspec(audio).to(self.dtype) + for audio in audio_waveforms + ] + + chunked_features: list[torch.Tensor] = [] + chunks_per_example: list[int] = [] + for feature in input_features: + chunks = feature.split(self.chunk_size, dim=-1) + chunked_features += chunks + chunks_per_example.append(len(chunks)) + + # [total_num_chunks, num_mel_bins, chunk_size] + return torch.stack(chunked_features), chunks_per_example + + def forward( + self, input_features: Union[torch.Tensor, list[torch.Tensor]] + ) -> list[torch.Tensor]: + if not isinstance(input_features, list): + input_features = [input_features] + + # Split long inputs into chunks + input_embeds, chunks_per_example = ( + self.prepare_inputs_for_conv(input_features)) + + # [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size] + out = self.whisper_encoder([input_embeds]) + + # Re-concatenate the chunks + chunk_idx = 0 + results = [] + for n_chunks in chunks_per_example: + result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1) + results.append(result) + chunk_idx += n_chunks + + return results + + def load_weight(self, weight: tuple[str, torch.Tensor]) -> str: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + + name, loaded_weight = weight + for pattern, repl in self.mistral_remapping: + if re.fullmatch(pattern, name): + name = re.sub(pattern, repl, name) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + return name diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 344d6fc8f452..d98dab5fac0e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -3,8 +3,10 @@ import math from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, TypedDict, Union +from contextlib import nullcontext +from typing import Optional, TypedDict, Union, cast +import numpy as np import torch from torch import nn from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor, @@ -12,8 +14,11 @@ from transformers.models.whisper.modeling_whisper import sinusoids from vllm.attention import Attention, AttentionType -from vllm.config import CacheConfig, VllmConfig +from vllm.attention.layer import MultiHeadAttention +from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig, + VllmConfig) from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -23,6 +28,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors @@ -33,6 +39,7 @@ EncDecMultiModalProcessor, PromptReplacement, PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder +from vllm.transformers_utils.processor import cached_get_processor from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription, SupportsV0Only) @@ -174,6 +181,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + standalone_encoder: bool = False, ): super().__init__() self.embed_dim = embed_dim @@ -209,16 +217,24 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.out_proj", ) - self.attn = Attention( - self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - attn_type=self.attn_type, - ) + if standalone_encoder: + self.attn = MultiHeadAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + ) + else: + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) def _init_qkv( self, @@ -353,7 +369,11 @@ def forward(self, hidden_states: torch.Tensor): class WhisperEncoderLayer(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + is_standalone_encoder: bool = False): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -367,6 +387,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + standalone_encoder=is_standalone_encoder, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.mlp = WhisperMLP( @@ -458,10 +479,16 @@ def forward( class WhisperEncoder(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + is_standalone_encoder: bool = False, + init_in_fp32: bool = False): super().__init__() config = vllm_config.model_config.hf_config embed_dim = config.d_model + self.is_standalone_encoder = is_standalone_encoder self.num_mel_bins = config.num_mel_bins self.max_source_positions = config.max_source_positions self.embed_scale = (math.sqrt(embed_dim) @@ -476,17 +503,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): kernel_size=3, stride=2, padding=1) - self.embed_positions = nn.Embedding(self.max_source_positions, - embed_dim) self.start_layer, self.end_layer, self.layers = make_layers( config.encoder_layers, lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config, - prefix=f"{prefix}.layers"), + prefix=f"{prefix}.layers", + is_standalone_encoder= + is_standalone_encoder), prefix=f"{prefix}.layers", ) self.layer_norm = nn.LayerNorm(config.d_model) - with torch.no_grad(): + maybe_fp32_init_ctx = set_default_torch_dtype( + torch.float32) if init_in_fp32 else nullcontext() + + with ( + torch.no_grad(), + maybe_fp32_init_ctx, + ): + self.embed_positions = nn.Embedding(self.max_source_positions, + embed_dim) self.embed_positions.weight.copy_( sinusoids(*self.embed_positions.weight.shape)) @@ -495,8 +530,10 @@ def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]): for features in input_features: embeds = nn.functional.gelu(self.conv1(features)) embeds = nn.functional.gelu(self.conv2(embeds)) - embeds = embeds.permute(1, 0) - embeds = embeds + self.embed_positions.weight[:embeds.size(0), :] + embeds = embeds.transpose(-1, -2) + embeds = (embeds + + self.embed_positions.weight[:embeds.size(-2), :]).to( + embeds.dtype) hidden_states.append(embeds) hidden_states = torch.cat(hidden_states) @@ -634,7 +671,14 @@ def get_hf_config(self) -> WhisperConfig: def get_hf_processor(self, sampling_rate: Optional[int] = None ) -> WhisperProcessor: - return self.ctx.get_hf_processor(WhisperProcessor) + # HACK: Transformers 4.53.0 has issue with whisper tokenizer to + # initialize processor. We use a monkeypatch to fix it here. + # See: https://github.com/vllm-project/vllm/issues/20224 + processor_class = WhisperProcessor + tokenizer_class = ("WhisperTokenizer", "WhisperTokenizerFast") + if processor_class.tokenizer_class != tokenizer_class: + processor_class.tokenizer_class = tokenizer_class + return self.ctx.get_hf_processor(processor_class) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": 1} @@ -761,6 +805,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription, ".fc2.": ".mlp.fc2." }) + # Whisper only supports audio-conditioned generation. + supports_transcription_only = True + @classmethod def validate_language(cls, language: str) -> bool: if language in ISO639_1_SUPPORTED_LANGS: @@ -778,11 +825,28 @@ def validate_language(cls, language: str) -> bool: f"or {list(ISO639_1_OTHER_LANGS.values())}") @classmethod - def get_decoder_prompt(cls, language: str, task_type: str, - prompt: str) -> str: - return ((f"<|prev|>{prompt}" if prompt else "") + - f"<|startoftranscript|><|{language}|>" + - f"<|{task_type}|><|notimestamps|>") + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: str, + task_type: str, + request_prompt: str) -> PromptType: + prompt = { + "encoder_prompt": { + # Whisper does not support encoder prompt. + "prompt": "", + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + }, + "decoder_prompt": + ((f"<|prev|>{request_prompt}" if request_prompt else "") + + f"<|startoftranscript|><|{language}|>" + + f"<|{task_type}|><|notimestamps|>") + } + return cast(PromptType, prompt) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: @@ -791,6 +855,30 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: raise ValueError("Only audio modality is supported") + @classmethod + def get_speech_to_text_config(cls, model_config: ModelConfig, + task_type: str) -> SpeechToTextConfig: + processor = cached_get_processor(model_config.model) + + return SpeechToTextConfig( + max_audio_clip_s=processor.feature_extractor.chunk_length, + sample_rate=processor.feature_extractor.sampling_rate, + ) + + @classmethod + def get_num_audio_tokens(cls, audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig) -> Optional[int]: + processor = cached_get_processor(model_config.model) + hop_length = processor.feature_extractor.hop_length + assert hop_length is not None + # NOTE(NickLucche) user can't pass encoder + # prompts directly at least not to Whisper. + # One indicator of the encoder amount of processing + # is the log-mel spectogram length. + return math.ceil(audio_duration_s * stt_config.sample_rate / + hop_length) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 54c80cfa5922..7764fd9b9e08 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -17,8 +17,9 @@ from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -30,8 +31,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -501,8 +502,7 @@ def __init__(self, rms_norm_eps=config.rms_norm_eps, activation="silu", quant_config=quant_config, - prefix=f"{prefix}.mixer", - chunk_size=config.chunk_size) + prefix=f"{prefix}.mixer") # Input normalization self.input_layernorm = RMSNorm(config.hidden_size, @@ -549,14 +549,16 @@ def forward( hidden_states = self.input_layernorm(hidden_states) # Process through Mamba mixer - hidden_states = self.mamba( + output = torch.empty_like(hidden_states) + self.mamba( hidden_states, + output, mamba_cache_params=mamba_cache_params, mamba2_metadata=mamba2_metadata, ) # residual connection after mamba - hidden_states = residual + hidden_states + hidden_states = residual + output return hidden_states @@ -647,6 +649,7 @@ def forward( return layer_outputs +@support_torch_compile class Zamba2Model(nn.Module): """Core Zamba2 model combining transformer and Mamba architectures. @@ -844,6 +847,39 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): "1.weight": "B.weight", }) + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_ngroups, + num_heads=hf_config.n_mamba_heads, + head_dim=hf_config.mamba_headdim, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. @@ -926,9 +962,13 @@ def forward(self, if not envs.VLLM_USE_V1: if self.mamba_cache is None: num_mamba_layers = self.config.num_hidden_layers - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) # Get cache parameters for current run mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -969,49 +1009,6 @@ def get_seqlen_agnostic_capture_inputs( """ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - """Calculate shapes for Mamba's convolutional and state caches. - - Returns: - Tuple containing: - - conv_state_shape: Shape for convolutional state cache - - temporal_state_shape: Shape for state space model cache - """ - world_size = get_tensor_model_parallel_world_size() - - intermediate_size = self.config.mamba_expand * self.config.hidden_size - - # Extend groups if needed to ensure all groups needed by a head - # are sharded together - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards( - self.config.mamba_ngroups, world_size)) - - # Calculate conv state shape (includes groups) - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # Calculate temporal state shape (per-head states) - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(divide(intermediate_size, self.config.mamba_headdim), - world_size), - self.config.mamba_headdim, - self.config.mamba_d_state, - ) - - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py index 4dd443bc26ea..e6f1ca61dd29 100644 --- a/vllm/model_executor/pooling_metadata.py +++ b/vllm/model_executor/pooling_metadata.py @@ -38,6 +38,13 @@ def __repr__(self) -> str: f"seq_data={self.seq_data}, " f"prompt_lens={self.prompt_lens})") + def __getitem__(self, indices: slice): + return PoolingMetadata( + seq_groups=self.seq_groups[indices], + seq_data=dict(list(self.seq_data.items())[indices]), + prompt_lens=self.prompt_lens[indices], + ) + @dataclass class PoolingTensors: diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index aa7889fc3cc5..78d244a6b4fc 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1100,24 +1100,29 @@ def get_allowed_mm_limits(self) -> Mapping[str, int]: return allowed_limits - def get_max_tokens_per_item( - self, seq_len: int, - mm_counts: Optional[Mapping[str, - int]]) -> Optional[Mapping[str, int]]: - """Return the maximum number of tokens per item of for each modality. - By default, returns `None`. When `None` is returned, vLLM will generate - dummy inputs (images/videos) at maximum possible sizes and process them - to determine the maximum token count per modality. + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Optional[Mapping[str, int]]: + """ + Return the maximum number of tokens per item of for each modality. + + When `None` (the default) is returned, vLLM will generate dummy inputs + (images/videos) at maximum possible sizes and process them to determine + the maximum token count per modality. + This approach works but can be very slow for certain models (e.g., Qwen2.5-VL), leading to very long startup time. For better performance, each model can override this method to return pre-computed maximum token counts, avoiding the need for dummy input generation and processing. - NOTE: The maximum number of tokens per item of each modality returned - from this function should respect to the model maximum sequence length - and the maximum number of items of each modality allowed, and agrees - with dummy inputs (images/videos) at maximum possible sizes. - + Note: + The maximum number of tokens per item of each modality returned + from this function should respect the model's maximum sequence + length and the maximum number of items of each modality allowed, + and agree with dummy inputs (images/videos) at maximum possible + sizes. """ return None diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index fb5a7b64c419..7f6fb47a21fa 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -258,8 +258,13 @@ def get_mm_max_tokens( seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, ) -> Mapping[str, int]: - max_tokens_per_item = self.processing_info.get_max_tokens_per_item( - seq_len=seq_len, mm_counts=mm_counts) + if mm_counts is None: + mm_counts = self.get_mm_limits() + + max_tokens_per_item = self.processing_info.get_mm_max_tokens_per_item( + seq_len=seq_len, + mm_counts=mm_counts, + ) if max_tokens_per_item is not None: if mm_counts is None: total_mm_tokens = sum(max_tokens_per_item.values()) @@ -270,7 +275,7 @@ def get_mm_max_tokens( if total_mm_tokens > seq_len: logger.warning_once( "The sequence length (%d) is smaller than the pre-defined" - " wosrt-case total number of multimodal tokens (%d). " + " worst-case total number of multimodal tokens (%d). " "This may cause certain multi-modal inputs to fail during " "inference. To avoid this, you should increase " "`max_model_len` or reduce `mm_counts`.", diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 27aaa661c35c..c44fcacd246c 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -266,7 +266,7 @@ def create_processor( if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if tokenizer is None: + if tokenizer is None and not model_config.skip_tokenizer_init: tokenizer = cached_tokenizer_from_config(model_config) if disable_cache is None: mm_config = model_config.get_multimodal_config() diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 13453d2c4b4b..c13659f8a06e 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Optional from vllm.plugins import load_plugins_by_group -from vllm.utils import resolve_obj_by_qualname +from vllm.utils import resolve_obj_by_qualname, supports_xccl from .interface import _Backend # noqa: F401 from .interface import CpuArchEnum, Platform, PlatformEnum @@ -116,33 +116,25 @@ def rocm_platform_plugin() -> Optional[str]: return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None -def hpu_platform_plugin() -> Optional[str]: - is_hpu = False - logger.debug("Checking if HPU platform is available.") - try: - from importlib import util - is_hpu = util.find_spec('habana_frameworks') is not None - if is_hpu: - logger.debug("Confirmed HPU platform is available.") - else: - logger.debug("HPU platform is not available because " - "habana_frameworks is not found.") - except Exception as e: - logger.debug("HPU platform is not available because: %s", str(e)) - - return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None - - def xpu_platform_plugin() -> Optional[str]: is_xpu = False logger.debug("Checking if XPU platform is available.") try: # installed IPEX if the machine has XPUs. import intel_extension_for_pytorch # noqa: F401 - import oneccl_bindings_for_pytorch # noqa: F401 import torch + if supports_xccl(): + dist_backend = "xccl" + else: + dist_backend = "ccl" + import oneccl_bindings_for_pytorch # noqa: F401 + if hasattr(torch, 'xpu') and torch.xpu.is_available(): is_xpu = True + from vllm.platforms.xpu import XPUPlatform + XPUPlatform.dist_backend = dist_backend + logger.debug("Confirmed %s backend is available.", + XPUPlatform.dist_backend) logger.debug("Confirmed XPU platform is available.") except Exception as e: logger.debug("XPU platform is not available because: %s", str(e)) @@ -199,7 +191,6 @@ def neuron_platform_plugin() -> Optional[str]: 'tpu': tpu_platform_plugin, 'cuda': cuda_platform_plugin, 'rocm': rocm_platform_plugin, - 'hpu': hpu_platform_plugin, 'xpu': xpu_platform_plugin, 'cpu': cpu_platform_plugin, 'neuron': neuron_platform_plugin, diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index dccd60f4463a..31a67183ff12 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -1,13 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json import os import platform +import subprocess import sys +from dataclasses import dataclass from importlib.util import find_spec from typing import TYPE_CHECKING, Optional -import psutil import torch from vllm.logger import init_logger @@ -32,11 +34,41 @@ def get_max_threads(pid=0): raise NotImplementedError("Unsupported OS") +@dataclass +class LogicalCPUInfo: + id: int = -1 + physical_core: int = -1 + numa_node: int = -1 + + @classmethod + def _int(cls, value: str) -> int: + try: + int_value = int(value) + except Exception: + int_value = -1 + return int_value + + @staticmethod + def json_decoder(obj_dict: dict): + id = obj_dict.get("cpu") + physical_core = obj_dict.get("core") + numa_node = obj_dict.get("node") + + if not (id is None or physical_core is None or numa_node is None): + return LogicalCPUInfo( + id=LogicalCPUInfo._int(id), + physical_core=LogicalCPUInfo._int(physical_core), + numa_node=LogicalCPUInfo._int(numa_node)) + else: + return obj_dict + + class CpuPlatform(Platform): _enum = PlatformEnum.CPU device_name: str = "cpu" device_type: str = "cpu" dispatch_key: str = "CPU" + dist_backend: str = "gloo" @property def supported_dtypes(self) -> list[torch.dtype]: @@ -64,17 +96,34 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: - logger.info("Using CPU MLA backend.") - return "vllm.attention.backends.cpu_mla.CPUMLABackend" + raise NotImplementedError("MLA is not supported on CPU.") logger.info("Using Torch SDPA backend.") - if use_v1: - return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" - else: - return "vllm.attention.backends.torch_sdpa.TorchSDPABackend" + if not use_v1: + raise ValueError("CPU backend only supports V1.") + return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: - return psutil.virtual_memory().total + import vllm.envs as envs + from vllm.utils import GiB_bytes + + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + if kv_cache_space is None: + kv_cache_space = 4 * GiB_bytes # type: ignore + logger.warning_once( + "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " + "for CPU backend is not set, using 4 by default.") + else: + kv_cache_space *= GiB_bytes + + return kv_cache_space + + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.cpu.set_device(device) @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: @@ -86,11 +135,10 @@ def inference_mode(cls): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - import vllm.envs as envs - from vllm.utils import GiB_bytes model_config = vllm_config.model_config - model_config.disable_cascade_attn = True + if model_config is not None: + model_config.disable_cascade_attn = True cache_config = vllm_config.cache_config @@ -117,26 +165,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "CPU backend doesn't support fp8_e4m3 KV cache type, " "cast to fp8_e5m2.") - if (cache_config.cache_dtype != "auto" + if (cache_config.cache_dtype != "auto" and model_config is not None and model_config.dtype == torch.half): logger.warning("FP8 KV cache on the CPU backend only does not" " support fp16 for now, cast to bf16.") model_config.dtype = torch.bfloat16 - kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE - - if kv_cache_space >= 0: - if kv_cache_space == 0: - cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore - logger.warning( - "Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) " - "for CPU backend is not set, using 4 by default.") - else: - cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa - else: - raise RuntimeError( - "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" - f" {kv_cache_space}, expect a positive integer value.") + cache_config.cpu_kvcache_space_bytes = \ + CpuPlatform.get_device_total_memory() parallel_config = vllm_config.parallel_config if (parallel_config.world_size > 1 @@ -147,26 +183,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "mp" if parallel_config.worker_cls == "auto": - if vllm_config.speculative_config: - parallel_config.worker_cls = \ - "vllm.spec_decode.spec_decode_worker.create_spec_worker" - parallel_config.sd_worker_cls = \ - "vllm.worker.cpu_worker.CPUWorker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.cpu_worker.CPUWorker" - else: - parallel_config.worker_cls = \ - "vllm.worker.cpu_worker.CPUWorker" + parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker" # Note: workaround for v1 gpu_model_runner from vllm.config import CompilationLevel vllm_config.compilation_config.cudagraph_capture_sizes = [] compilation_config = vllm_config.compilation_config - if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE): + if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE: # Note: vLLM V1 is using PIECEWISE level compilation, which will # take time to compile kernels just-in-time with the inductor @@ -189,8 +213,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: False, "nan_asserts": False, - "memory_planning": - True, "epilogue_fusion": True, }) @@ -235,7 +257,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: os.environ["LOCAL_WORLD_SIZE"] = str( vllm_config.parallel_config.tensor_parallel_size) - if vllm_config.model_config and vllm_config.model_config.use_mla: + if model_config is not None and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") @@ -245,6 +267,38 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: vllm_config.scheduler_config.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS) + @classmethod + def get_allowed_cpu_memory_node_list( + cls) -> tuple[list[int], list[LogicalCPUInfo]]: + assert platform.system() == "Linux" + + # Init LogicalCPUInfo from lscpu + lscpu_output = subprocess.check_output("lscpu -J -e=CPU,CORE,NODE", + shell=True, + text=True) + logical_cpu_list: list[LogicalCPUInfo] = json.loads( + lscpu_output, object_hook=LogicalCPUInfo.json_decoder)['cpus'] + + # Filter CPUs with invalid attributes + logical_cpu_list = [ + x for x in logical_cpu_list + if -1 not in (x.id, x.physical_core, x.numa_node) + ] + + # Filter allowed CPUs + allowed_cpu_id_list = os.sched_getaffinity(0) + logical_cpu_list = [ + x for x in logical_cpu_list if x.id in allowed_cpu_id_list + ] + + # Get allowed NUMA nodes + allowed_numa_nodes = set() + for x in logical_cpu_list: + allowed_numa_nodes.add(x.numa_node) # type: ignore + allowed_numa_nodes_list = sorted(allowed_numa_nodes) + + return allowed_numa_nodes_list, logical_cpu_list + @classmethod def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on CPU.") @@ -277,5 +331,6 @@ def default_v1(cls, model_config) -> bool: """Returns whether the current platform can use v1 by default for the supplied model configuration. """ - return cls.supports_v1( - model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86 + arch = cls.get_cpu_architecture() + return (cls.supports_v1(model_config) and arch + in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM)) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 0a5f4004e448..9a8941e3cdd1 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -56,6 +56,7 @@ class CudaPlatformBase(Platform): device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" + dist_backend: str = "nccl" device_control_env_var: str = "CUDA_VISIBLE_DEVICES" @property @@ -76,7 +77,7 @@ def set_device(cls, device: torch.device) -> None: """ Set the device for the current platform. """ - super().set_device(device) + torch.cuda.set_device(device) # With this trick we can force the device to be set eagerly # see https://github.com/pytorch/pytorch/issues/155668 # for why and when it is needed @@ -98,7 +99,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager: + if enforce_eager and not envs.VLLM_USE_V1: logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " @@ -131,14 +132,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config.worker_cls = \ "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.gpu_worker.Worker" - else: - parallel_config.worker_cls = \ - "vllm.spec_decode.spec_decode_worker.create_spec_worker" - parallel_config.sd_worker_cls = \ - "vllm.worker.worker.Worker" + if not envs.VLLM_USE_V1: + raise NotImplementedError( + "Speculative decoding is not supported on vLLM V0.") + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ @@ -165,20 +162,26 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") + use_cutlass_mla = (envs.VLLM_ATTENTION_BACKEND is not None \ + and envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1") + if use_cutlass_mla and cache_config.block_size != 128: + cache_config.block_size = 128 + logger.info("Forcing kv cache block size to 128 for " + "CUTLASS_MLA_VLLM_V1 backend.") + + compilation_config = vllm_config.compilation_config if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" and parallel_config.data_parallel_size > 1 - and vllm_config.compilation_config.use_cudagraph): + and compilation_config.use_cudagraph): logger.info( "Data Parallel: Forcing enforce eager to be True since DP " "with DeepEP high-throughput kernels are not CUDA Graph " "compatible. The DeepEP low-latency kernels are CUDA Graph " "compatible. Set the all_to_all backend to deepep_low_latency " "to use those kernels instead.") - vllm_config.compilation_config.use_cudagraph = False - vllm_config.model_config.enforce_eager = True - # TODO (varun): Turning this ON gives incorrect results for the - # Deepseek-V2-lite model. - vllm_config.compilation_config.use_inductor = False + compilation_config.use_cudagraph = False + if model_config is not None: + model_config.enforce_eager = True @classmethod def get_current_memory_usage(cls, @@ -241,6 +244,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if selected_backend == _Backend.FLASHINFER: logger.info_once("Using FlashInfer backend on V1 engine.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + set_kv_cache_layout("HND") return FLASHINFER_V1 elif selected_backend == _Backend.FLEX_ATTENTION: logger.info_once("Using FlexAttention backend on V1 engine.") @@ -252,44 +259,68 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info_once("Using Flash Attention backend on V1 engine.") return FLASH_ATTN_V1 - from vllm.attention.selector import supports_head_size + from vllm.attention.selector import is_attn_backend_supported # Default backends for V1 engine - # FP32 is only supported by FlexAttention - if dtype not in (torch.float16, torch.bfloat16): - logger.info_once( - "Using FlexAttention backend for %s on V1 engine.", - dtype, - ) - return FLEX_ATTENTION_V1 - # Prefer FlashInfer for Blackwell GPUs if installed - if cls.is_device_capability(100) and \ - supports_head_size(FLASHINFER_V1, head_size): - try: - import flashinfer # noqa: F401 + if cls.is_device_capability(100): + if is_default_backend_supported := is_attn_backend_supported( + FLASHINFER_V1, head_size, dtype): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + logger.info_once( - "Using FlashInfer backend on V1 engine by default for " - "Blackwell (SM 10.0) GPUs.") + "Using FlashInfer backend with HND KV cache layout on " + "V1 engine by default for Blackwell (SM 10.0) GPUs.") + set_kv_cache_layout("HND") + return FLASHINFER_V1 - except ImportError: - logger.info_once( + + if not is_default_backend_supported.can_import: + logger.warning_once( "FlashInfer failed to import for V1 engine on " "Blackwell (SM 10.0) GPUs; it is recommended to " "install FlashInfer for better performance.") - pass + # FlashAttention is the default for SM 8.0+ GPUs - if cls.has_device_capability(80) and \ - supports_head_size(FLASH_ATTN_V1, head_size): - logger.info_once("Using Flash Attention backend on V1 engine.") - return FLASH_ATTN_V1 + if cls.has_device_capability(80): + if is_default_backend_supported := is_attn_backend_supported( + FLASH_ATTN_V1, head_size, dtype, + allow_import_error=False): + logger.info_once("Using Flash Attention backend on " + "V1 engine.") + return FLASH_ATTN_V1 + + # FlexAttention is the default for older GPUs + else: + logger.info_once("Using FlexAttention backend on V1 engine.") + return FLEX_ATTENTION_V1 + + assert not is_default_backend_supported + + use_flex_attention_reason = {} + if not is_default_backend_supported.head_size: + use_flex_attention_reason["head_size"] = head_size + if not is_default_backend_supported.dtype: + use_flex_attention_reason["dtype"] = dtype - logger.info_once("Using FlexAttention backend on V1 engine.") + logger.info_once( + "Using FlexAttention backend for %s on V1 engine.", + ", ".join(f"{k}={v}" + for k, v in use_flex_attention_reason.items()), + ) return FLEX_ATTENTION_V1 # Backends for V0 engine if selected_backend == _Backend.FLASHINFER: logger.info("Using FlashInfer backend.") + if cls.has_device_capability(100): + from vllm.v1.attention.backends.utils import ( + set_kv_cache_layout) + logger.info_once( + "Using HND KV cache layout on V1 engine by default for " + "Blackwell (SM 10.0) GPUs.") + set_kv_cache_layout("HND") return "vllm.attention.backends.flashinfer.FlashInferBackend" elif selected_backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") @@ -298,6 +329,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using DualChunkFlashAttention backend.") return ("vllm.attention.backends.dual_chunk_flash_attn." "DualChunkFlashAttentionBackend") + elif selected_backend == _Backend.DIFFERENTIAL_FLASH_ATTN: + logger.info("Using DifferentialFlashAttention backend.") + return ("vllm.attention.backends.differential_flash_attn." + "DifferentialFlashAttentionBackend") elif selected_backend == _Backend.FLASH_ATTN: pass elif selected_backend: @@ -421,6 +456,19 @@ def stateless_init_device_torch_dist_pg( def device_count(cls) -> int: return cuda_device_count_stateless() + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + fp8_attention = kv_cache_dtype.startswith("fp8") + will_use_fa = (not envs.is_set("VLLM_ATTENTION_BACKEND") + ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" + supported = False + if cls.is_device_capability(100): + supported = True + elif fp8_attention and will_use_fa: + from vllm.attention.utils.fa_utils import flash_attn_supports_fp8 + supported = flash_attn_supports_fp8() + return supported + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py deleted file mode 100644 index 3cf28950190c..000000000000 --- a/vllm/platforms/hpu.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from typing import TYPE_CHECKING, Optional - -import torch - -from vllm import envs -from vllm.logger import init_logger -from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS - -from .interface import Platform, PlatformEnum, _Backend - -if TYPE_CHECKING: - from vllm.config import VllmConfig -else: - VllmConfig = None - -logger = init_logger(__name__) - - -class HpuPlatform(Platform): - _enum = PlatformEnum.HPU - device_name: str = "hpu" - device_type: str = "hpu" - dispatch_key: str = "HPU" - ray_device_key: str = "HPU" - device_control_env_var: str = "HABANA_VISIBLE_MODULES" - - @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, - dtype: torch.dtype, kv_cache_dtype: Optional[str], - block_size: int, use_v1: bool, - use_mla: bool) -> str: - logger.info("Using HPUAttention backend.") - return "vllm.attention.backends.hpu_attn.HPUAttentionBackend" - - @classmethod - def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return True - - @classmethod - def inference_mode(cls): - return torch.no_grad() - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - - scheduler_config = vllm_config.scheduler_config - parallel_config = vllm_config.parallel_config - if scheduler_config.is_multi_step: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_hpu_worker.MultiStepHPUWorker" - - if vllm_config.speculative_config is not None: - raise NotImplementedError( - "Speculative decoding is not implemented for HPU") - - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" - - # NOTE(kzawora): default block size for Gaudi should be 128 - # smaller sizes still work, but very inefficiently - cache_config = vllm_config.cache_config - if cache_config and cache_config.block_size is None: - cache_config.block_size = 128 - if (parallel_config.distributed_executor_backend == 'mp' - and envs.VLLM_WORKER_MULTIPROC_METHOD == 'fork'): - if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", - None) is not None: - logger.warning("On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork " - "might cause application hangs on exit. Using " - "VLLM_WORKER_MULTIPROC_METHOD=fork anyway, " - "as it was explicitly requested.") - else: - logger.warning( - "On HPU, VLLM_WORKER_MULTIPROC_METHOD=fork " - "might cause application hangs on exit. Setting " - "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. " - "To override that behavior, please set " - "VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - - if vllm_config.model_config and vllm_config.model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled.") - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.chunked_prefill_enabled = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.scheduler_config.max_model_len, - DEFAULT_MAX_NUM_BATCHED_TOKENS) - - @classmethod - def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on HPU.") - return False - - @classmethod - def get_punica_wrapper(cls) -> str: - return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU" - - @classmethod - def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator" # noqa diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 567d5cbf503f..02cc392244ba 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -54,12 +54,11 @@ class _Backend(enum.Enum): FLASHMLA_VLLM_V1 = enum.auto() FLASHMLA = enum.auto() # Supported by V1 CUTLASS_MLA_VLLM_V1 = enum.auto() - HPU_ATTN = enum.auto() PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() - BLOCK_SPARSE_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto() + DIFFERENTIAL_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() FLEX_ATTENTION = enum.auto() @@ -68,7 +67,6 @@ class PlatformEnum(enum.Enum): CUDA = enum.auto() ROCM = enum.auto() TPU = enum.auto() - HPU = enum.auto() XPU = enum.auto() CPU = enum.auto() NEURON = enum.auto() @@ -129,6 +127,9 @@ class Platform: # compilation strategy. simple_compile_backend: str = "inductor" + # The backend used for distributed communication. + dist_backend: str = "" + supported_quantization: list[str] = [] additional_env_vars: list[str] = [] @@ -150,9 +151,6 @@ def is_rocm(self) -> bool: def is_tpu(self) -> bool: return self._enum == PlatformEnum.TPU - def is_hpu(self) -> bool: - return self._enum == PlatformEnum.HPU - def is_xpu(self) -> bool: return self._enum == PlatformEnum.XPU @@ -302,7 +300,7 @@ def set_device(cls, device: torch.device) -> None: """ Set the device for the current platform. """ - torch.cuda.set_device(device) + raise NotImplementedError @classmethod def pre_register_and_update(cls, @@ -545,6 +543,13 @@ def stateless_init_device_torch_dist_pg( """ raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + """ + Returns if the kv_cache_dtype is supported by the current platform. + """ + return False + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 04e918d7aebe..cb8ac8db669f 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -30,6 +30,7 @@ class NeuronPlatform(Platform): device_type: str = "neuron" ray_device_key: str = "neuron_cores" supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"] + dist_backend: str = "gloo" device_control_env_var: str = "NEURON_RT_VISIBLE_CORES" @classmethod diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4550ef570684..b2e69f60343f 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -164,6 +164,7 @@ class RocmPlatform(Platform): device_type: str = "cuda" dispatch_key: str = "CUDA" ray_device_key: str = "GPU" + dist_backend: str = "nccl" # rocm shares the same device control env var as CUDA device_control_env_var: str = "CUDA_VISIBLE_DEVICES" @@ -240,6 +241,13 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, logger.info("Using ROCmFlashAttention backend.") return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.cuda.set_device(device) + @classmethod @lru_cache(maxsize=8) def get_device_capability(cls, @@ -291,7 +299,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - if enforce_eager: + if enforce_eager and not envs.VLLM_USE_V1: logger.warning( "To see benefits of async output processing, enable CUDA " "graph. Since, enforce-eager is enabled, async output " @@ -318,15 +326,10 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config.worker_cls = \ "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: - if envs.VLLM_USE_V1: + if not envs.VLLM_USE_V1: raise NotImplementedError( - "Speculative decoding is not yet supported on vLLM V1." - ) - else: - parallel_config.worker_cls = \ - "vllm.spec_decode.spec_decode_worker.create_spec_worker" - parallel_config.sd_worker_cls = \ - "vllm.worker.worker.Worker" + "Speculative decoding is not supported on vLLM V0.") + parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" else: if envs.VLLM_USE_V1: parallel_config.worker_cls = \ @@ -451,3 +454,7 @@ def stateless_init_device_torch_dist_pg( @classmethod def device_count(cls) -> int: return cuda_device_count_stateless() + + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + return True \ No newline at end of file diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0387e348965d..146801c9d773 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -6,7 +6,6 @@ import torch from tpu_info import device -import vllm.envs as envs from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger from vllm.sampling_params import SamplingParams, SamplingType @@ -32,10 +31,13 @@ class TpuPlatform(Platform): device_type: str = "tpu" dispatch_key: str = "XLA" ray_device_key: str = "TPU" + dist_backend: str = "gloo" device_control_env_var: str = "TPU_VISIBLE_CHIPS" simple_compile_backend: str = "openxla" - supported_quantization: list[str] = ["tpu_int8", "compressed-tensors"] + supported_quantization: list[str] = [ + "fp8", "tpu_int8", "compressed-tensors" + ] additional_env_vars: list[str] = [ "TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS" @@ -50,12 +52,17 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) - if use_v1: - logger.info("Using Pallas V1 backend.") - return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" - else: - logger.info("Using Pallas backend.") - return "vllm.attention.backends.pallas.PallasAttentionBackend" + if not use_v1: + raise ValueError("TPU backend only supports V1.") + logger.info("Using Pallas V1 backend.") + return "vllm.v1.attention.backends.pallas.PallasAttentionBackend" + + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.tpu.set_device(device) @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -68,7 +75,7 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @classmethod def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: - return not envs.VLLM_USE_V1 + return False @classmethod def get_punica_wrapper(cls) -> str: @@ -111,37 +118,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: assert vllm_config.speculative_config is None, \ "TPU does not support speculative decoding" - if vllm_config.model_config.dtype in (torch.float16, torch.float32): + model_config = vllm_config.model_config + if model_config is not None and model_config.dtype in (torch.float16, + torch.float32): logger.warning( "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", vllm_config.model_config.dtype) - vllm_config.model_config.dtype = torch.bfloat16 + "Using bfloat16 instead.", model_config.dtype) + model_config.dtype = torch.bfloat16 - if envs.VLLM_USE_V1: - from vllm.v1.attention.backends.pallas import ( - PallasAttentionBackend) - cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) # type: ignore[assignment] + from vllm.v1.attention.backends.pallas import PallasAttentionBackend + cache_config.block_size = PallasAttentionBackend.get_page_size( + vllm_config) # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: - if envs.VLLM_USE_V1: - raise NotImplementedError( - "Multi-step scheduling is not supported (and not " - "needed) on vLLM V1. Please launch without " - "--num-scheduler-steps.") - else: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" - else: - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.tpu_worker.TPUWorker" - else: - parallel_config.worker_cls = \ - "vllm.worker.tpu_worker.TPUWorker" + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" assert not vllm_config.speculative_config, ( "Speculative decoding is not yet supported for TPU backend") @@ -153,7 +150,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True - if vllm_config.model_config and vllm_config.model_config.use_mla: + if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") @@ -189,13 +186,13 @@ def validate_request( processed_inputs: ProcessorInputs, ) -> None: """Raises if this request is unsupported on this platform""" - if isinstance(params, SamplingParams): - if params.guided_decoding is not None and not envs.VLLM_USE_V1: - raise ValueError("Structured output is not supported on " - f"{cls.device_name} V0.") - if params.sampling_type == SamplingType.RANDOM_SEED: - raise ValueError( - "Torch XLA does not support per-request seed.") + if (isinstance(params, SamplingParams) + and params.sampling_type == SamplingType.RANDOM_SEED): + raise ValueError("Torch XLA does not support per-request seed.") + + @classmethod + def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: + return True try: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 61a0453dcbc8..c4530c1dfaa3 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -29,6 +29,7 @@ class XPUPlatform(Platform): # Intel XPU's device key is "GPU" for Ray. # see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501 ray_device_key: str = "GPU" + dist_backend: str = "ccl" # ccl | xccl device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR" @classmethod @@ -36,15 +37,20 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool) -> str: - if selected_backend != _Backend.IPEX: + if selected_backend is not None and selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) use_v1 = envs.VLLM_USE_V1 - if use_v1: - logger.info("Using Flash Attention backend on V1 engine.") - return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" - else: - logger.info("Using IPEX attention backend.") - return "vllm.attention.backends.ipex_attn.IpexAttnBackend" + if not use_v1: + raise ValueError("XPU backend only supports V1.") + logger.info("Using Flash Attention backend on V1 engine.") + return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" + + @classmethod + def set_device(cls, device: torch.device) -> None: + """ + Set the device for the current platform. + """ + torch.xpu.set_device(device) @classmethod def get_device_capability( @@ -59,6 +65,10 @@ def get_device_capability( def get_device_name(cls, device_id: int = 0) -> str: return torch.xpu.get_device_name(device_id) + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU" + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.xpu.get_device_properties(device_id) @@ -75,18 +85,22 @@ def inference_mode(cls): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config + model_config = vllm_config.model_config # in V1(or with ipex chunked prefill) block_size is 64 if cache_config and cache_config.block_size is None: - if envs.VLLM_USE_V1: - cache_config.block_size = 64 - else: - cache_config.block_size = 16 + cache_config.block_size = 64 + + # FIXME: Temporarily forcing eager mode + # remove after t.compile support stabilizes. + if (envs.VLLM_USE_V1 and model_config is not None + and not vllm_config.model_config.enforce_eager): + from vllm.config import CompilationLevel + vllm_config.compilation_config.level = CompilationLevel.NO_COMPILATION # noqa: E501 # Instances created using VllmConfig() typically have model_config as # None by default. The modification involves adding a check to prevent # potential null exceptions check and update model config. - if vllm_config.model_config is not None: - model_config = vllm_config.model_config + if model_config is not None: if model_config.dtype == torch.bfloat16: bf16_supported = cls.device_support_bf16() if not bf16_supported: @@ -97,20 +111,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "mode.") model_config.enforce_eager = True - if vllm_config.speculative_config is not None: - raise NotImplementedError( - "XPU does not support speculative decoding") - - if vllm_config.device_config is not None: - assert vllm_config.device_config.device_type == "xpu" - # check and update parallel config parallel_config = vllm_config.parallel_config - if envs.VLLM_USE_V1: - parallel_config.worker_cls =\ - "vllm.v1.worker.xpu_worker.XPUWorker" - else: - parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" + parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker" if parallel_config.distributed_executor_backend is None: if parallel_config.world_size > 1: @@ -125,15 +128,17 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" logger.warning( "Please use spawn as start method if you want to use mp.") - elif parallel_config.distributed_executor_backend != "ray" and \ - parallel_config.distributed_executor_backend != "uni": + elif (parallel_config.distributed_executor_backend != "ray" + and parallel_config.distributed_executor_backend != "uni" + and parallel_config.distributed_executor_backend + != "external_launcher"): logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "ray" - if vllm_config.model_config and vllm_config.model_config.use_mla: + if model_config and model_config.use_mla: logger.info( "MLA is enabled on a non-GPU platform; forcing chunked " "prefill and prefix caching to be disabled.") @@ -145,8 +150,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: @classmethod def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on XPU.") - return False + return True @classmethod def get_current_memory_usage(cls, @@ -189,4 +193,4 @@ def supports_v1(cls, model_config: ModelConfig) -> bool: @classmethod def device_count(cls) -> int: - return torch.xpu.device_count() \ No newline at end of file + return torch.xpu.device_count() diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 2cb177b9ba78..51c78ddc1a9d 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import logging -import os from typing import Any, Callable import torch @@ -75,18 +74,6 @@ def load_general_plugins(): if current_platform.is_xpu(): # see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158 torch._dynamo.config.disable = True - elif current_platform.is_hpu(): - # NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1) - # does not support torch.compile - # Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for - # torch.compile support - is_lazy = os.environ.get('PT_HPU_LAZY_MODE', '1') == '1' - if is_lazy: - torch._dynamo.config.disable = True - # NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only) - # requires enabling lazy collectives - # see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501 - os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = 'true' plugins = load_plugins_by_group(group=DEFAULT_PLUGINS_GROUP) # general plugins, we only need to execute the loaded functions diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py index 106f3e8b22b7..868facbe2557 100644 --- a/vllm/pooling_params.py +++ b/vllm/pooling_params.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Literal, Optional import msgspec @@ -10,31 +10,49 @@ if TYPE_CHECKING: from vllm.config import ModelConfig +PoolingTask = Literal["encode", "embed", "classify", "score"] + class PoolingParams( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] array_like=True): # type: ignore[call-arg] - """API parameters for pooling models. This is currently a placeholder. + """API parameters for pooling models. Attributes: dimensions: Reduce the dimensions of embeddings if model support matryoshka representation. - additional_data: Any additional data needed for pooling. """ dimensions: Optional[int] = None - use_cross_encoder: bool = False - additional_data: Optional[Any] = None + output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY + task: Optional[PoolingTask] = None + """Internal use only.""" + + requires_token_ids: bool = False + """Internal use only.""" + def clone(self) -> "PoolingParams": """Returns a deep copy of the PoolingParams instance.""" - return PoolingParams(dimensions=self.dimensions, - use_cross_encoder=self.use_cross_encoder, - additional_data=self.additional_data) + return PoolingParams( + dimensions=self.dimensions, + task=self.task, + requires_token_ids=self.requires_token_ids, + ) + + def verify(self, task: PoolingTask, model_config: "ModelConfig") -> None: + if self.task is None: + self.task = task + elif self.task != task: + msg = f"You cannot overwrite {self.task=!r} with {task=!r}!" + raise ValueError(msg) + + # NOTE: Task validation needs to done against the model instance, + # which is not available in model config. So, it's not included + # in this method - def verify(self, model_config: "ModelConfig") -> None: if self.dimensions is not None: if not model_config.is_matryoshka: raise ValueError( @@ -56,8 +74,8 @@ def verify(self, model_config: "ModelConfig") -> None: def __repr__(self) -> str: return (f"PoolingParams(" f"dimensions={self.dimensions}, " - f"use_cross_encoder={self.use_cross_encoder}, " - f"additional_metadata={self.additional_data})") + f"task={self.task}, " + f"requires_token_ids={self.requires_token_ids})") def __post_init__(self) -> None: assert self.output_kind == RequestOutputKind.FINAL_ONLY,\ diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py deleted file mode 100644 index b5b925d042f2..000000000000 --- a/vllm/prompt_adapter/layers.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass -from typing import Optional - -import torch -from torch import nn - -from vllm.adapter_commons.layers import AdapterMapping -from vllm.config import PromptAdapterConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) - - -@dataclass -class PromptAdapterMapping(AdapterMapping): - pass - - -class VocabParallelEmbeddingWithPromptAdapter(nn.Module): - - def __init__(self, base_layer: VocabParallelEmbedding) -> None: - super().__init__() - self.base_layer = base_layer - self.emb_layer = self.base_layer - if 'LoRA' in base_layer.__class__.__name__: - self.emb_layer = self.base_layer.base_layer - - def create_prompt_adapter_weights( - self, prompt_adapter_config: PromptAdapterConfig): - self.embeddings_tensors = torch.zeros( - ( - prompt_adapter_config.max_prompt_adapters, - prompt_adapter_config.max_prompt_adapter_token, - self.emb_layer.embedding_dim, - ), - dtype=self.emb_layer.weight.dtype, - device=self.emb_layer.weight.device, - ) - self.adapter_lengths = torch.zeros( - prompt_adapter_config.max_prompt_adapters, - dtype=torch.long, - device=self.emb_layer.weight.device) - - self.indices_gpu: torch.Tensor - self.embedding_indices_gpu: torch.Tensor - - def reset_prompt_adapter(self, index: int): - self.embeddings_tensors[index] = 0 - - def set_prompt_adapter( - self, - index: int, - adapter_model: Optional[torch.Tensor], - ): - self.reset_prompt_adapter(index) - if adapter_model is not None: - length = adapter_model.shape[0] - self.embeddings_tensors[index, :length] = adapter_model - self.adapter_lengths[index] = length - - def set_mapping( - self, - prompt_indices: torch.Tensor, - prompt_embedding_indices: torch.Tensor, - ): - self.indices_gpu = prompt_indices.to( - device=self.emb_layer.weight.device) - self.embedding_indices_gpu = prompt_embedding_indices.to( - device=self.emb_layer.weight.device) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - hidden_states = self.base_layer(x) - if self.embedding_indices_gpu.ndim > 1: - valid_mask = self.indices_gpu != -1 - gathered_embeddings = self.embeddings_tensors[ - self.embedding_indices_gpu[:, 0], - self.embedding_indices_gpu[:, 1]] - - # Update hidden states - hidden_states[valid_mask] = gathered_embeddings - return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py deleted file mode 100644 index 864b50c861e1..000000000000 --- a/vllm/prompt_adapter/models.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import logging -import math -from typing import Any, Callable, Dict, List, Optional, Type - -import torch -from torch import nn - -from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, - AdapterModelManager) -from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, - get_adapter, list_adapters, - remove_adapter, set_adapter_mapping) -from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.layers import ( - VocabParallelEmbeddingWithPromptAdapter) # yapf: disable -from vllm.prompt_adapter.layers import PromptAdapterMapping -from vllm.prompt_adapter.utils import load_peft_weights - -logger = logging.getLogger(__name__) - -_GLOBAL_PROMPT_ADAPTER_ID = 0 - - -def get_prompt_adapter_id(): - global _GLOBAL_PROMPT_ADAPTER_ID - _GLOBAL_PROMPT_ADAPTER_ID += 1 - return _GLOBAL_PROMPT_ADAPTER_ID - - -def convert_to_embedding_indices(indices): - embedding_indices = [] - count = 0 - - for value in indices: - if value == -1: - count = 0 - else: - embedding_indices.append([value, count]) - count += 1 - - return torch.tensor(embedding_indices) - - -def convert_mapping( - mapping: PromptAdapterMapping, - prompt_adapter_index_to_id: List[Optional[int]], -) -> torch.Tensor: - """Converts PromptAdapterMapping to index tensors. - - Args: - mapping: PromptAdapterMapping mapping rows in a - batch to PromptAdapter ids. - prompt_adapter_index_to_id: List mapping PromptAdapter - ids to PromptAdapter indices. - - Returns: - pa_indices: Tensor of shape [batch_size] mapping batch rows to - PromptAdapter indices. - """ - id_to_index = { - id_: idx - for idx, id_ in enumerate(prompt_adapter_index_to_id) - if id_ is not None - } - pa_indices = ([ - id_to_index.get(id_, -1) if id_ > 0 else -1 - for id_ in mapping.index_mapping - ]) - - pa_embedding_mapping = convert_to_embedding_indices(pa_indices) - pa_indices = torch.tensor(pa_indices) - return pa_indices, pa_embedding_mapping - - -class PromptAdapterModel(AdapterModel): - - def __init__(self, - prompt_adapter_id=None, - num_virtual_tokens=None, - prompt_embedding=None) -> None: - self.id = prompt_adapter_id - self.prompt_embedding = prompt_embedding - self.num_virtual_tokens = num_virtual_tokens - - @classmethod - def from_local_checkpoint( - cls, - adapter_model_path: str, - prompt_adapter_id: int, - num_virtual_tokens: int, - config: PromptAdapterConfig, - device: str = "cuda", - ) -> "PromptAdapterModel": - - if num_virtual_tokens > config.max_prompt_adapter_token: - raise ValueError( - f'num_virtual_tokens ({num_virtual_tokens}) should be <= ' - f'max_prompt_adapter_token({config.max_prompt_adapter_token})') - - adapters_weights = load_peft_weights(adapter_model_path, device) - prompt_embedding = adapters_weights["prompt_embeddings"].to( - config.prompt_adapter_dtype) - - return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) - - -class PromptAdapterModelManager(AdapterModelManager): - """A manager that manages multiple Prompt Adapter models.""" - - def __init__( - self, - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - prompt_adapter_config: PromptAdapterConfig, - ): - """Create a PromptAdapterModel and adapter for a given model. - - Args: - model: the model to be adapted. - max_num_seqs: the maximum number of sequences model can run in a - single batch. - max_num_batched_tokens: the maximum number of tokens model can run - in a single batch. - prompt_adapter_config: the PromptAdapter config, - """ - self.model: nn.Module = model - # Dict instead of a Set for compatibility with LRUCache. - self.prompt_adapter_index_to_id: List[ - Optional[int]] = [None] * self.prompt_adapter_slots - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.prompt_adapter_config = prompt_adapter_config - self.model.prompt_adapter_manager = self - self.adapter_type = 'PromptAdapter' - - self.base_indices = torch.tensor([-1]) - self.base_embedding_indices = torch.tensor([]) - - self.modules: Dict[str, nn.Module] = {} - self._create_prompt_adapter_modules() - self._last_mapping: Optional[PromptAdapterMapping] = None - - @property - def prompt_adapter_slots(self) -> int: - return self.prompt_adapter_config.max_prompt_adapters - - @property - def adapter_slots(self) -> int: - return self.prompt_adapter_slots - - @property - def capacity(self) -> int: - return self.prompt_adapter_config.max_cpu_prompt_adapters - - def activate_adapter( - self, - prompt_adapter_id: int, - ) -> bool: - """Move PromptAdapter into a GPU buffer - to be used in the forward pass.""" - if prompt_adapter_id in self._active_adapters: - return False - first_free_slot = next( - ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( - self.prompt_adapter_index_to_id) if prompt_adapter_id is None), - None) - if first_free_slot is None: - raise ValueError("No free prompt_adapter slots") - index, _ = first_free_slot - self._active_adapters[prompt_adapter_id] = None - prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) - logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", - prompt_adapter_model.id, index) - self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id - for _, v in self.modules.items(): - v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) - return True - - def _deactivate_adapter(self, prompt_adapter_id: int): - try: - index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) - self.prompt_adapter_index_to_id[index] = None - for _, v in self.modules.items(): - v.reset_prompt_adapter(index) - except ValueError: - pass - - def _add_adapter(self, prompt_adapter: PromptAdapterModel): - self._registered_adapters[prompt_adapter.id] = prompt_adapter - - def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - base_indices, base_embedding_indices = convert_mapping( - mapping, self.prompt_adapter_index_to_id) - for k, v in self.modules.items(): - v.set_mapping(base_indices, base_embedding_indices) - - def _create_prompt_adapter_modules(self): - for module_name, module in self.model.named_modules( - remove_duplicate=False): - if "VocabParallel" in module.__class__.__name__: - new_module = VocabParallelEmbeddingWithPromptAdapter(module) - new_module.create_prompt_adapter_weights( - self.prompt_adapter_config) - replaced_module = self.replace_submodule( - self.model, module_name, new_module) - self.register_module(module.__class__.__name__, - replaced_module) - replaced_module.set_mapping(self.base_indices, - self.base_embedding_indices) - break - - def replace_submodule(self, model: nn.Module, module_name: str, - new_module: nn.Module) -> nn.Module: - """Replace a submodule in a model with a new module.""" - parent = model.get_submodule(".".join(module_name.split(".")[:-1])) - target_name = module_name.split(".")[-1] - setattr(parent, target_name, new_module) - return new_module - - def register_module(self, module_name: str, module: nn.Module): - self.modules[module_name] = module - - def pin_adapter(self, prompt_adapter_id: int) -> bool: - """Pin a PromptAdapterModel in the manager cache.""" - raise NotImplementedError( - "Pinning is not supported in PromptAdapterModelManager. " - "Use LRUCachePromptAdapterModelManager for pinning" - ) # type: ignore - - def remove_all_adapters(self): - """Remove all PromptAdapterModel from the manager.""" - self._registered_adapters.clear() - self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots - self._active_adapters.clear() - - def deactivate_adapter(self, adapter_id: int) -> bool: - return deactivate_adapter(adapter_id, self._active_adapters, - self._deactivate_adapter) - - def add_adapter(self, adapter: PromptAdapterModel) -> bool: - return add_adapter(adapter, self._registered_adapters, self.capacity, - self._add_adapter) - - def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: - self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, - self._set_adapter_mapping) - - def remove_adapter(self, adapter_id: int) -> bool: - return remove_adapter(adapter_id, self._registered_adapters, - self.deactivate_adapter) - - def list_adapters(self) -> Dict[int, Any]: - return list_adapters(self._registered_adapters) - - def get_adapter(self, adapter_id: int) -> Optional[Any]: - return get_adapter(adapter_id, self._registered_adapters) - - -class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): - - def __init__(self, capacity: int, - deactivate_prompt_adapter_fn: Callable[[int], bool]): - super().__init__(capacity, deactivate_prompt_adapter_fn) - - -class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): - """A model manager that manages multiple prompt_adapters with LRU cache.""" - - def __init__( - self, - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - prompt_adapter_config: PromptAdapterConfig, - ): - self.prompt_adapter_config = prompt_adapter_config - super().__init__(model, max_num_seqs, max_num_batched_tokens, - prompt_adapter_config) - self._registered_adapters = PromptAdapterLRUCache( - self.capacity, self.deactivate_adapter) - self._active_adapters = PromptAdapterLRUCache( - self.prompt_adapter_slots, self._deactivate_adapter) - - def list_adapters(self) -> Dict[int, PromptAdapterModel]: - """List all registered PromptAdapterModel.""" - return dict(self._registered_adapters.cache) - - def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: - """Add a PromptAdapterModel to the manager.""" - if prompt_adapter.id not in self._registered_adapters: - self._add_adapter(prompt_adapter) - was_added = True - else: - # We always touch to update the LRU cache order - self._registered_adapters.touch(prompt_adapter.id) - was_added = False - return was_added - - def activate_adapter( - self, - prompt_adapter_id: int, - ) -> bool: - if prompt_adapter_id not in self._active_adapters and len( - self._active_adapters) >= self.prompt_adapter_slots: - self._active_adapters.remove_oldest() - result = super().activate_adapter(prompt_adapter_id) - # We always touch to update the LRU cache order - self._active_adapters.touch(prompt_adapter_id) - return result - - def remove_oldest_adapter(self) -> bool: - if len(self._registered_adapters) > 0: - self._registered_adapters.remove_oldest() - return True - return False - - def pin_adapter(self, prompt_adapter_id: int) -> bool: - """Pin a PromptAdapterModel in the manager cache.""" - self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) - self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) - return True - - def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): - try: - self._registered_adapters.pin(prompt_adapter_id) - except ValueError as err: - raise ValueError( - "Pinning failed. " - f"Prompt Adapter {prompt_adapter_id} is not registered." - ) from err - - def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): - if prompt_adapter_id not in self._active_adapters: - # move adapter to gpu if not already active - self.activate_adapter(prompt_adapter_id) - self._active_adapters.pin(prompt_adapter_id) - - -def create_prompt_adapter_manager( - model: nn.Module, - max_num_seqs: int, - max_num_batched_tokens: int, - prompt_adapter_config: PromptAdapterConfig, - prompt_adapter_manager_cls: Type[ - PromptAdapterModelManager] = PromptAdapterModelManager, - **kwargs) -> PromptAdapterModelManager: - """Create a PromptAdapterModel for a given model.""" - prompt_adapter_manager = prompt_adapter_manager_cls( - model=model, - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - prompt_adapter_config=prompt_adapter_config, - **kwargs) - return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py deleted file mode 100644 index 3ce50d0a26bb..000000000000 --- a/vllm/prompt_adapter/request.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import msgspec - -from vllm.adapter_commons.request import AdapterRequest - - -class PromptAdapterRequest( - msgspec.Struct, - array_like=True, # type: ignore[call-arg] - omit_defaults=True, # type: ignore[call-arg] - frozen=True): # type: ignore[call-arg] - """ - Request for a Prompt adapter. - """ - __metaclass__ = AdapterRequest - - prompt_adapter_name: str - prompt_adapter_id: int - prompt_adapter_local_path: str - prompt_adapter_num_virtual_tokens: int - - def __hash__(self): - return super().__hash__() - - @property - def adapter_id(self): - return self.prompt_adapter_id - - @property - def name(self): - return self.prompt_adapter_name - - @property - def local_path(self): - return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py deleted file mode 100644 index ddd007868f6b..000000000000 --- a/vllm/prompt_adapter/utils.py +++ /dev/null @@ -1,98 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420 - -import os -from typing import Optional - -import torch -from huggingface_hub import file_exists, hf_hub_download -from huggingface_hub.utils import EntryNotFoundError -from safetensors.torch import load_file as safe_load_file - -from vllm.platforms import current_platform - -WEIGHTS_NAME = "adapter_model.bin" -SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" - - -# Get current device name based on available devices -def infer_device() -> str: - if current_platform.is_cuda_alike(): - return "cuda" - return "cpu" - - -def load_peft_weights(model_id: str, - device: Optional[str] = None, - **hf_hub_download_kwargs) -> dict: - r""" - A helper method to load the PEFT weights from the HuggingFace Hub or locally - - Args: - model_id (`str`): - The local path to the adapter weights or the name of the adapter to - load from the HuggingFace Hub. - device (`str`): - The device to load the weights onto. - hf_hub_download_kwargs (`dict`): - Additional arguments to pass to the `hf_hub_download` method when - loading from the HuggingFace Hub. - """ - path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) if - hf_hub_download_kwargs.get("subfolder") is not None else model_id) - - if device is None: - device = infer_device() - - if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): - filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) - use_safetensors = True - elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): - filename = os.path.join(path, WEIGHTS_NAME) - use_safetensors = False - else: - token = hf_hub_download_kwargs.get("token") - if token is None: - token = hf_hub_download_kwargs.get("use_auth_token") - - hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"], - SAFETENSORS_WEIGHTS_NAME) - if hf_hub_download_kwargs.get("subfolder") is not None - else SAFETENSORS_WEIGHTS_NAME) - has_remote_safetensors_file = file_exists( - repo_id=model_id, - filename=hub_filename, - revision=hf_hub_download_kwargs.get("revision"), - repo_type=hf_hub_download_kwargs.get("repo_type"), - token=token, - ) - use_safetensors = has_remote_safetensors_file - - if has_remote_safetensors_file: - # Priority 1: load safetensors weights - filename = hf_hub_download( - model_id, - SAFETENSORS_WEIGHTS_NAME, - **hf_hub_download_kwargs, - ) - else: - try: - filename = hf_hub_download(model_id, WEIGHTS_NAME, - **hf_hub_download_kwargs) - except EntryNotFoundError: - raise ValueError( # noqa: B904 - f"Can't find weights for {model_id} in {model_id} or \ - in the Hugging Face Hub. " - f"Please check that the file {WEIGHTS_NAME} or \ - {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.") - - if use_safetensors: - adapters_weights = safe_load_file(filename, device=device) - else: - adapters_weights = torch.load(filename, - map_location=torch.device(device), - weights_only=True) - - return adapters_weights diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py deleted file mode 100644 index 56265de8087c..000000000000 --- a/vllm/prompt_adapter/worker_manager.py +++ /dev/null @@ -1,179 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import logging -from typing import Any, Optional, Set, Type - -import torch - -from vllm.adapter_commons.utils import (add_adapter_worker, - apply_adapters_worker, - list_adapters_worker, - set_active_adapters_worker) -from vllm.adapter_commons.worker_manager import AbstractWorkerManager -from vllm.config import PromptAdapterConfig -from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, - PromptAdapterModel, - PromptAdapterModelManager, - create_prompt_adapter_manager) -from vllm.prompt_adapter.request import PromptAdapterRequest - -logger = logging.getLogger(__name__) - - -class WorkerPromptAdapterManager(AbstractWorkerManager): - """WorkerPromptAdapterManager that manages - prompt_adapter models on the worker side. - - Every request, the requested prompt_adapters will be - loaded (unless they are already loaded), - and every other prompt_adapter will be unloaded.""" - - _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager - - def __init__( - self, - max_num_seqs: int, - max_num_batched_tokens: int, - device: torch.device, - prompt_adapter_config: PromptAdapterConfig, - prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel - ): - self._adapter_manager: PromptAdapterModelManager - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens - self._prompt_adapter_model_cls = prompt_adapter_model_cls - self.prompt_adapter_config = prompt_adapter_config - super().__init__(device) - - @property - def is_enabled(self) -> bool: - return True - - def create_prompt_adapter_manager( - self, - model: torch.nn.Module, - ) -> Any: - prompt_adapter_manager = create_prompt_adapter_manager( - model, - max_num_seqs=self.max_num_seqs, - max_num_batched_tokens=self.max_num_batched_tokens, - prompt_adapter_config=self.prompt_adapter_config, - prompt_adapter_manager_cls=self._manager_cls, - ) - self._adapter_manager = prompt_adapter_manager - return prompt_adapter_manager.model - - def _load_adapter( - self, prompt_adapter_request: PromptAdapterRequest - ) -> PromptAdapterModel: - try: - prompt_adapter = ( - self._prompt_adapter_model_cls.from_local_checkpoint( - prompt_adapter_request.prompt_adapter_local_path, - prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, - num_virtual_tokens=prompt_adapter_request. - prompt_adapter_num_virtual_tokens, - config=self.prompt_adapter_config, - device=str(self.device), - )) - except Exception as e: - raise RuntimeError( - f"Loading prompt_adapter " - f"{prompt_adapter_request.prompt_adapter_local_path}" - f" failed") from e - return prompt_adapter - - def add_dummy_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return True - - def pin_adapter(self, adapter_id: int) -> bool: - return self._adapter_manager.pin_adapter(adapter_id) - - def set_active_adapters(self, requests: Set[Any], - mapping: Optional[Any]) -> None: - set_active_adapters_worker(requests, mapping, self._apply_adapters, - self._adapter_manager.set_adapter_mapping) - - def add_adapter(self, adapter_request: Any) -> bool: - return add_adapter_worker(adapter_request, self.list_adapters, - self._load_adapter, - self._adapter_manager.add_adapter, - self._adapter_manager.activate_adapter) - - def _apply_adapters(self, adapter_requests: Set[Any]) -> None: - apply_adapters_worker(adapter_requests, self.list_adapters, - self._adapter_manager.adapter_slots, - self.remove_adapter, self.add_adapter) - - def remove_adapter(self, adapter_id: int) -> bool: - return self._adapter_manager.remove_adapter(adapter_id) - - def remove_all_adapters(self): - self._adapter_manager.remove_all_adapters() - - def list_adapters(self) -> Set[int]: - return list_adapters_worker(self._adapter_manager.list_adapters) - - -class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): - """WorkerPromptAdapterManager that manages - prompt_adapter models on the worker side. - - Uses an LRU Cache. Every request, the requested - prompt_adapters will be loaded (unless they are already loaded) - and least recently used prompt_adapters will - be unloaded if the cache is above capacity.""" - - _prompt_adapter_manager_cls: Type[ - LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager - - def create_prompt_adapter_manager( - self, - model: torch.nn.Module, - ) -> Any: - prompt_adapter_manager = create_prompt_adapter_manager( - model, - max_num_seqs=self.max_num_seqs, - max_num_batched_tokens=self.max_num_batched_tokens, - prompt_adapter_config=self.prompt_adapter_config, - prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) - self._adapter_manager: LRUCachePromptAdapterModelManager = ( - prompt_adapter_manager) - return prompt_adapter_manager.model - - def _apply_adapters( - self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: - prompt_adapters_map = { - prompt_adapter_request.prompt_adapter_id: prompt_adapter_request - for prompt_adapter_request in prompt_adapter_requests - if prompt_adapter_request - } - if len(prompt_adapters_map - ) > self._adapter_manager.prompt_adapter_slots: - raise RuntimeError( - f"Number of requested prompt_adapters " - f"({len(prompt_adapters_map)}) is greater " - "than the number of GPU prompt_adapter slots " - f"({self._adapter_manager.prompt_adapter_slots}).") - for prompt_adapter in prompt_adapters_map.values(): - self.add_adapter(prompt_adapter) - - def add_adapter(self, - prompt_adapter_request: PromptAdapterRequest) -> bool: - if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( - ): - # Remove before we load the new prompt_adapter to save memory - if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: - self._adapter_manager.remove_oldest_adapter() - prompt_adapter = self._load_adapter(prompt_adapter_request) - loaded = self._adapter_manager.add_adapter(prompt_adapter) - else: - # If the prompt_adapter is already loaded, just touch it to - # update its position in the caches - loaded = self._adapter_manager.get_adapter( - prompt_adapter_request.prompt_adapter_id) is not None - self._adapter_manager.activate_adapter( - prompt_adapter_request.prompt_adapter_id) - return loaded diff --git a/vllm/ray/__init__.py b/vllm/ray/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/ray/ray_env.py b/vllm/ray/ray_env.py new file mode 100644 index 000000000000..f6a994bb3c22 --- /dev/null +++ b/vllm/ray/ray_env.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os +from typing import Optional + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +CONFIG_HOME = envs.VLLM_CONFIG_ROOT + +# This file contains a list of env vars that should not be copied +# from the driver to the Ray workers. +RAY_NON_CARRY_OVER_ENV_VARS_FILE = os.path.join( + CONFIG_HOME, "ray_non_carry_over_env_vars.json") + +try: + if os.path.exists(RAY_NON_CARRY_OVER_ENV_VARS_FILE): + with open(RAY_NON_CARRY_OVER_ENV_VARS_FILE) as f: + RAY_NON_CARRY_OVER_ENV_VARS = set(json.load(f)) + else: + RAY_NON_CARRY_OVER_ENV_VARS = set() +except json.JSONDecodeError: + logger.warning( + "Failed to parse %s. Using an empty set for non-carry-over env vars.", + RAY_NON_CARRY_OVER_ENV_VARS_FILE) + RAY_NON_CARRY_OVER_ENV_VARS = set() + + +def get_env_vars_to_copy(exclude_vars: Optional[set[str]] = None, + additional_vars: Optional[set[str]] = None, + destination: Optional[str] = None) -> set[str]: + """ + Get the environment variables to copy to downstream Ray actors. + + Example use cases: + - Copy environment variables from RayDistributedExecutor to Ray workers. + - Copy environment variables from RayDPClient to Ray DPEngineCoreActor. + + Args: + exclude_vars: A set of vllm defined environment variables to exclude + from copying. + additional_vars: A set of additional environment variables to copy. + If a variable is in both exclude_vars and additional_vars, it will + be excluded. + destination: The destination of the environment variables. + Returns: + A set of environment variables to copy. + """ + exclude_vars = exclude_vars or set() + additional_vars = additional_vars or set() + + env_vars_to_copy = { + v + for v in set(envs.environment_variables).union(additional_vars) + if v not in exclude_vars and v not in RAY_NON_CARRY_OVER_ENV_VARS + } + + to_destination = " to " + destination if destination is not None else "" + + logger.info("RAY_NON_CARRY_OVER_ENV_VARS from config: %s", + RAY_NON_CARRY_OVER_ENV_VARS) + logger.info("Copying the following environment variables%s: %s", + to_destination, + [v for v in env_vars_to_copy if v in os.environ]) + logger.info( + "If certain env vars should NOT be copied, add them to " + "%s file", RAY_NON_CARRY_OVER_ENV_VARS_FILE) + + return env_vars_to_copy diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index e8cd565519f3..d61e4f11dfa2 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -3,7 +3,10 @@ from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser from .granite_reasoning_parser import GraniteReasoningParser +from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser +from .mistral_reasoning_parser import MistralReasoningParser from .qwen3_reasoning_parser import Qwen3ReasoningParser __all__ = [ @@ -11,5 +14,8 @@ "ReasoningParserManager", "DeepSeekR1ReasoningParser", "GraniteReasoningParser", + "HunyuanA13BReasoningParser", "Qwen3ReasoningParser", + "Glm4MoeModelReasoningParser", + "MistralReasoningParser", ] diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index e827d381ca1d..4f4522d726e8 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -7,14 +7,22 @@ from abc import abstractmethod from collections.abc import Sequence from functools import cached_property -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - DeltaMessage) from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import import_from_path, is_list_of +if TYPE_CHECKING: + from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ResponsesRequest) + from vllm.transformers_utils.tokenizer import AnyTokenizer +else: + ChatCompletionRequest = Any + DeltaMessage = Any + ResponsesRequest = Any + AnyTokenizer = Any + logger = init_logger(__name__) @@ -66,7 +74,9 @@ def extract_content_ids(self, input_ids: list[int]) -> list[int]: @abstractmethod def extract_reasoning_content( - self, model_output: str, request: ChatCompletionRequest + self, + model_output: str, + request: Union[ChatCompletionRequest, ResponsesRequest], ) -> tuple[Optional[str], Optional[str]]: """ Extract reasoning content from a complete model-generated string. diff --git a/vllm/reasoning/glm4_moe_reasoning_parser.py b/vllm/reasoning/glm4_moe_reasoning_parser.py new file mode 100644 index 000000000000..6511fb49d10e --- /dev/null +++ b/vllm/reasoning/glm4_moe_reasoning_parser.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("glm4_moe") +class Glm4MoeModelReasoningParser(ReasoningParser): + """ + Reasoning parser for the Glm4MoeModel model. + + The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning + text within its output. The model provides a strict switch to disable + reasoning output via the 'enable_thinking=False' parameter. This parser + extracts the reasoning content enclosed by <think> and </think> tokens + from the model's output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_token = "<think>" + self.think_end_token = "</think>" + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + if (self.think_start_token_id is None + or self.think_end_token_id is None): + raise RuntimeError( + "Glm4MoeModel reasoning parser could not locate " + "think start/end tokens in the tokenizer!") + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.think_end_token_id in input_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + """ + Extract the content after the end tokens + """ + if self.think_end_token_id not in input_ids[:-1]: + return [] + else: + return input_ids[input_ids.index(self.think_end_token_id) + 1:] + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text <think>abc</think>xyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.think_start_token_id, self.think_end_token_id + ]): + return None + + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: + # <think> in previous, </think> in delta, + # extract reasoning content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # <think> in previous, </think> in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # <think> in previous, no </think> in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.think_start_token_id in delta_token_ids: + if self.think_end_token_id in delta_token_ids: + # <think> in delta, </think> in delta, extract reasoning content + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[start_index + + len(self.think_start_token + ):end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + else: + # <think> in delta, no </think> in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # thinking is disabled, just content + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from the model output. + + For text <think>abc</think>xyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + + Returns: + tuple[Optional[str], Optional[str]]: reasoning content and content + """ + + # Check if the model output contains the <think> and </think> tokens. + if (self.think_start_token not in model_output + or self.think_end_token not in model_output): + return None, model_output + # Check if the <think> is present in the model output, remove it + # if it is present. + model_output_parts = model_output.partition(self.think_start_token) + model_output = model_output_parts[2] if model_output_parts[ + 1] else model_output_parts[0] + # Check if the model output contains the </think> tokens. + # If the end token is not found, return the model output as is. + if self.think_end_token not in model_output: + return None, model_output + + # Extract reasoning content from the model output. + reasoning_content, _, content = model_output.partition( + self.think_end_token) + + final_content = content or None + return reasoning_content, final_content diff --git a/vllm/reasoning/hunyuan_a13b_reasoning_parser.py b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py new file mode 100644 index 000000000000..b2452b95c1c6 --- /dev/null +++ b/vllm/reasoning/hunyuan_a13b_reasoning_parser.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Optional, Union + +import regex as re +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("hunyuan_a13b") +class HunyuanA13BReasoningParser(ReasoningParser): + """ + Reasoning parser for Hunyuan A13B Model + + HunyuanReasoningParser + + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model + outputs that follow a specific pattern. + + Key Features: + - For non-stream output , Recognizes and extracts reasoning ("think") + and answer ("answer") sections from text using regular expressions. + - For stream process, it require a token id sequences to change the + reasoning state and other state so it maintains internal state to + manage parsing across multiple token. + + + think start: "<think>\n": [14023, 771, 397] + think ends: "\n</think>\n<answer>\n": [198, 524, 27963, 397, 27, 9399, 397] + response ends: "\n</answer>": [524, 9399, 29] + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_expr = r"<think>\n" + self.think_end_expr = r"\n</think>\n" + + self.response_start_expr = r"\n</think>\n<answer>\n" + self.response_end_expr = r"\n</answer>" + + self.full_match_reasoning_regex = re.compile( + rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", + re.DOTALL) + + self.half_match_reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", + re.DOTALL) + + self.think_start_ids = [14023, 771, 397] + self.think_start_ids_fast = [14023, 771, 1363] + self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] + self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] + self.response_end_ids = [198, 524, 9399, 29] + self.fast_think_ids = [ + 14023, 771, 1363, 524, 27963, 397, 27, 9399, 397 + ] + + # when state change, send out all the buffered text in last state + self.buffered_text = [] + self.buffered_ids = [] + + self.current_state = "reasoning" + self.all_states = ["reasoning", "response"] + + self.current_state = "idle" + self.expected_sequence = self.think_start_ids + # this sequence only for the think start, it has two way to start. + self.expected_sequence_side = self.think_start_ids_fast + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self.current_state == "response" + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for hunyuan streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[Optional[str], Optional[str]]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.full_match_reasoning_regex.findall(model_output) + if re_match: + reasoning_content, response_content = re_match[0] + if len(reasoning_content) == 0: + reasoning_content = None + if len(response_content) == 0: + response_content = None + return reasoning_content, response_content + + fallback_regex = self.half_match_reasoning_regex + fallback_match = fallback_regex.findall(model_output) + if fallback_match: + reasoning_content, response_content = fallback_match[0] + + if response_content.endswith(self.response_end_expr): + response_content = response_content[:-len(self. + response_end_expr)] + + if len(reasoning_content) == 0: + reasoning_content = None + if len(response_content) == 0: + response_content = None + + return reasoning_content, response_content + + return None, model_output + + def _is_strict_increasing_subsequence(self, subsequence: Sequence[int], + sequence: Sequence[int]) -> bool: + if not subsequence: + return False + + sub_idx = 0 + for num in sequence: + if sub_idx < len(subsequence) and num == subsequence[sub_idx]: + sub_idx += 1 + return sub_idx == len(subsequence) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """Extract content using token ID sequence state machine""" + # Define sequences + think_start_sequence = self.think_start_ids + response_start_sequence = self.response_start_ids + response_end_sequence = self.response_end_ids + + assert (len(delta_token_ids) == 1) + # Process each token in the delta + token = delta_token_ids[0] + + def check_token_with_sequence(token): + if self.current_state == "idle" or self.current_state == "think": + return (token == self.expected_sequence[self.sequence_index] + or token == \ + self.expected_sequence_side[self.sequence_index]) + else: + return token == self.expected_sequence[self.sequence_index] + + def check_last_token(token): + if self.current_state == "idle" or self.current_state == "think": + # only return true if it's judge using a side sequence. + if (self.sequence_index - 1 < len(self.expected_sequence_side) + and token + == self.expected_sequence_side[self.sequence_index - + 1]): + return self.sequence_index == len( + self.expected_sequence_side) + else: + return self.sequence_index == len(self.expected_sequence) + else: + return self.sequence_index == len(self.expected_sequence) + + # Check if token matches expected sequence + token_in_state_seq = check_token_with_sequence(token) + + if token_in_state_seq: + # Store matching token + self.token_buffer.append(token) + self.text_buffer += delta_text + self.sequence_index += 1 + ## state change from idle->think->response->idle + + # Check if sequence fully matched + if check_last_token(token): + # State transition + if self.current_state == "idle": + self.current_state = "think" + self.expected_sequence = response_start_sequence + self.expected_sequence_side = self.response_start_ids_fast + elif self.current_state == "think": + self.current_state = "response" + self.expected_sequence = response_end_sequence + elif self.current_state == "response": + self.current_state = "idle" + self.expected_sequence = think_start_sequence + self.expected_sequence_side = self.think_start_ids_fast + + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + # Do not send content for state transition texts. + else: + # Sequence broken - handle buffered content + if self.token_buffer and len(self.token_buffer) > 0: + # Send buffered tokens + buffered_content = self.text_buffer + delta_text + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + # Return content based on current state + if self.current_state == "think": + return DeltaMessage(reasoning_content=buffered_content, + content=None) + else: + return DeltaMessage(reasoning_content=None, + content=buffered_content) + else: + # No buffered content, send normally + if self.current_state == "think": + return DeltaMessage(reasoning_content=delta_text, + content=None) + else: + return DeltaMessage(reasoning_content=None, + content=delta_text) + + # If no content to send in this delta + return None diff --git a/vllm/reasoning/mistral_reasoning_parser.py b/vllm/reasoning/mistral_reasoning_parser.py new file mode 100644 index 000000000000..6c707a4079fa --- /dev/null +++ b/vllm/reasoning/mistral_reasoning_parser.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.logger import init_logger +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser) +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("mistral") +class MistralReasoningParser(DeepSeekR1ReasoningParser): + """ + Reasoning parser for Mistral models. + + The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning + text. This parser extracts the reasoning content from the model output. + """ + + def __init__(self, tokenizer: MistralTokenizer): + if not isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "The tokenizer must be an instance of MistralTokenizer.") + + ReasoningParser.__init__(self, tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + from mistral_common.tokens.tokenizers.base import SpecialTokens + + self.start_token = SpecialTokens.begin_think + self.end_token = SpecialTokens.end_think + + self.start_token_id = tokenizer.tokenizer.get_control_token( + self.start_token) + self.end_token_id = tokenizer.tokenizer.get_control_token( + self.end_token) + + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + "Mistral reasoning parser could not locate think start/end " + "tokens in the tokenizer!") diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a9a862384d11..322e53b75394 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -9,7 +9,6 @@ import msgspec from pydantic import BaseModel -from typing_extensions import deprecated from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor @@ -84,27 +83,6 @@ def __post_init__(self): "You can only use one kind of guided decoding but multiple are " f"specified: {self.__dict__}") - if self.backend is not None and ":" in self.backend: - self._extract_backend_options() - - @deprecated( - "Passing guided decoding backend options inside backend in the format " - "'backend:...' is deprecated. This will be removed in v0.10.0. Please " - "use the dedicated arguments '--disable-fallback', " - "'--disable-any-whitespace' and '--disable-additional-properties' " - "instead.") - def _extract_backend_options(self): - """Extract backend options from the backend string.""" - assert isinstance(self.backend, str) - self.backend, options = self.backend.split(":") - options_set = set(options.strip().split(",")) - if "no-fallback" in options_set: - self.disable_fallback = True - if "disable-any-whitespace" in options_set: - self.disable_any_whitespace = True - if "no-additional-properties" in options_set: - self.disable_additional_properties = True - class RequestOutputKind(Enum): # Return entire output so far in every RequestOutput diff --git a/vllm/sequence.py b/vllm/sequence.py index ffe890eb2dab..fe87b52f9df1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -19,7 +19,6 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -112,13 +111,6 @@ class RequestMetrics: model_execute_time: The time spent in the model execute function. This will include model forward, block/sync across workers, cpu-gpu sync time and sampling time. - spec_token_acceptance_counts: number of accepted speculative tokens at - each position; the first token is from - the target model and is always accepted; - e.g., when it's [10, 8, 4, 2] for a req, - it means there were 10 forward passes in - total, and there were 8, 4, 2 accepted - tokens at 1st, 2nd, 3rd speculation step. """ arrival_time: float last_token_time: float @@ -129,7 +121,6 @@ class RequestMetrics: scheduler_time: Optional[float] = None model_forward_time: Optional[float] = None model_execute_time: Optional[float] = None - spec_token_acceptance_counts: Optional[list[int]] = None class SequenceDataDelta( @@ -466,7 +457,6 @@ class Sequence: block size used by the block manager and cache engine. eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. lora_request: LoRA request. - prompt_adapter_request: Prompt Adapter request. """ def __init__( @@ -476,14 +466,12 @@ def __init__( block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.seq_id = seq_id self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData.from_seqs( self.prompt_token_ids, @@ -545,11 +533,6 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - @property - def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 - def get_output_text_to_return(self, buffer_length: int, delta: bool) -> str: """If delta is True, only new text since the last call to @@ -609,12 +592,12 @@ def extra_hash(self) -> Optional[int]: designed for prefix caching mode. The final sequence hash is determined by applying token_ids from the sequence's blocks. """ - if self.prompt_adapter_id == 0 and self.lora_int_id == 0: + if self.lora_int_id == 0: return None # NOTE: If there are additional factors influencing the block aside from # token_ids, include them as input parameters to the hash. - return hash((self.prompt_adapter_id, self.lora_int_id)) + return hash(self.lora_int_id) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size @@ -715,7 +698,6 @@ class SequenceGroup: encoder_seq: Optional, the single encoder sequence. Should be None unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request. priority: User-defined priority of the request. draft_size: The number of speculative tokens plus one from the target model; equal to max number of tokens a step can generate @@ -733,7 +715,6 @@ def __init__(self, pooled_data: Optional[torch.Tensor] = None, encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, draft_size: int = 1) -> None: self.request_id = request_id @@ -748,16 +729,13 @@ def __init__(self, last_token_time=arrival_time, first_scheduled_time=None, first_token_time=None, - time_in_queue=None, - spec_token_acceptance_counts=[0] * - draft_size) + time_in_queue=None) self.last_token_latency = 0.0 self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() self.pooling_params = pooling_params self.pooled_data = pooled_data - self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers self.priority = priority @@ -812,16 +790,6 @@ def multi_modal_placeholders(self) -> MultiModalPlaceholderDict: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - @property - def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 - - @property - def prompt_adapter_num_virtual_tokens(self) -> int: - return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - if self.prompt_adapter_request else 0 - def init_multi_step(self, num_steps: int) -> None: self.state.num_steps = num_steps self.state.current_step = 0 @@ -1021,7 +989,6 @@ class SequenceGroupMetadata( (SequenceGroup.encoder_seq). Should be None unless you are working with an encoder/decoder model. - prompt_adapter_request: Prompt Adapter request. """ request_id: str @@ -1040,7 +1007,6 @@ class SequenceGroupMetadata( multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[list[int]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None ### Stateful fields that are lazily defined. ### @@ -1062,16 +1028,6 @@ def __post_init__(self): def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - @property - def prompt_adapter_id(self) -> int: - return self.prompt_adapter_request.prompt_adapter_id \ - if self.prompt_adapter_request else 0 - - @property - def prompt_adapter_num_virtual_tokens(self) -> int: - return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ - if self.prompt_adapter_request else 0 - # Multi-Step Chunked-Prefill property @property def is_single_step_prompt(self) -> bool: @@ -1183,6 +1139,10 @@ class PoolingSequenceGroupOutput( # The actual type is in SequenceGroup.pooled_data data: Any + def get_data_nbytes(self) -> int: + data: torch.Tensor = self.data + return data.nbytes + def __repr__(self) -> str: return f"PoolingSequenceGroupOutput(data={self.data}" @@ -1198,9 +1158,15 @@ class IntermediateTensors: """For all pipeline stages except the last, we need to return the hidden states and residuals to be sent to the next stage. This data structure contains the hidden states and residuals for a request. + + Each stage also needs to handle its own finished_sending and + finished_recving in case of kv transfer. """ tensors: dict[str, torch.Tensor] + # [req_ids] + finished_sending: Optional[set[str]] = None + finished_recving: Optional[set[str]] = None def __init__(self, tensors): # manually define this function, so that @@ -1238,6 +1204,9 @@ class PoolerOutput( """The output from a pooling operation in the pooling model.""" outputs: list[PoolingSequenceGroupOutput] + def get_data_nbytes(self) -> int: + return sum(o.get_data_nbytes() for o in self.outputs) + def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput: return self.outputs[idx] @@ -1390,8 +1359,6 @@ class ExecuteModelRequest( previous_hidden_states: Optional[HiddenStates] = None # The number of forward steps to run. num_steps: int = 1 - # The step index for spec model input. - spec_step_idx: Optional[int] = None # Finished request ids since last step. finished_requests_ids: list[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. @@ -1524,7 +1491,6 @@ def add_request(request_id: str, engine, params, **kwargs): pooled_data=seq_group.pooled_data, encoder_seq=seq_group.encoder_seq, trace_headers=seq_group.trace_headers, - prompt_adapter_request=seq_group.prompt_adapter_request, priority=seq_group.priority, ) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py deleted file mode 100644 index f9b882469a4d..000000000000 --- a/vllm/spec_decode/batch_expansion.py +++ /dev/null @@ -1,506 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from array import array -from itertools import chain, count -from typing import Iterator, List, Optional, Tuple - -import torch - -from vllm import SamplingParams -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, - ExecuteModelRequest, SequenceData, - SequenceGroupMetadata, get_all_seq_ids) -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len - -SeqId = int -TargetSeqId = int -TokenId = int - -DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams() - - -class BatchExpansionTop1Scorer(SpeculativeScorer): - """Implements a speculative scorer that uses batch expansion to get - probabilities of speculative tokens according to the scoring model. - - Batch expansion converts a list of sequences and multiple query positions - to a new batch of sequences, each with a single query position. This allows - for MQA-like scoring in speculative decoding without requiring an MQA - kernel. - - It is strictly less efficient than MQA scoring. - - It only supports scoring the top1 proposal tokens of the proposer, instead - of topk/tree. - """ - - @nvtx_range("BatchExpansionTop1Scorer.score_proposals") - def score_proposals( - self, - execute_model_req: ExecuteModelRequest, - proposals: SpeculativeProposals, - ) -> SpeculativeScores: - """Score the proposed tokens via the scorer model. - - This converts each input sequence to a set of k+1 target sequences. The - target sequences have the unique continuations to be scored and a - unique sequence ID that is different from all input sequence ids. - - If a speculative sequence length would exceed the max model length, then - no speculation is produced for that sequence. - - Args: - execute_model_req: The execution request. - proposals: The speculative proposals to score. - Returns: - SpeculativeScores: The scores of each speculative token, along with - which sequences were ignored during scoring. - """ - - # TODO(cade) perform this on GPU to remove blocking call. - proposal_lens_list = proposals.proposal_lens.tolist() - proposal_token_ids_list = proposals.proposal_token_ids.tolist() - - # Filter the list to ignore invalid proposals. - proposal_token_ids_list_without_skips = [ - proposals for proposals in proposal_token_ids_list - if VLLM_INVALID_TOKEN_ID not in proposals - ] - - (spec_indices, non_spec_indices, target_seq_group_metadata_list, - num_scoring_tokens) = self._expand_batch( - seq_group_metadata_list=execute_model_req.seq_group_metadata_list, - proposal_token_ids_list=proposal_token_ids_list_without_skips, - proposal_lens_list=proposal_lens_list, - ) - - target_sampler_output = self._scorer_worker.execute_model( - execute_model_req=execute_model_req.clone( - seq_group_metadata_list=target_seq_group_metadata_list)) - assert len(target_sampler_output) == 1, "expected single-step output" - target_sampler_output = target_sampler_output[0] - - if not non_spec_indices: - # All sequence groups in batch have spec decoding enabled - return self._contract_batch_all_spec( - target_sampler_output=target_sampler_output, - proposals=proposals, - ) - else: - # Batch has a mix of spec decode enabled and disabled seq groups - return self._contract_batch( - execute_model_req.seq_group_metadata_list, - target_sampler_output=target_sampler_output, - proposals=proposals, - num_scoring_tokens=num_scoring_tokens, - non_spec_indices=non_spec_indices, - spec_indices=spec_indices, - k=execute_model_req.num_lookahead_slots, - ) - - def _expand_batch( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids_list: List[List[TokenId]], - proposal_lens_list: List[int], - ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: - """Given the input sequences and potentially multiple corresponding - proposal tokens, create a new batch where each sequence has a single - query token. - """ - - # vLLM currently only supports proposal lens equal to zero or the batch - # proposal len. This adds some complexity (splitting the batch into spec - # and non spec sequences) and should be removed in the future. It can be - # done by supporting per-sequence proposal lens. - (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \ - split_batch_by_proposal_len( - seq_group_metadata_list, proposal_lens_list) - - spec_expanded_seqs = self._create_scoring_model_input( - seq_group_metadata_list=spec_seqs, - proposal_token_ids=proposal_token_ids_list, - # NOTE: We determine the seq ids in the expanded batch using the - # full seq_group_metadata_list, instead of only spec_seqs. - target_seq_ids_iter=self._create_target_seq_id_iterator( - seq_ids=get_all_seq_ids(seq_group_metadata_list)), - ) - - num_scoring_tokens = len(spec_expanded_seqs) - # Batch speculative and non-speculative (e.g. chunked prefill) requests - # but make sure order is prefill|decode due to backend requirement. - target_seq_group_metadata_list = non_spec_seqs + spec_expanded_seqs - - return (spec_indices, non_spec_indices, target_seq_group_metadata_list, - num_scoring_tokens) - - def _contract_non_speculative( - self, scores: SpeculativeScores, - seq_group_metadata_list: List[SequenceGroupMetadata], - non_spec_indices: List[int], non_spec_outputs: SpeculativeScores, - has_prompt_log: bool) -> SpeculativeScores: - """ - Augment input `scores` with non-speculative requests outputs. - This includes decode requests with speculation turned off, as well - as prefill requests when `enable_chunked_prefill` is set. - For the latter, prefills are further separated into terminal and - non-terminal chunks (from which no token is sampled). - """ - if not non_spec_indices: - return scores - - if has_prompt_log: - # When prompt_logprobs is enabled, prefills yield output token - # (and respective prob) in the last entry (prompt|out): - # [.|.|.|prefill0_out|.|prefill1_out|decode0_out|..]. - # With chunked prefill, non-terminal chunks have -1 on each - # position: they're still picked, but they're discarded later. - seq_meta = seq_group_metadata_list - nospec_sizes = torch.tensor([ - seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1 - for i in non_spec_indices - ]) - nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1) - else: - # In this case only sampled tokens are returned, select all. - nospec_sampled_token_idxs = list( - range(len(non_spec_outputs.token_ids))) - - scores.token_ids[non_spec_indices, :1] = \ - non_spec_outputs.token_ids[nospec_sampled_token_idxs].unsqueeze(1) - scores.probs[non_spec_indices, :1, :] = \ - non_spec_outputs.probs[nospec_sampled_token_idxs].unsqueeze(1) - scores.logprobs[non_spec_indices, :1, :] = \ - non_spec_outputs.logprobs[nospec_sampled_token_idxs].unsqueeze(1) - if scores.hidden_states is not None: - assert non_spec_outputs.hidden_states is not None - scores.hidden_states[non_spec_indices, :1, :] = \ - non_spec_outputs.hidden_states[nospec_sampled_token_idxs].unsqueeze(1) - return scores - - def _contract_batch( - self, - contracted_seq_group_metadata_list: List[SequenceGroupMetadata], - target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, num_scoring_tokens: int, - non_spec_indices: List[int], spec_indices: List[int], - k: int) -> SpeculativeScores: - """Contract the expanded batch back into its original size. - This maps the scores of speculative tokens back to their original - sequences. - - contracted_bs is the original batch size, and the batch size that the - target_sampler_output will be contracted to. - """ - contracted_bs = len(contracted_seq_group_metadata_list) - (target_token_ids, target_probs, target_logprobs, target_hidden_states, - non_spec_target_token_ids, non_spec_target_probs, - non_spec_target_logprobs, - non_spec_target_hidden_states) = self._split_scoring_output( - target_sampler_output, num_scoring_tokens) - - # Map distinct sequences used to score each token - # of shape [batch_size * k + 1] back to [batch_size, k + 1]. - expanded_batch_size, k = proposals.proposal_token_ids.shape - - # The number of tokens in the expanded batch used for speculation is - # equal to the total expanded batch size minus the number of samples for - # non-speculative sequences, prefill chunks with no out tokens included - non_spec_expanded_bs = len(non_spec_indices) - spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs - - target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) - target_probs = target_probs.reshape(*target_token_ids.shape, - self._vocab_size) - target_logprobs = target_logprobs.reshape(target_probs.shape) - - if target_hidden_states is not None: - target_hidden_states = target_hidden_states.reshape( - *target_token_ids.shape, target_hidden_states.shape[-1]) - - all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), - fill_value=-1) - all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) - all_logprobs = target_logprobs.new_full(size=all_probs.shape, - fill_value=-float("inf")) - - if target_sampler_output.hidden_states is not None: - all_hidden_states = target_hidden_states.new_zeros( - size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) - else: - all_hidden_states = None - - has_prompt_log = any((sg.sampling_params.prompt_logprobs - and sg.sampling_params.prompt_logprobs > 0) - for sg in contracted_seq_group_metadata_list) - # When prompt logprobs is enabled, lens of returned tensors go from - # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. - # We adjust stride accordingly to get the generated tokens and - # their probs, but pass on prompt_logprobs as is. - prompt_logprobs = None - if (not self._scorer_worker.model_runner.disable_logprobs\ - and has_prompt_log): - prompt_logprobs = [ - o.prompt_logprobs for o in target_sampler_output.outputs - ] - elif not has_prompt_log: - # When prompt logprobs are not to be returned, - # we can ignore non-terminal chunks (no out token). - non_spec_indices = [ - idx for idx in non_spec_indices - if contracted_seq_group_metadata_list[idx].do_sample - ] - - # "Contract" speculative. - if spec_indices: - all_tokens[spec_indices] = target_token_ids - all_probs[spec_indices] = target_probs - all_logprobs[spec_indices] = target_logprobs - if all_hidden_states is not None: - all_hidden_states[spec_indices] = target_hidden_states - - spec_scores = SpeculativeScores(probs=all_probs, - token_ids=all_tokens, - logprobs=all_logprobs, - hidden_states=all_hidden_states, - prompt_logprobs=prompt_logprobs) - - non_spec_outputs = SpeculativeScores( - probs=non_spec_target_probs, - token_ids=non_spec_target_token_ids, - logprobs=non_spec_target_logprobs, - hidden_states=non_spec_target_hidden_states) - # Contract remaining nonspec entries based on non_spec_indices, if any. - return self._contract_non_speculative( - spec_scores, contracted_seq_group_metadata_list, non_spec_indices, - non_spec_outputs, has_prompt_log) - - def _contract_batch_all_spec( - self, - target_sampler_output: SamplerOutput, - proposals: SpeculativeProposals, - ) -> SpeculativeScores: - """Contract the expanded batch back into its original size. - This maps the scores of speculative tokens back to their original - sequences. - - It assumes all sequences in the batch were previously expanded. - """ - - # Map distinct sequences used to score each token - # of shape [batch_size * k + 1] back to [batch_size, k + 1]. - contracted_bs, k = proposals.proposal_token_ids.shape - - # Reshape tensors to original batch size - target_token_ids = target_sampler_output.sampled_token_ids.reshape( - contracted_bs, k + 1) - target_probs = target_sampler_output.sampled_token_probs.reshape( - *target_token_ids.shape, self._vocab_size) - target_logprobs = target_sampler_output.logprobs.reshape( - target_probs.shape) - target_hidden_states = target_sampler_output.hidden_states - if target_hidden_states is not None: - target_hidden_states = target_hidden_states.reshape( - *target_token_ids.shape, target_hidden_states.shape[-1]) - - return SpeculativeScores(probs=target_probs, - token_ids=target_token_ids, - logprobs=target_logprobs, - hidden_states=target_hidden_states, - prompt_logprobs=None) - - def _create_scoring_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] - target_seq_ids_iter: Iterator[TargetSeqId], - ) -> List[SequenceGroupMetadata]: - """Given the original input sequences and proposed tokens from the draft - model, create a list of target sequences that can be used for scoring. - - target_seq_ids_iter provides sequence ids for the expanded batch, - fulfilling the requirement that no seq id in the expanded batch is equal - to the seq id in the original batch. - """ - - if not seq_group_metadata_list: - return [] - - target_seq_group_metadata = list( - chain.from_iterable( - self._create_target_seq_group_metadata( - seq_group_metadata, - proposal_token_ids, - i, - target_seq_ids_iter, - ) for i, seq_group_metadata in enumerate( - seq_group_metadata_list))) - - return target_seq_group_metadata - - def _create_target_seq_group_metadata( - self, - input_seq_group_metadata: SequenceGroupMetadata, - proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] - batch_index: int, - target_seq_ids_iter: Iterator[TargetSeqId], - ) -> List[SequenceGroupMetadata]: - """Given an input sequence group metadata and a list of draft tokens, - create a list of target SequenceGroupMetadata, one for each - token id that needs to be scored. - - Naive speculative decoding requires K target model scores, one for each - draft model token. However one can add a bonus token such that if each - token is accepted, then a final token may be sampled from the model. - This function creates K+1 target SequenceGroupMetadata to take - advantage of the bonus token. - """ - assert len(input_seq_group_metadata.seq_data) == 1, ( - "Beam search " - "not supported in speculative decoding") - input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys())) - - token_ids_to_score = self._get_token_ids_to_score( - proposal_token_ids[batch_index]) - - sampling_params = input_seq_group_metadata.sampling_params - target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - for i, token_ids in enumerate(token_ids_to_score): - target_seq_group_metadata_list.append( - self._create_single_target_seq_group_metadata( - input_seq_group_metadata, - input_seq_id, - next(target_seq_ids_iter), - token_ids, - sampling_params=sampling_params, - )) - - return target_seq_group_metadata_list - - @staticmethod - def _create_single_target_seq_group_metadata( - seq_group_metadata: SequenceGroupMetadata, - seq_id: SeqId, - target_seq_id: TargetSeqId, - token_ids: List[TokenId], - sampling_params: SamplingParams, - ) -> SequenceGroupMetadata: - """Create a single target SequenceGroupMetadata. - - Args: - seq_group_metadata: The metadata for the input sequence. - seq_id: The input sequence ID. - target_seq_id: The corresponding target sequence ID. - token_ids: The list of token ids that are to be appended to the - input sequence. - """ - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_token_ids = seq_data.prompt_token_ids_array - new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] - mrope_position_delta = seq_data.mrope_position_delta - - new_seq_data_dict = { - target_seq_id: - SequenceData( - prompt_token_ids, - _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, - new_output_token_ids), - ), - } - # This is a hack. Technically, spec decoding should compute - # num_lookahead slots at one shot, but instead, it expands the batch - # and evaluate one by one right now. context_len is seq_len - 1 because - # the kv cache is filled by a previous batch in the batch expansion. - for data in new_seq_data_dict.values(): - data.update_num_computed_tokens(data.get_len() - 1) - data.mrope_position_delta = mrope_position_delta - - return SequenceGroupMetadata( - request_id=seq_group_metadata.request_id, - is_prompt=seq_group_metadata.is_prompt, - seq_data=new_seq_data_dict, - sampling_params=sampling_params, - block_tables={ - target_seq_id: seq_group_metadata.block_tables[seq_id], - }, - lora_request=None, - token_chunk_size=1, - ) - - @staticmethod - def _split_scoring_output( - sampler_output: SamplerOutput, num_scoring_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, - Optional[torch.Tensor], torch.Tensor, torch.Tensor, - torch.Tensor, Optional[torch.Tensor]]: - """Split the target model output into speculative and non-speculative - output. - """ - - # vLLM currently only supports proposal lens equal to zero or the batch - # proposal len. This adds some complexity (splitting the batch into spec - # and non spec sequences) and should be removed in the future. It can be - # done by supporting per-sequence proposal lens. - # - # First samples are non-speculative, latter samples are from speculative - # scoring (prefill|decode order). - split_sizes = (sampler_output.sampled_token_ids.numel() - - num_scoring_tokens, num_scoring_tokens) - (non_spec_probs, - spec_probs) = sampler_output.sampled_token_probs.split(split_sizes) - (non_spec_sampled_tokens, spec_sampled_tokens - ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) - (non_spec_logprobs, - spec_logprobs) = sampler_output.logprobs.split(split_sizes) - - if sampler_output.hidden_states is not None: - (non_spec_hidden_states, spec_hidden_states - ) = sampler_output.hidden_states.split(split_sizes) - else: - non_spec_hidden_states, spec_hidden_states = None, None - - return (spec_sampled_tokens, spec_probs, spec_logprobs, - spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, - non_spec_logprobs, non_spec_hidden_states) - - @staticmethod - def _create_target_seq_id_iterator( - seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: - """Create an iterator for creating target sequence ids. - Target sequence ids are distinct from sequence ids because we create a - distinct target sequence id for each proposal token to be scored. - - This implementation increments a counter starting at 1 + max of all - provided input sequence ids. - """ - return count(start=max(seq_ids) + 1) - - @staticmethod - def _get_token_ids_to_score( - full_spec_token_ids: List[TokenId] # shape: [k] - ) -> List[List[TokenId]]: - """Given an int tensor of proposal token ids, return a list of - token ids that should be scored. - - Returns k+1 output lists. The additional one is used for generating the - bonus token. - - Example: - Input: [0, 1, 2, 3] (k=4) - Output: (k+1 lists) - [] - [0] - [0, 1] - [0, 1, 2] - [0, 1, 2, 3] - """ - empty_token_ids: List[TokenId] = [] - - token_ids_to_score = [empty_token_ids] - token_ids_to_score.extend(full_spec_token_ids[:i + 1] - for i in range(len(full_spec_token_ids))) - return token_ids_to_score diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py deleted file mode 100644 index 96646ec94718..000000000000 --- a/vllm/spec_decode/draft_model_runner.py +++ /dev/null @@ -1,349 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional - -import torch - -from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.sampler import SamplerOutput - -try: - try: - from vllm.attention.backends.flash_attn import FlashAttentionMetadata - except (ModuleNotFoundError, ImportError): - # vllm_flash_attn is not installed, try the ROCm FA metadata - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata as FlashAttentionMetadata) -except (ModuleNotFoundError, ImportError) as err: - raise RuntimeError( - "Draft model speculative decoding currently only supports " - "CUDA and ROCm flash attention backend.") from err - -from vllm.logger import init_logger -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import ExecuteModelRequest, IntermediateTensors -from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - ModelRunnerWrapperBase) - -logger = init_logger(__name__) - -# A flag to enable debug prints for the updated input tensors -# before each step. -debug_advance_input = False -# A flag to allow GPU advance step for draft model runner. -# Set to False for debugging. -allow_gpu_advance_step = True - - -class TP1DraftModelRunner(ModelRunnerWrapperBase): - """Specialized model runner for speculative decoding draft model. - Since the draft model always execute k forward passes consecutively to - generate k speculative tokens in a single speculative decoding step, - we could get rid of most CPU-GPU synchronization and data transfer - overheads by keeping model input and output tensors on GPU all the time. - - TODOs: - 1. Currently supports only flash-attn, add support for other attn_backends. - 2. Support TP > 1 (this requires some designs because we do not expect - any broadcasting inside execute_model). - """ - - def __init__(self, model_runner: ModelRunnerBase): - super().__init__(model_runner) - - self.indices_of_seq_with_bonus_tokens = None - - def _update_sampling_metadata(self, sampling_metadata, num_seqs, - num_queries): - - assert sampling_metadata.num_prompts == 0 - assert len(sampling_metadata.seq_groups) == num_queries - assert sampling_metadata.selected_token_indices.shape == ( - num_queries, ) - # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 - - # Verify that all sequences are decodes - for i in range(num_queries): - seq_group = sampling_metadata.seq_groups[i] - - assert seq_group.is_prompt is False # No prompt - assert seq_group.prompt_logprob_indices == [] # No prompt - assert seq_group.sample_indices == [i] # Simple - - def _gpu_advance_step(self, model_input: ModelRunnerInputBase, - last_output: SamplerOutput) -> ModelRunnerInputBase: - # Currently, we expect "decode mode" only - assert not model_input.is_prompt - - # Get num_seqs - num_seqs = len(model_input.seq_lens) - num_queries = len(model_input.query_lens) - - # Get output tokens GPU tensor - sampled_token_ids = last_output.sampled_token_ids - assert sampled_token_ids is not None - - # Update attn_metadata - attn_metadata = model_input.attn_metadata - assert isinstance(attn_metadata, FlashAttentionMetadata) - - attn_metadata.advance_step(model_input, sampled_token_ids, - self.block_size, num_seqs, num_queries) - - # Update sampling_metadata - sampling_metadata = model_input.sampling_metadata - self._update_sampling_metadata(sampling_metadata, num_seqs, - num_queries) - - # Create new input - new_model_input = self._model_input_cls( - input_tokens=model_input.input_tokens, - input_positions=model_input.input_positions, - attn_metadata=attn_metadata, - seq_lens=attn_metadata.seq_lens, - query_lens=model_input.query_lens, - lora_mapping=model_input.lora_mapping, - lora_requests=model_input.lora_requests, - multi_modal_kwargs=model_input.multi_modal_kwargs, - sampling_metadata=model_input.sampling_metadata, - is_prompt=False, - ) - - # Ensure we skip CPU samples - assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True - # We can reuse sampling tensors since every decode iteration is the same - new_model_input.sampling_metadata.reuse_sampling_tensors = True - - if debug_advance_input: - logger.debug("NEW INPUT: ") - logger.debug(" input_tokens = %s", new_model_input.input_tokens) - logger.debug(" input_positions = %s", - new_model_input.input_positions) - logger.debug(" seq_lens = %d", new_model_input.seq_lens) - logger.debug(" query_lens = %d", new_model_input.query_lens) - logger.debug(" attn_metadata:") - logger.debug(" seq_lens_tensor: %s", - attn_metadata.seq_lens_tensor) - logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) - logger.debug(" block_tables: %s", attn_metadata.block_tables) - - return new_model_input - - def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): - """Determines if draft_model_runner GPU multi-step can be used. - Currently required conditions are: - 1. Only decodes - 2. Only flash-attn - 3. No LORA - 4. No prompt_adapter_config - """ - if not allow_gpu_advance_step: - return False - - # We allow multi-step GPU only in decode mode - for seq_group in execute_model_req.seq_group_metadata_list: - if seq_group.is_prompt: - return False - - # TODO: Add support for other attn backends - if self.attn_backend.get_name() not in ("FLASH_ATTN", ): - return False - - # TODO: Add support for LORA - if self.lora_config: - return False - - # TODO: Add soft-tuning prompt adapter support - return not self.prompt_adapter_config - - def set_indices_of_seq_with_bonus_tokens(self, - indices_of_seq_with_bonus_tokens): - self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelRunnerInputBase, - kv_caches: List[torch.Tensor], - previous_hidden_states: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - **kwargs, - ) -> Optional[List[SamplerOutput]]: - """Executes num_steps forward passes with advacement of input tensors - on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. - - Optimizations used: - 1. Input tensors are updated on the GPU directly - 2. Skips GPU=>CPU serialization of sampler outputs (we don't need - them since we do batch expansion later that uses GPU outputs) - 3. Reuses sampling tensors (since we run only decodes and they have - a repeating sampling logic) - """ - - # When num_steps == 1, we execute the fallback here for the GPU - # advance_step, which runs prepare_inputs on CPU and for each spec - # iteration invokes this function only once - # (Look at multi-step-worker code) - is_fallback = num_steps == 1 - if not is_fallback: - # Since we do not broadcast data inside execute_model anymore, - # we need to figure out the best way to support TP > 1 in this - # case, because we will at least need to broadcast the sampled - # tokens to all workers. - if not self.is_driver_worker: - raise ValueError("TP1DraftModelRunner only supports TP=1.") - - # Sanity - if self.lora_config is not None: - raise ValueError("TP1DraftModelRunner has no support for LORA") - if self.prompt_adapter_config is not None: - raise ValueError("TP1DraftModelRunner has no support for " - "prompt_adapter_config") - if model_input.inputs_embeds is not None: - raise ValueError("TP1DraftModelRunner has no support for " - "inputs_embeds") - if model_input.multi_modal_kwargs: - raise ValueError( - "TP1DraftModelRunner has no support for multi_modal_kwargs" - ) - else: - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) - - self.attn_state.begin_forward(model_input) - - # Detect exec mode - assert model_input.attn_metadata is not None - use_cuda_graph = False - if model_input.attn_metadata.num_prefills > 0: - # In this case, execute_model(..) was called directly - if num_steps > 1: - raise ValueError( - "execute_model(..) of draft_model_runner can be called " - "directly only with a single-step prefill") - else: - # We can skip CPU samples for spec token generation. - # (We do allow CPU samples for num_steps == 1 to support the - # fallback case, where supports_gpu_multi_step(..) does not pass) - model_input.sampling_metadata.skip_sampler_cpu_output = ( - not is_fallback) - - # Attn attr defines if we use cuda graphs - use_cuda_graph = model_input.attn_metadata.use_cuda_graph - - # Get model - if use_cuda_graph: - if model_input.inputs_embeds is None: - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, False)]) - else: - graph_batch_size = model_input.inputs_embeds.shape[0] - model_executable = ( - self.graph_runners[model_input.virtual_engine][( - graph_batch_size, True)]) - - if previous_hidden_states is not None: - hidden_states = torch.cat([ - previous_hidden_states, - torch.empty([ - graph_batch_size - previous_hidden_states.shape[0], - *previous_hidden_states.shape[1:] - ], - dtype=previous_hidden_states.dtype, - device=previous_hidden_states.device) - ]) - else: - hidden_states = None - else: - model_executable = self.model - hidden_states = previous_hidden_states - - outputs: List[SamplerOutput] = [] - for step in range(num_steps): - multi_modal_kwargs = model_input.multi_modal_kwargs or {} - - model_execute_kwargs = {"previous_hidden_states": hidden_states} \ - if previous_hidden_states is not None else {} - - compute_logits_kwargs = {} - # Run model - if hasattr(self.model.config, "num_nextn_predict_layers"): - # for DeepSeek MTP only to use the corresponding layer for - # each step - spec_step_idx = kwargs.get("spec_step_idx", step) - model_execute_kwargs["spec_step_idx"] = spec_step_idx - compute_logits_kwargs["spec_step_idx"] = spec_step_idx - with set_forward_context(model_input.attn_metadata, - self.vllm_config): - hidden_states = model_executable( - input_ids=model_input.input_tokens, - inputs_embeds=None, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - multi_modal_kwargs, - device=self.device, - ), - **model_execute_kwargs, - ) - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata, - **compute_logits_kwargs) - if not self.is_driver_worker: - return [] - # Sample the next token. - output = self.model_runner.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - outputs.append(output) - - if self.return_hidden_states and is_fallback: - if use_cuda_graph: - indices = model_input.sampling_metadata\ - .selected_token_indices - output.hidden_states = hidden_states[:len(indices)] - else: - output.hidden_states = hidden_states - - if model_input.attn_metadata.num_prefills == 0 \ - and self.indices_of_seq_with_bonus_tokens is not None: - assert output.sampled_token_ids is not None - # output.sampled_token_ids should be of shape (num_seqs, 1) - nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape - assert num_tokens_per_seq == 1 - count = 0 - for i in range(nums_seqs): - bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[ - count] - if i != bonus_seq_idx: - # The following might cause a cpu->gpu sync - # However, the performance impact is negligible as we - # benchmarked on H100. - output.sampled_token_ids[ - i, :] = model_input.input_tokens[bonus_seq_idx] - else: - count += 1 - - # Prepare inputs for the next step - if step != num_steps - 1: - model_input = self._gpu_advance_step(model_input, outputs[-1]) - - return outputs diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py deleted file mode 100644 index 70ec1590e7ad..000000000000 --- a/vllm/spec_decode/interfaces.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import List, Optional, Set, Union - -import torch - -from vllm.sequence import ExecuteModelRequest, PromptLogprobs -from vllm.worker.worker_base import WorkerBase - - -@dataclass -class SpeculativeProposals: - """Datastructure used to represent proposal tokens from some proposer. It - also tracks how many speculative tokens each sequence has. - """ - - # Speculative proposal tokens. - proposal_token_ids: torch.Tensor - - # Probabilities of the proposal tokens according to the proposer. - proposal_probs: torch.Tensor - - # The valid length of each proposal; can be zero. - proposal_lens: torch.Tensor - - # A flag to mark that there's no available proposals - no_proposals: bool = False - - def __repr__(self): - return (f"SpeculativeProposals(" - f"proposal_token_ids={self.proposal_token_ids}, " - f"proposal_probs={self.proposal_probs.shape}, " - f"proposal_lens={self.proposal_lens})") - - -@dataclass -class SpeculativeScores: - """Datastructure used to represent the scores of speculative tokens - according to the scoring model. - """ - - # Probabilities of the speculative tokens according to the scoring model. - probs: torch.Tensor - - # Log-probabilities of the speculative tokens according to the scoring - # model. These values can be used to generate Logprob objects that are - # returned to the user. - logprobs: torch.Tensor - - # Token ids sampled from the scoring model. Used for speculative bonus - # tokens and also non-speculative normal decoding. - token_ids: torch.Tensor - - # Optional last hidden states from the scoring model. - hidden_states: Optional[torch.Tensor] = None - - # Scoring model may also return logprobs for prompt tokens - # for each request, when chunked prefill is enabled. - prompt_logprobs: Optional[List[PromptLogprobs]] = None - - def __repr__(self): - return (f"SpeculativeScores(" - f"probs={self.probs.shape}, " - f"token_ids={self.token_ids.shape})") - - -class SpeculativeProposer(ABC): - - @abstractmethod - def get_spec_proposals( - self, - execute_model_req: ExecuteModelRequest, - # If set, this contains all sequence IDs that were assigned - # bonus tokens in their last forward pass. - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> SpeculativeProposals: - raise NotImplementedError - - -class SpeculativeScorer(ABC): - - def __init__(self, scorer_worker: WorkerBase, - device: Union[torch.device, str], vocab_size: int): - self._scorer_worker = scorer_worker - if isinstance(device, torch.device): - device = device.type - self._device = device - self._vocab_size = vocab_size - - @abstractmethod - def score_proposals( - self, - execute_model_req: ExecuteModelRequest, - proposals: SpeculativeProposals, - ) -> SpeculativeScores: - raise NotImplementedError diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py deleted file mode 100644 index 82b5a79fa7cb..000000000000 --- a/vllm/spec_decode/medusa_worker.py +++ /dev/null @@ -1,138 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import weakref -from typing import List, Optional, Set, Tuple - -import torch - -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata -from vllm.spec_decode.interfaces import SpeculativeProposals -from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase -from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker_base import DelegateWorkerBase - - -class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase): - """Worker for Medusa. - """ - - def __init__(self, *args, **kwargs): - DelegateWorkerBase.__init__(self, *args, **kwargs) - # Lazy initialization list. - self._proposer: Top1Proposer - - def init_device(self): - self.worker.init_device() - - self._proposer = Top1Proposer( - weakref.proxy(self), # type: ignore[arg-type] - self.device, - self.vocab_size, - max_proposal_len=self.max_model_len, - ) - - def set_include_gpu_probs_tensor(self): - pass - - def set_should_modify_greedy_probs_inplace(self): - pass - - @torch.inference_mode() - def sampler_output( - self, - execute_model_req: ExecuteModelRequest, - sample_len: int, - # Unused parameter. - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> Tuple[List[SamplerOutput], bool]: - """Run the model forward pass to generate sample_len future tokens. - Returns the list of sampler output, one per layer, along with indicator - of whether torch tensor in sampler output need to be transposed in - latter sampler_output_to_torch logic. - - For medusa worker, this indicator shall be False. - """ - self._raise_if_unsupported(execute_model_req) - - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - - seq_lens, query_lens = self._prepare_input_tensors( - seq_group_metadata_list) - - generators = self.model_runner.get_generators( - execute_model_req.finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.model_runner.pin_memory, generators) - - model_outputs = self.model_runner.model.generate_proposals( - previous_hidden_states=execute_model_req.previous_hidden_states. - hidden_states, - sampling_metadata=sampling_metadata) - - return model_outputs, False - - def _prepare_input_tensors( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[List[int], List[int]]: - if not seq_group_metadata_list: - return [], [] - - seq_lens: List[int] = [] - query_lens: List[int] = [] - - for seq_group_metadata in seq_group_metadata_list: - is_prompt = seq_group_metadata.is_prompt - - for seq_data in seq_group_metadata.seq_data.values(): - seq_data_len = seq_data.get_len() - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - seq_len = min( - seq_data_len, - context_len + seq_group_metadata.token_chunk_size) - seq_lens.append(seq_len) - query_lens.append(seq_len - context_len) - else: - seq_lens.append(seq_data_len) - query_lens.append(1) - - return seq_lens, query_lens - - def get_spec_proposals( - self, - execute_model_req: ExecuteModelRequest, - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> SpeculativeProposals: - """Produce speculations given an input batch of sequences. The number of - speculative tokens per sequence is determined by max_proposal_len. - """ - - return self._proposer.get_spec_proposals( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - - def _raise_if_unsupported( - self, - execute_model_req: ExecuteModelRequest, - ) -> None: - """MedusaWorker does not yet implement support for cache swap - operations or beam search. - """ - if any([ - execute_model_req.blocks_to_swap_in, - execute_model_req.blocks_to_swap_out, - execute_model_req.blocks_to_copy - ]): - raise NotImplementedError( - "MedusaWorker does not support cache operations") - - if any( - len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in - execute_model_req.seq_group_metadata_list): - raise NotImplementedError( - "MedusaWorker does not support beam search.") diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py deleted file mode 100644 index a4784cad962d..000000000000 --- a/vllm/spec_decode/metrics.py +++ /dev/null @@ -1,213 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from typing import Callable, Optional, Union - -import msgspec -import torch - -from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler) -from vllm.platforms import current_platform -from vllm.utils import is_pin_memory_available - - -class SpecDecodeWorkerMetrics( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """Dataclass holding metrics emitted from the spec decode worker. - """ - - # The empirical acceptance rate of the proposal method on a per-token basis. - # This is useful for evaluating how well the proposal method aligns with the - # scoring method. - draft_acceptance_rate: float - - # The empirical efficiency, measured as the number of tokens emitted by the - # system divided by the number of tokens that could be emitted by the system - # if the proposal method were perfect. - system_efficiency: float - - # The number of speculative tokens produced by the proposal method. - draft_tokens: int - - # The number of tokens emitted by the entire system. - emitted_tokens: int - - # The number of tokens accepted by the scoring model and verification - # routine, e.g. Llama2-70B and lossless rejection sampling. - # - # NOTE: Any token accepted by the verification routine is considered - # accepted (regardless of if the speculative prefix is also accepted). The - # user will usually see less accepted tokens. This metric is helpful when - # evaluating alignment of the proposal method with the scoring model. - accepted_tokens: int - - # The number of speculative tokens per sequence. - num_spec_tokens: int - - -Timer = Callable[[], float] - - -class AsyncMetricsCollector: - """Class which copies rejection/typical-acceptance sampler metrics - from the device to CPU on a non-default Torch stream. - """ - - def __init__(self, - spec_decode_sampler: SpecDecodeBaseSampler, - timer: Optional[Timer] = None, - collect_interval_s: float = 5.0): - self.spec_decode_sampler = spec_decode_sampler - self._timer = time.time if timer is None else timer - - self._rank: Optional[int] = None - - # We don't have a device set yet. - self._copy_stream: Optional[torch.cuda.Stream] = None - - self._in_flight_copy: Optional[torch.cuda.Event] = None - - pin_memory = is_pin_memory_available() - self._aggregate_num_accepted_tokens = torch.tensor( - 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) - self._aggregate_num_emitted_tokens = torch.tensor( - 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) - self._aggregate_num_draft_tokens = 0 - - self._rejsample_metrics_collect_interval_s = collect_interval_s - self._last_metrics_collect_time = self._timer() - - def init_gpu_tensors(self, rank: int) -> None: - self._rank = rank - self._copy_stream = torch.cuda.Stream() - - def init_tensors(self, - rank: int, - device_type: Union[torch.device, str] = 'cuda') -> None: - self._rank = rank - if isinstance(device_type, torch.device): - device_type = device_type.type - stream = current_platform.Stream - if stream is not None: - self._copy_stream = stream() - - def maybe_collect_rejsample_metrics( - self, k: int) -> Optional[SpecDecodeWorkerMetrics]: - # Skip for any platform that doesn't have device Event - if current_platform.Event is None: - return None - - # If a copy was initiated in the previous call, collect and return. - if self._in_flight_copy is not None: - ready_event = self._in_flight_copy - self._in_flight_copy = None - return self._collect_rejsample_metrics(k, ready_event) - - # Otherwise, check if we should start a new copy. - if self._should_collect_rejsample_metrics(self._timer()): - assert self._in_flight_copy is None - self._in_flight_copy = self._copy_rejsample_metrics_async() - - return None - - def _should_collect_rejsample_metrics(self, now: float) -> bool: - """Return whether or not this iteration should print sampling - metrics. - """ - if self._rank != 0: - return False - - return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 - - def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: - """Copy rejection/typical-acceptance sampling metrics - (number of accepted tokens, etc) to CPU asynchronously. - - Returns a device event recording when the copy is complete. - """ - assert self._copy_stream is not None - self._copy_stream.wait_stream(current_platform.current_stream()) - - with current_platform.stream(self._copy_stream): - self._aggregate_num_accepted_tokens.copy_( - self.spec_decode_sampler.num_accepted_tokens, - non_blocking=True) - self._aggregate_num_emitted_tokens.copy_( - self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) - # Number of draft tokens is calculated on CPU, so no copy is - # required. - self._aggregate_num_draft_tokens = ( - self.spec_decode_sampler.num_draft_tokens) - - aggregate_metrics_ready = current_platform.Event() - aggregate_metrics_ready.record(self._copy_stream) - - return aggregate_metrics_ready - - def _collect_rejsample_metrics( - self, k: int, - ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics: - """Create metrics object from statistics copied asynchronously. - - Args: - k: int. The number of speculative tokens; used to determine system - efficiency. - ready_event: torch.cuda.Event. The CUDA event recording when the - async GPU->CPU copy is complete. - """ - - ready_event.synchronize() - - # update time of last collection - self._last_metrics_collect_time = self._timer() - - accepted_tokens = self._aggregate_num_accepted_tokens.item() - emitted_tokens = self._aggregate_num_emitted_tokens.item() - draft_tokens = self._aggregate_num_draft_tokens - - max_num_emitted_tokens = self.get_max_num_emitted_tokens( - draft_tokens, k) - - if draft_tokens > 0: - draft_acceptance_rate = accepted_tokens / draft_tokens - else: - draft_acceptance_rate = float("nan") - - if max_num_emitted_tokens > 0: - system_efficiency = emitted_tokens / max_num_emitted_tokens - else: - system_efficiency = float("nan") - - return SpecDecodeWorkerMetrics( - num_spec_tokens=k, - draft_acceptance_rate=draft_acceptance_rate, - system_efficiency=system_efficiency, - accepted_tokens=accepted_tokens, - draft_tokens=draft_tokens, - emitted_tokens=emitted_tokens, - ) - - @staticmethod - def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: - """Calculate the number of emitted tokens, assuming all tokens are - accepted. - - This is equal to the number of sequences that have been speculated on, - times (speculation len + 1). The +1 comes from the bonus token. - """ - # Determine the number of sequences that have been speculated on. Since - # the batch size can be variable, we divide by k. - assert draft_tokens % k == 0 - total_num_spec_seqs = draft_tokens // k - - # A single sequence may emit k accepted tokens and one bonus token in - # the best case. - num_emitted_per_seq_if_all_accepted = k + 1 - - # The max num of emitted tokens is the number of speculated sequences - # times the max emitted per seq. - return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py deleted file mode 100644 index 8e8c05d26361..000000000000 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Set, Tuple - -import torch - -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata -from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase - - -class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): - """Worker for MLPSpeculator models. - - Not currently compatible with LoRA or chunked prefill. - """ - - @torch.inference_mode() - def sampler_output( - self, - execute_model_req: ExecuteModelRequest, - sample_len: int, - # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and - # therefore does not need this parameter. - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> Tuple[List[SamplerOutput], bool]: - """Run the model forward pass to generate sample_len future tokens. - Returns the list of sampler output, one per layer, along with indicator - of whether torch tensor in sampler output need to be transposed in - latter sampler_output_to_torch logic. - - For mlp spec worker, this indicator shall be True. - """ - self._raise_if_unsupported(execute_model_req) - - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - - (input_tokens, seq_lens, - query_lens) = self._prepare_input_tensors(seq_group_metadata_list) - - generators = self.model_runner.get_generators( - execute_model_req.finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, seq_lens, query_lens, self.device, - self.model_runner.pin_memory, generators) - - model_outputs = self.model_runner.model.generate_proposals( - input_ids=input_tokens, - previous_hidden_states=execute_model_req.previous_hidden_states. - hidden_states, - num_predict_tokens=sample_len, - sampling_metadata=sampling_metadata) - - assert len(model_outputs) == sample_len - - return model_outputs, True - - def _prepare_input_tensors( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, List[int], List[int]]: - if not seq_group_metadata_list: - return torch.empty(0, device=self.device), [], [] - - input_tokens: List[int] = [] - seq_lens: List[int] = [] - query_lens: List[int] = [] - - for seq_group_metadata in seq_group_metadata_list: - is_prompt = seq_group_metadata.is_prompt - - for seq_data in seq_group_metadata.seq_data.values(): - seq_data_len = seq_data.get_len() - if is_prompt: - context_len = seq_data.get_num_computed_tokens() - seq_len = min( - seq_data_len, - context_len + seq_group_metadata.token_chunk_size) - tokens = seq_data.get_token_ids()[context_len:seq_len] - seq_lens.append(seq_len) - input_tokens.extend(tokens) - query_lens.append(seq_len - context_len) - else: - seq_lens.append(seq_data_len) - input_tokens.append(seq_data.get_last_token_id()) - query_lens.append(1) - - input_tokens_tensor = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - return input_tokens_tensor, seq_lens, query_lens diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py deleted file mode 100644 index 18e7b055a678..000000000000 --- a/vllm/spec_decode/mqa_scorer.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm.sequence import (ExecuteModelRequest, SequenceData, - SequenceGroupMetadata, get_all_seq_ids) -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeScorer, SpeculativeScores) - -SeqId = int -TargetSeqId = int - - -class MQAScorer(SpeculativeScorer): - - def score_proposals( - self, - execute_model_req: ExecuteModelRequest, - proposals: SpeculativeProposals, - ) -> SpeculativeScores: - target_seq_group_metadata_list = [] - target_seq_id_start = max( - get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 - all_proposal_tokens = proposals.proposal_token_ids.tolist() - all_proposal_lengths = proposals.proposal_lens.tolist() - for i, seq_group_metadata in enumerate( - execute_model_req.seq_group_metadata_list): - if all_proposal_lengths[i] == 0: - # Keep prompt seqs untouched (keep computed_tokens for chunks). - target_seq_group_metadata_list.append(seq_group_metadata) - continue - - seq_data_dict = seq_group_metadata.seq_data - assert len(seq_data_dict) == 1 - seq_id = next(iter(seq_data_dict.keys())) - - seq_data: SequenceData = seq_data_dict[seq_id] - prompt_token_ids = seq_data.get_prompt_token_ids() - output_token_ids = seq_data.get_output_token_ids() - proposal_token_ids = all_proposal_tokens[ - i][:all_proposal_lengths[i]] - new_output_token_ids = [*output_token_ids, *proposal_token_ids] - - target_seq_id = target_seq_id_start + i - new_seq_data = SequenceData.from_seqs( - prompt_token_ids=prompt_token_ids, - output_token_ids=new_output_token_ids, - ) - new_seq_data.update_num_computed_tokens( - len(prompt_token_ids) + len(output_token_ids) - 1) - - # Ensure that the new decode sequence has at least one token. - assert len(output_token_ids) >= 1 - new_seq_data_dict = {target_seq_id: new_seq_data} - - new_seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group_metadata.request_id, - is_prompt=seq_group_metadata.is_prompt, - seq_data=new_seq_data_dict, - sampling_params=seq_group_metadata.sampling_params, - block_tables={ - target_seq_id: seq_group_metadata.block_tables[seq_id], - }, - lora_request=None, - ) - target_seq_group_metadata_list.append(new_seq_group_metadata) - - target_sampler_output = self._scorer_worker.execute_model( - execute_model_req=execute_model_req.clone( - seq_group_metadata_list=target_seq_group_metadata_list)) - - target_sampler_output = target_sampler_output[0] - - k = execute_model_req.num_lookahead_slots - bs = len(execute_model_req.seq_group_metadata_list) - target_token_ids = target_sampler_output.sampled_token_ids - target_probs = target_sampler_output.sampled_token_probs - target_logprobs = target_sampler_output.logprobs - prompt_logprobs = None - - # If all requests have the same number of query tokens, we can avoid - # the for loop to build output for better performance. - if min(all_proposal_lengths) == k: - # Regular decodes only. - assert all(not sg.is_prompt - for sg in target_seq_group_metadata_list - if sg.is_prompt) - bs, _ = proposals.proposal_token_ids.shape - all_tokens = target_token_ids.reshape(bs, k + 1) - all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) - all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) - else: - # We either have decodes with different lens or prefill+decodes. - all_tokens = target_token_ids.new_full(size=(bs, k + 1), - fill_value=-1) - all_probs = target_probs.new_zeros(*all_tokens.shape, - self._vocab_size) - all_logprobs = target_logprobs.new_full(size=all_probs.shape, - fill_value=-float("inf")) - target_token_ids = target_token_ids.flatten() - - # When prompt logprobs is enabled, lens of returned tensors go from - # n_sampled (requests with do_sample=True) to n_prompt+n_prefills. - # We adjust stride accordingly to get the generated tokens and - # their probs, but pass on prompt_logprobs as is, since it may be - # that n_prompts >> K. - has_prompt_log = any((sg.sampling_params.prompt_logprobs - and sg.sampling_params.prompt_logprobs > 0) - for sg in target_seq_group_metadata_list) - # TODO (NickLucche) we should surface `disable_logprobs` as to not - # break abstraction to get its value. - if (not self._scorer_worker.model_runner.disable_logprobs\ - and has_prompt_log): - prompt_logprobs = [ - o.prompt_logprobs for o in target_sampler_output.outputs - ] - - # Split loop into prefill|decode for readability. - start_loc, i = 0, 0 - while i < len(target_seq_group_metadata_list - ) and target_seq_group_metadata_list[i].is_prompt: - seq_meta = target_seq_group_metadata_list[i] - end_loc = start_loc - if has_prompt_log: - end_loc += seq_meta.token_chunk_size - elif seq_meta.do_sample: - end_loc += 1 - - # Skip chunks with no output tokens. - if seq_meta.do_sample: - # Get sampled token (last position in chunk) and its prob. - all_tokens[i, 0] = target_token_ids[end_loc - 1] - all_probs[i, 0] = target_probs[end_loc - 1] - all_logprobs[i, 0] = target_logprobs[end_loc - 1] - - i += 1 - start_loc = end_loc - # Decodes. - while i < len(target_seq_group_metadata_list): - proposed_len, seq_meta = all_proposal_lengths[ - i], target_seq_group_metadata_list[i] - output_len = proposed_len + 1 - end_loc = start_loc + output_len - all_tokens[ - i, :output_len] = target_token_ids[start_loc:end_loc] - all_probs[i, :output_len] = target_probs[start_loc:end_loc] - all_logprobs[ - i, :output_len] = target_logprobs[start_loc:end_loc] - start_loc = end_loc - i += 1 - - hidden_states = None - if target_sampler_output.hidden_states is not None: - hidden_states = target_sampler_output.hidden_states.reshape( - bs, (k + 1), -1) - - return SpeculativeScores(probs=all_probs, - token_ids=all_tokens, - logprobs=all_logprobs, - hidden_states=hidden_states, - prompt_logprobs=prompt_logprobs) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py deleted file mode 100644 index 4a9bbe44d89a..000000000000 --- a/vllm/spec_decode/multi_step_worker.py +++ /dev/null @@ -1,423 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import copy -import weakref -from typing import Dict, List, Set, Tuple - -import torch - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, - SequenceGroupMetadata) - -if current_platform.is_cuda_alike(): - from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner - -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeProposer) -from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase -from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker_base import DelegateWorkerBase - - -class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): - """The MultiStepWorker is equivalent to a Worker except that it allows - multiple forward passes in a single call, assuming the scheduler has - allocated enough space to store the additional KV. This reduces overhead - by invoking the scheduler less. - - The MultiStepWorker does not support cache swap operations, or beam search. - Cache swap operations do not require large modifications. On the other hand, - beam search requires memory allocations during sequence forks and thus - requires more thought for MultiStepWorker support. - """ - - def __init__(self, *args, **kwargs): - DelegateWorkerBase.__init__(self, *args, **kwargs) - # Lazy initialization list. - self._proposer: SpeculativeProposer - - def init_device(self) -> None: - self.worker.init_device() - self._proposer = Top1Proposer( - weakref.proxy(self), # type: ignore[arg-type] - self.device, - self.vocab_size, - max_proposal_len=self.max_model_len, - ) - - def set_include_gpu_probs_tensor(self) -> None: - # Need include_gpu_probs_tensor for MultiStepWorker - self.model_runner.sampler.include_gpu_probs_tensor = True - if hasattr(self.model_runner.model, "sampler"): - (self.model_runner.model.sampler.include_gpu_probs_tensor) = True - - def set_should_modify_greedy_probs_inplace(self) -> None: - self.model_runner.sampler.should_modify_greedy_probs_inplace = True - if hasattr(self.model_runner.model, "sampler"): - (self.model_runner.model.sampler.should_modify_greedy_probs_inplace - ) = True - - @torch.inference_mode() - def sampler_output( - self, - execute_model_req: ExecuteModelRequest, - sample_len: int, - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> Tuple[List[SamplerOutput], bool]: - """Run the model forward pass sample_len times. Returns the list of - sampler output, one per model forward pass, along with indicator of - whether torch tensor in sampler output need to be transposed in latter - sampler_output_to_torch logic. - - For multi step worker, this indicator shall be True. - """ - self._raise_if_unsupported(execute_model_req) - # Expand the batch for sequences with a bonus token. - # Perform a forward pass on the expanded batch and filter the - # response to retain only the original sequences' responses. - expanded_request, indices_of_seq_with_bonus_tokens =\ - self._expand_execute_model_request( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - - # Run model sample_len times. - model_outputs: List[SamplerOutput] = [] - if current_platform.is_cuda_alike() and isinstance( - self.model_runner, TP1DraftModelRunner - ) and self.model_runner.supports_gpu_multi_step(expanded_request): - # Here we run the draft_model_runner with multi-step prepare - # on the GPU directly - expanded_request.num_steps = sample_len - self.model_runner.set_indices_of_seq_with_bonus_tokens( - indices_of_seq_with_bonus_tokens) - model_outputs = self.execute_model( - execute_model_req=expanded_request) - else: - # Here we run multi-step directly, with every step prepared - # on the CPU. - # TODO: Remove this branch once DraftModelRunner supports TP>1 - # and other restrictions that are part of DraftModelRunner's - # supports_gpu_multi_step(..) - if expanded_request.previous_hidden_states is not None: - self.worker.model_runner.return_hidden_states = True - for _ in range(sample_len): - model_output: List[SamplerOutput] = self.worker.execute_model( - execute_model_req=expanded_request) - assert (len(model_output) == 1 - ), "composing multistep workers not supported" - model_output = model_output[0] - self._maybe_update_previous_hidden_states( - model_output, expanded_request) - - self._append_new_tokens( - model_output, expanded_request.seq_group_metadata_list, - indices_of_seq_with_bonus_tokens) - model_outputs.append(model_output) - - # move indices to device to avoid stream sync - indices_of_seq_with_bonus_tokens = torch.tensor( - indices_of_seq_with_bonus_tokens, device=self.device) - filtered_model_outputs = self._filter_model_output( - model_outputs, indices_of_seq_with_bonus_tokens) - return filtered_model_outputs, True - - @staticmethod - def _maybe_update_previous_hidden_states( - model_output: SamplerOutput, - expanded_request: ExecuteModelRequest) -> None: - """ - Updates the previous hidden states in an expanded request - in-place with the hidden states from the model output. - """ - if expanded_request.previous_hidden_states is not None: - expanded_request.previous_hidden_states = HiddenStates( - model_output.hidden_states, - expanded_request.seq_group_metadata_list) - - @staticmethod - def _expand_execute_model_request( - execute_model_req: ExecuteModelRequest, - seq_with_bonus_token_in_last_step: set, - ) -> Tuple[ExecuteModelRequest, List[int]]: - """ - Expands the execute model request based on sequences with bonus - tokens. - - For each sequence with a bonus token, this method creates a new - sequence without the bonus token and adds it to the execute model - request. The original sequence groups are also retained. The indices - of the original sequence groups are returned for further processing. - - Args: - execute_model_req (ExecuteModelRequest): The original execute - model request. - seq_with_bonus_token_in_last_step (set): Set of sequence IDs that - contain bonus tokens. - - Returns: - Tuple[ExecuteModelRequest, List[int]]: The updated execute model - request with expanded sequences and a list of indices corresponding - to the original sequence groups. - """ - updated_seq_group_metadata_list: List[SequenceGroupMetadata] = [] - updated_execute_model_req = execute_model_req.clone( - updated_seq_group_metadata_list) - indices_of_original_sequence_groups = [] - for seq_group in execute_model_req.seq_group_metadata_list: - seq_group_has_bonus_tokens = False - for seq_id, _ in seq_group.seq_data.items(): - # Identify sequences with bonus tokens in the sequence group. - if seq_id in seq_with_bonus_token_in_last_step: - seq_group_has_bonus_tokens = True - break - if seq_group_has_bonus_tokens: - #Create new sequences without the last bonus token. These new - # sequence have the same sequence id as the original sequence. - # We create a new sequence group and add them there. - updated_seq_group_without_bonus_token = \ - MultiStepWorker._copy_seq_metadata_excluding_last_token( - seq_group, seq_with_bonus_token_in_last_step) - updated_seq_group_metadata_list.append( - updated_seq_group_without_bonus_token) - # Add the original sequence group. - updated_seq_group_metadata_list.append( - MultiStepWorker._shallow_copy_seq_group_metadata(seq_group)) - # Record the index of the original sequence group. - indices_of_original_sequence_groups.append( - len(updated_seq_group_metadata_list) - 1) - - updated_execute_model_req.seq_group_metadata_list =\ - updated_seq_group_metadata_list - - if isinstance(updated_execute_model_req.previous_hidden_states, - HiddenStates): - updated_execute_model_req.previous_hidden_states\ - .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) - - return updated_execute_model_req, indices_of_original_sequence_groups - - @staticmethod - def _filter_model_output( - expanded_batch_outputs: List[SamplerOutput], - output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]: - """ - Filters the model output to include only the specified sequence - outputs. This method contracts the expanded batch output from the - model to retain the outputs of only those sequences indicated by the - provided indices. - - Args: - expanded_batch_output (List[SamplerOutput]): The expanded output - batch from the model. - output_indices_to_retain (torch.Tensor): Indices of the model - outputs to retain. - - Returns: - List[SamplerOutput]: A list containing the filtered model - outputs for the specified indices. - """ - return [ - SamplerOutput( - outputs=[ - expanded_batch_output.outputs[i] - for i in output_indices_to_retain - ] if len(expanded_batch_output.outputs) > 0 else [], - sampled_token_probs=( - expanded_batch_output. - sampled_token_probs[output_indices_to_retain] - if expanded_batch_output.sampled_token_probs is not None - else None), - logprobs=( - expanded_batch_output.logprobs[output_indices_to_retain] - if expanded_batch_output.logprobs is not None else None), - sampled_token_ids=(expanded_batch_output. - sampled_token_ids[output_indices_to_retain] - if expanded_batch_output.sampled_token_ids - is not None else None)) - for expanded_batch_output in expanded_batch_outputs - ] - - def get_spec_proposals( - self, - execute_model_req: ExecuteModelRequest, - seq_ids_with_bonus_token_in_last_step: set, - ) -> SpeculativeProposals: - """Produce speculations given an input batch of sequences. The number of - speculative tokens per sequence is determined by max_proposal_len. - """ - return self._proposer.get_spec_proposals( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - - @staticmethod - def _append_new_tokens( - model_output: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - indices_of_seq_with_bonus_tokens: List[int]) -> None: - """Given model output from a single run, append the tokens to the - sequences. This is normally done outside of the worker, but it is - required if the worker is to perform multiple forward passes. - """ - count = 0 - for index, (seq_group_metadata, sequence_group_outputs) in enumerate( - zip(seq_group_metadata_list, model_output)): - seq_group_metadata.is_prompt = False - - for seq_output in sequence_group_outputs.samples: - # NOTE: Beam search is not supported, so we can assume that - # parent_seq_id == seq_id. - seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] - - token_id = seq_output.output_token - token_logprob = seq_output.logprobs[token_id] - # Determine the actual token ID to be generated, - # considering bonus tokens - if index != indices_of_seq_with_bonus_tokens[count]: - bonus_seq_metadata = seq_group_metadata_list[ - indices_of_seq_with_bonus_tokens[count]] - _, bonus_token_seq_data = next( - iter(bonus_seq_metadata.seq_data.items())) - token_id = bonus_token_seq_data.output_token_ids[-1] - else: - count += 1 - - seq.append_token_id(token_id, token_logprob.logprob, - seq_output.output_embed) - seq.update_num_computed_tokens(1) - - @staticmethod - def _shallow_copy_seq_group_metadata( - seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata: - """Copy input data structures to remove side-effects when input data - structures are shared with other modules. - - Helpful when the vLLM scheduler runs in the same process as the worker. - The alternative is deep-copying (or other form of deep copy); this has - performance downsides. - """ - # Shallow-copy the SequenceGroupMetadata. This allows us to - # append tokens and change is_prompt without external side-effects. - # We must shallow-copy seq_group_metadata as is_prompt could change. - new_seq_group_metadata = copy.copy(seq_group_metadata) - - # We must shallow-copy seq_data as we will append token ids - new_seq_data: Dict[int, SequenceData] = {} - for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - new_seq_data[seq_id] = copy.copy(old_seq_data) - new_seq_data[seq_id].output_token_ids =\ - old_seq_data.output_token_ids[:] - - new_seq_group_metadata.seq_data = new_seq_data - return new_seq_group_metadata - - @staticmethod - def _copy_seq_metadata_excluding_last_token( - seq_group_metadata: SequenceGroupMetadata, - seq_ids_to_copy: Set[int], - ) -> SequenceGroupMetadata: - """ - Creates a shallow copy of the given SequenceGroupMetadata, retaining - only the sequence IDs specified in seq_ids_to_copy. For each of these - sequence IDs, all output_token_ids except the last one are copied. - Sequence IDs not in seq_ids_to_copy are excluded from the copy. - - Parameters: - seq_group_metadata (SequenceGroupMetadata): The original sequence - group metadata. - seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the - copy. - - Returns: - SequenceGroupMetadata: A shallow copy of the sequence group metadata - with the specified modifications. - """ - # Shallow-copy the SequenceGroupMetadata. - new_seq_group_metadata = copy.copy(seq_group_metadata) - # Shallow-copy seq_data and modify the output_token_ids. - new_seq_data: Dict[int, SequenceData] = {} - for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): - if (seq_id in seq_ids_to_copy): - new_seq_data[seq_id] = copy.copy(old_seq_data) - # Copy all the output token ids except the last. - # Also reduce num_computed_tokens by 1 since we are not - # including the last output token. - # NOTE: num_computed_tokens is not directly used by the - # speculative decoding workers, as it is only relevant for - # chunked prefill, which is disabled for speculative decoding. - # However, to maintain consistency in num_computed_tokens, - # we update it here. - new_seq_data[seq_id].output_token_ids =\ - old_seq_data.output_token_ids[:-1] - new_seq_data[seq_id].update_num_computed_tokens(-1) - new_seq_group_metadata.seq_data = new_seq_data - return new_seq_group_metadata - - def _assert_enough_kv_space( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - num_steps: int) -> None: - """Assert there are enough physical blocks per sequence to store the - current KV plus additional KV from num_steps tokens. - """ - assert self.model_runner.block_size is not None - for seq_group_metadata in seq_group_metadata_list: - # Only one seq_id is guaranteed because there is no beam search. - seq_id = list(seq_group_metadata.seq_data.keys())[0] - seq = seq_group_metadata.seq_data[seq_id] - - # After num_steps, the seq len will be the current seq len - # plus one token per step. - final_seq_len = seq.get_len() + num_steps - - # We will have final_seq_len - 1 KV because vLLM saves KV for a - # token in the iteration after the token was generated. - required_num_kv_slots = final_seq_len - 1 - - # The allocated number of kv slots is the number of allocated blocks - # times the number of slots of block. - number_physical_blocks = len( - seq_group_metadata.block_tables[seq_id]) - allocated_kv_slots = (number_physical_blocks * - self.model_runner.block_size) - - if required_num_kv_slots > allocated_kv_slots: - request_id = seq_group_metadata.request_id - raise ValueError( - "The worker attempted to run " - f"{num_steps} times but found insufficient KV space for " - f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " - f"{required_num_kv_slots=}).") - - def _raise_if_unsupported( - self, - execute_model_req: ExecuteModelRequest, - ) -> None: - """MultiStepWorker does not yet implement support for cache swap - operations or beam search. - """ - if any([ - execute_model_req.blocks_to_swap_in, - execute_model_req.blocks_to_swap_out, - execute_model_req.blocks_to_copy - ]): - raise NotImplementedError( - "MultiStepWorker does not support cache operations") - - if any( - len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in - execute_model_req.seq_group_metadata_list): - raise NotImplementedError( - "MultiStepWorker does not support beam search.") - - def maybe_load_lm_head_weight( - self, - lm_head_weight: torch.Tensor, - ) -> None: - weight_loader = getattr( - self.worker.model_runner.model_runner.model.lm_head.weight, - "weight_loader", default_weight_loader) - weight_loader( - self.worker.model_runner.model_runner.model.lm_head.weight, - lm_head_weight) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py deleted file mode 100644 index 7a1a0e56dc00..000000000000 --- a/vllm/spec_decode/ngram_worker.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import weakref -from typing import List, Optional, Set, Tuple - -import torch -import torch.nn as nn - -from vllm.config import VllmConfig -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.spec_decode.interfaces import SpeculativeProposals -from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase -from vllm.spec_decode.top1_proposer import Top1Proposer - - -class _DummyModel(nn.Module): - pass - - -class NGramWorker(NonLLMProposerWorkerBase): - """NGramWorker provides a light drafter without need for model. - - Current NGramWorker only implements prompt lookup decoding, - and in future we may also do RAG type drafter and other scenarios - which don't rely on LLM model to give proposals. - """ - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - device_type: str = "cuda", - **kwargs, - ): - super().__init__(vllm_config) - - # Get local_rank/vocab_size from kwargs attribute - self.local_rank = local_rank - self.device_type = device_type - - # Lazy initialization list. - self._proposer: Top1Proposer - - def set_ngram_window_size(self, ngram_prompt_lookup_min: int, - ngram_prompt_lookup_max: int): - # Search valid candidate window between - # ngram_prompt_lookup_min/ngram_prompt_lookup_max - self.ngram_prompt_lookup_max = ngram_prompt_lookup_max - self.ngram_prompt_lookup_min = ngram_prompt_lookup_min - - def init_device(self): - self.device = torch.device(f"{self.device_type}:{self.local_rank}") - - # Current NGramWorker only supports Top1Proposer - self._proposer = Top1Proposer( - weakref.proxy(self), # type: ignore[arg-type] - device=self.device, - vocab_size=self.vocab_size, - ) - - def load_model(self) -> None: - pass # Dummy - - def get_model(self) -> nn.Module: - return _DummyModel() - - def sampler_output( - self, - execute_model_req: ExecuteModelRequest, - sample_len: int, - # Unused parameter. NGramWorker does not use the KV Cache and - # therefore does not need this parameter. - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: - """NGram match algo to pick proposal candidate. Returns the list of - sampler output, one per SequenceGroupMetadata. - - For ngram worker, we already done needed transposed internal, so the - indicator pass to sampler_output_to_torch shall be False. - """ - self._raise_if_unsupported(execute_model_req) - - has_spec_out = False - token_id_list: List[Optional[torch.Tensor]] = [] - token_prob_list: List[Optional[torch.Tensor]] = [] - for idx, seq_group_metadata in enumerate( - execute_model_req.seq_group_metadata_list): - seq_data = next(iter(seq_group_metadata.seq_data.values())) - - seq_len = seq_data.get_len() - # When seq_len is less than 3072 (3K), we use CPU to perform - # the ngram match. Otherwise, we use the device specified in - # the model config (normally GPU). 3072 is a rough threshold - # based on profiling on H100, and it can be adjusted based - # on the actual performance on different hardware. - cur_device = "cpu" if seq_len < 3072 else self.device - input_ids = torch.as_tensor(seq_data.get_token_ids(), - dtype=torch.long, - device=cur_device) - input_length = seq_data.get_len() - - for ngram_size in range( - min(self.ngram_prompt_lookup_max, input_length - 1), - self.ngram_prompt_lookup_min - 1, - -1, - ): - ngram_tensor = input_ids[-ngram_size:] - if ngram_size == 1: - # Do not match itself and do not use unfold and all - matches = (input_ids[:-1] == ngram_tensor) - else: - windows = input_ids.unfold(dimension=0, - size=ngram_size, - step=1) - # Do not match itself - matches = (windows[:-1] == ngram_tensor).all(dim=-1) - - # first_match includes "values" (bool), indicating whether - # the match is found, and "indices", indicating the index - # of the first match. - first_match = matches.max(dim=-1) - if first_match.values.item(): - proposal_start_idx = first_match.indices.add_(ngram_size) - spec_indices = ( - proposal_start_idx).repeat(sample_len) + torch.arange( - sample_len, device=cur_device) - spec_indices.clamp_(max=input_ids.shape[-1] - 1) - res = input_ids.gather(dim=-1, - index=spec_indices).to(self.device) - token_id_list.append(res) - token_prob_list.append( - torch.nn.functional.one_hot( - res, - num_classes=self.vocab_size).to(torch.float32)) - has_spec_out = True - break - else: - token_id_list.append(None) - token_prob_list.append(None) - - if not has_spec_out: - return None, False - - outputs: List[Optional[SamplerOutput]] = [] - for idx in range(len(execute_model_req.seq_group_metadata_list)): - if token_id_list[idx] is None: - outputs.append(None) - else: - outputs.append( - SamplerOutput( - outputs=None, - sampled_token_probs=token_prob_list[idx], - logprobs=torch.zeros((sample_len, self.vocab_size), - dtype=torch.float32, - device=self.device), - sampled_token_ids=token_id_list[idx], - )) - - return outputs, False - - def get_spec_proposals( - self, - execute_model_req: ExecuteModelRequest, - # Unused parameter. NGramWorker does not use the KV Cache and - # therefore does not need this parameter. - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> SpeculativeProposals: - """Produce speculations given an input batch of sequences. The number of - speculative tokens per sequence is determined by max_proposal_len. - """ - return self._proposer.get_spec_proposals( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - - def _raise_if_unsupported( - self, - execute_model_req: ExecuteModelRequest, - ) -> None: - """NGramWorker does not yet implement support for cache swap - operations or beam search. - """ - if any([ - execute_model_req.blocks_to_swap_in, - execute_model_req.blocks_to_swap_out, - execute_model_req.blocks_to_copy - ]): - raise NotImplementedError( - "NGramWorker does not support cache operations") - - if any( - len(seq_group_metadata.seq_data.keys()) != 1 - for seq_group_metadata in - execute_model_req.seq_group_metadata_list): - raise NotImplementedError( - "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py deleted file mode 100644 index fb44275aa935..000000000000 --- a/vllm/spec_decode/proposer_worker_base.py +++ /dev/null @@ -1,59 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from abc import ABC, abstractmethod -from typing import List, Optional, Set, Tuple - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest -from vllm.spec_decode.interfaces import SpeculativeProposer -from vllm.worker.worker_base import LoRANotSupportedWorkerBase - - -class ProposerWorkerBase(LoRANotSupportedWorkerBase, SpeculativeProposer): - """Interface for proposer workers""" - - @abstractmethod - def sampler_output( - self, - execute_model_req: ExecuteModelRequest, - sample_len: int, - # A set containing all sequence IDs that were assigned bonus tokens - # in their last forward pass. This set is used to backfill the KV cache - # with the key-value pairs of the penultimate token in the sequences. - # This parameter is only used by the MultiStepWorker, which relies on - # the KV cache for token generation. It is not used by workers that - # do not utilize the KV cache. - seq_ids_with_bonus_token_in_last_step: Set[int] - ) -> Tuple[Optional[List[SamplerOutput]], bool]: - raise NotImplementedError - - def set_include_gpu_probs_tensor(self) -> None: - """Implementation optional""" - pass - - def set_should_modify_greedy_probs_inplace(self) -> None: - """Implementation optional""" - pass - - -class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): - """Proposer worker which does not use a model with kvcache""" - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - """get_spec_proposals is used to get the proposals""" - return [] - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """This is never called on the proposer, only the target model""" - raise NotImplementedError - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - pass - - def get_cache_block_size_bytes(self) -> int: - return 0 diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py deleted file mode 100644 index 91256cab6e79..000000000000 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Set, Tuple - -import torch -import torch.nn as nn - -from vllm.distributed.parallel_state import (get_tp_group, - init_model_parallel_group, - patch_tensor_parallel_group) -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import ExecuteModelRequest -from vllm.spec_decode.interfaces import SpeculativeProposals -from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase - -logger = init_logger(__name__) - - -class _DummyModel(nn.Module): - pass - - -class SmallerTpProposerWorker(ProposerWorkerBase): - """Class which allows a speculative draft model to run with smaller tensor - parallel degree than target model. - This reduces the communication overhead of small draft models. - - To implement this feature, this class differs behavior based on is_dummy - flag, where dummy means worker that does not participate draft generation. - Participating workers use a smaller tp group by patching vLLM's tensor - parallel group temporarily during forward passes of draft models. - """ - - @classmethod - def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int, - target_tensor_parallel_size: int): - """Wrap the worker in a SmallerTpProposerWorker if necessary. - """ - if draft_tensor_parallel_size == target_tensor_parallel_size: - return worker - - # gpu ranks that will generate draft tokens together - draft_ranks = list(range(draft_tensor_parallel_size)) - - logger.info("Wrapping {%s} in {%s}", type(worker), cls) - return cls(worker, draft_ranks) - - def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]): - """Create a SmallerTpProposerWorker. - - Args: - worker (~vllm.spec_decode.multi_step_worker.MultiStepWorker): an - actual worker wrapped with this class - draft_ranks (List[int]): if this value is given, only the GPU ranks - written in this value participate in draft generation - """ - self._worker = worker - self._draft_ranks = draft_ranks - - # init during init_device - self._is_dummy = False - self._tp_group = None - - def _patch_tensor_parallel_group(self): - """Temporarily patch the global tp group state with its own tp group - state. - """ - return patch_tensor_parallel_group(self._tp_group) - - def init_device(self) -> None: - self._is_dummy = get_tp_group().rank not in self._draft_ranks - - # dummy workers do nothing - if self._is_dummy: - return - - # creates tp process group containing only a subset of gpu ranks - local_rank = get_tp_group().local_rank - tp_backend = torch.distributed.get_backend(get_tp_group().device_group) - self._tp_group = init_model_parallel_group([self._draft_ranks], - local_rank, tp_backend) - - with self._patch_tensor_parallel_group(): - self._worker.init_device() - - def set_include_gpu_probs_tensor(self) -> None: - if self._is_dummy: - return - - # Need include_gpu_probs_tensor for multi_step_worker - self._worker.set_include_gpu_probs_tensor() - - def set_should_modify_greedy_probs_inplace(self) -> None: - if self._is_dummy: - return - - self._worker.set_should_modify_greedy_probs_inplace() - - def load_model(self) -> None: - if self._is_dummy: - return - - with self._patch_tensor_parallel_group(): - self._worker.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - if self._is_dummy: - # this case is not used now - return -1, -1 - - with self._patch_tensor_parallel_group(): - return self._worker.determine_num_available_blocks() - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - if self._is_dummy: - return - - with self._patch_tensor_parallel_group(): - self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) - - def sampler_output( - self, - execute_model_req: ExecuteModelRequest, - sample_len: int, - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> Tuple[List[SamplerOutput], bool]: - # Do not check _is_dummy, as it's always called by get_spec_proposals - return self._worker.sampler_output( - execute_model_req, sample_len, - seq_ids_with_bonus_token_in_last_step) - - def get_spec_proposals( - self, - execute_model_req: ExecuteModelRequest, - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> SpeculativeProposals: - """Produce speculations given an input batch of sequences. The number of - speculative tokens per sequence is determined by max_proposal_len. - """ - if self._is_dummy: - return SpeculativeProposals(None, None, None) - - with self._patch_tensor_parallel_group(): - return self._worker.get_spec_proposals( - execute_model_req, seq_ids_with_bonus_token_in_last_step) - - def get_model(self) -> nn.Module: - if self._is_dummy: - return _DummyModel() - - with self._patch_tensor_parallel_group(): - return self._worker.get_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - if self._is_dummy: - return [] - - with self._patch_tensor_parallel_group(): - return self._worker.execute_model(execute_model_req) - - def get_cache_block_size_bytes(self) -> int: - if self._is_dummy: - # by returning zero, target worker can use the entire kv cache space - return 0 - - return self._worker.get_cache_block_size_bytes() - - @property - def vocab_size(self) -> int: - return self._worker.vocab_size - - def maybe_load_lm_head_weight( - self, - lm_head_weight: torch.Tensor, - ) -> None: - if self._is_dummy: - return - - with self._patch_tensor_parallel_group(): - weight_loader = getattr( - self._worker.worker.model_runner.model_runner.model.\ - lm_head.weight, - "weight_loader", - default_weight_loader) - weight_loader( - self._worker.worker.model_runner.model_runner.model.\ - lm_head.weight, - lm_head_weight) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py deleted file mode 100644 index 7dda1cbfe230..000000000000 --- a/vllm/spec_decode/spec_decode_worker.py +++ /dev/null @@ -1,1326 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import copy -from collections import defaultdict -from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Tuple, Type - -import torch -import torch.nn as nn - -from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig -from vllm.distributed.communication_op import (broadcast_tensor_dict, - get_tp_group, - tensor_model_parallel_gather) -from vllm.distributed.parallel_state import model_parallel_is_initialized -from vllm.logger import init_logger -from vllm.model_executor.layers.rejection_sampler import RejectionSampler -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.layers.spec_decode_base_sampler import ( - SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) -from vllm.model_executor.layers.typical_acceptance_sampler import ( - TypicalAcceptanceSampler) -from vllm.platforms import current_platform -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, - CompletionSequenceGroupOutput, ExecuteModelRequest, - HiddenStates, SequenceGroupMetadata, - get_all_seq_ids_and_request_ids) -from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer - -if current_platform.is_cuda_alike(): - from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner - -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeScorer, SpeculativeScores) -from vllm.spec_decode.medusa_worker import MedusaWorker -from vllm.spec_decode.metrics import AsyncMetricsCollector -from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker -from vllm.spec_decode.mqa_scorer import MQAScorer -from vllm.spec_decode.multi_step_worker import MultiStepWorker -from vllm.spec_decode.ngram_worker import NGramWorker -from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase -from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker -from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (Timer, create_logprobs_output, - create_sequence_group_output, - get_all_num_logprobs, - get_sampled_token_logprobs, nvtx_range, - split_batch_by_proposal_len) -from vllm.utils import resolve_obj_by_qualname -from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase - -logger = init_logger(__name__) - - -def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": - """Helper method that is the entrypoint for Executors which use - WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. - """ - vllm_config: VllmConfig = kwargs.get("vllm_config") - speculative_config: SpeculativeConfig = vllm_config.speculative_config - assert speculative_config is not None - - if vllm_config.parallel_config.pipeline_parallel_size > 1: - raise NotImplementedError("Speculative decoding is currently " - "incompatible with pipeline parallelism") - - draft_worker_kwargs = kwargs.copy() - - kwargs["model_runner_cls"] = TargetModelRunner - target_worker_config = copy.deepcopy(vllm_config) - target_worker_config.parallel_config.worker_cls =\ - target_worker_config.parallel_config.sd_worker_cls - cls = resolve_obj_by_qualname( - target_worker_config.parallel_config.worker_cls) - target_worker = cls(*args, **kwargs) - # Set the disable_logprobs variable in the TargetModelRunner instance - # as per its value specified in the SpeculativeConfig. - target_worker.model_runner.disable_logprobs =\ - speculative_config.disable_logprobs - - draft_worker_config = copy.deepcopy(vllm_config) - draft_worker_config.model_config = speculative_config.draft_model_config - draft_worker_config.quant_config = VllmConfig._get_quantization_config( - draft_worker_config.model_config, - vllm_config.load_config, - ) - speculative_config.draft_parallel_config.worker_cls =\ - draft_worker_config.parallel_config.sd_worker_cls - draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa - # TODO allow draft-model specific load config. - - # Override draft-model specific worker args. - draft_worker_kwargs.update( - vllm_config=draft_worker_config, - ngram_prompt_lookup_max=speculative_config.prompt_lookup_max, - ngram_prompt_lookup_min=speculative_config.prompt_lookup_min, - ) - - spec_decode_worker = SpecDecodeWorker.create_worker( - scorer_worker=target_worker, - draft_worker_kwargs=draft_worker_kwargs, - disable_mqa_scorer=speculative_config.disable_mqa_scorer, - disable_by_batch_size=speculative_config.disable_by_batch_size, - draft_token_acceptance_method=speculative_config.acceptance_method, - typical_acceptance_sampler_posterior_threshold=speculative_config. - posterior_threshold, - typical_acceptance_sampler_posterior_alpha=speculative_config. - posterior_alpha, - disable_logprobs=speculative_config.disable_logprobs, - disable_log_stats=speculative_config.disable_log_stats, - num_speculative_tokens=speculative_config.num_speculative_tokens, - ) - - return spec_decode_worker - - -# Reminder: Please update docs/features/compatibility_matrix.md -# If the feature combo become valid -class SpecDecodeWorker(LoRANotSupportedWorkerBase): - """Worker which implements speculative decoding. - - Speculative decoding reduces decoding per-token latency by using a proposal - method, such as a small draft model, to speculate ahead of a larger LLM. The - probabilities of the speculative tokens are then determined by the larger - LLM, after which some verification routine determines which (if any) of the - speculative tokens are accepted by the larger LLM. - - See https://github.com/vllm-project/vllm/pull/2188 and - https://github.com/vllm-project/vllm/pull/3103 for more info. - - The current implementation has the following limitations: - * Only draft-model proposal is implemented (contributions for more forms are - welcome!). - * Only top-1 proposal and scoring are implemented. Tree-attention is left as - future work. - * All sequences in a batch must have the same proposal length, or zero. This - can be improved by having per-sequence speculation in the future. - * The scoring forward pass is done without an MQA kernel, which is - suboptimal especially as the batch size, proposal length, and sequence - lengths grow. Contributions to add a MQA scoring are welcome once - correctness tests pass. - More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. - """ - - @classmethod - def create_worker( - cls, - scorer_worker: WorkerBase, - draft_worker_kwargs: Dict[str, Any], - disable_mqa_scorer: bool, - disable_by_batch_size: Optional[int], - draft_token_acceptance_method: str, - typical_acceptance_sampler_posterior_threshold: float, - typical_acceptance_sampler_posterior_alpha: float, - disable_logprobs: bool, - disable_log_stats: bool, - num_speculative_tokens: int, - ) -> "SpecDecodeWorker": - - allow_zero_draft_token_step = True - enable_lm_head_weight_load = False - num_spec_prefill_steps = 1 - ngram_prompt_lookup_max = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_max")) - ngram_prompt_lookup_min = ( - draft_worker_kwargs.pop("ngram_prompt_lookup_min")) - draft_model_config = draft_worker_kwargs["vllm_config"].model_config - draft_parallel_config: ParallelConfig = draft_worker_kwargs[ - 'vllm_config'].parallel_config - if ngram_prompt_lookup_max > 0: - draft_worker_kwargs[ - "device_type"] = scorer_worker.device_config.device.type - proposer_worker = NGramWorker(**draft_worker_kwargs) - proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, - ngram_prompt_lookup_max) - else: - draft_tp = draft_parallel_config.tensor_parallel_size - target_tp = scorer_worker.parallel_config.tensor_parallel_size - - if draft_model_config.hf_config.model_type == "mlp_speculator": - proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) - elif draft_model_config.hf_config.model_type == "medusa": - proposer_worker = MedusaWorker(**draft_worker_kwargs) - else: - if draft_tp == 1: - if current_platform.is_cuda_alike(): - draft_worker_kwargs[ - "model_runner_cls"] = TP1DraftModelRunner - else: - if draft_model_config.hf_config.model_type == "eagle": - raise NotImplementedError( - f"{draft_model_config.hf_config.model_type} " - "does not support TP > 1 yet") - - allow_zero_draft_token_step = False - - # Load lm_head weight for eagle in init_device - if draft_model_config.hf_config.model_type == "eagle": - enable_lm_head_weight_load = True - - proposer_worker = MultiStepWorker(**draft_worker_kwargs) - if draft_model_config.hf_config.model_type == "deepseek_mtp": - num_spec_prefill_steps = \ - draft_model_config.hf_config.n_predict - - proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( - proposer_worker, draft_tp, target_tp) - - logger.info("Configuring SpecDecodeWorker with proposer=%s", - type(proposer_worker)) - - spec_decode_sampler: SpecDecodeBaseSampler = None - if draft_token_acceptance_method == "rejection_sampler": - spec_decode_sampler = RejectionSampler() - elif draft_token_acceptance_method == "typical_acceptance_sampler": - spec_decode_sampler = TypicalAcceptanceSampler( - posterior_threshold=\ - typical_acceptance_sampler_posterior_threshold, - posterior_alpha=typical_acceptance_sampler_posterior_alpha, - ) - logger.info( - "[Speculative Decoding] Configuring" - " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) - - if not disable_mqa_scorer: - if scorer_worker.model_runner.attn_backend.get_name( - ) != "FLASH_ATTN": - disable_mqa_scorer = True - logger.info( - "[Speculative Decoding] Disabling MQA scorer as the " - "MQA is only available with flash attn backend.") - - if draft_model_config and \ - draft_model_config.max_model_len < \ - scorer_worker.model_config.max_model_len: - disable_mqa_scorer = True - logger.info( - "[Speculative Decoding] Disabling MQA scorer as the " - "draft model max_model_len is smaller than the target " - "model max_model_len.") - - if not scorer_worker.model_runner.model_config.enforce_eager: - disable_mqa_scorer = True - logger.info( - "[Speculative Decoding] Disabling MQA scorer as the " - "target model is not running in eager mode.") - - return SpecDecodeWorker( - proposer_worker, - scorer_worker, - disable_mqa_scorer=disable_mqa_scorer, - disable_logprobs=disable_logprobs, - disable_log_stats=disable_log_stats, - disable_by_batch_size=disable_by_batch_size, - spec_decode_sampler=spec_decode_sampler, - allow_zero_draft_token_step=allow_zero_draft_token_step, - enable_lm_head_weight_load=enable_lm_head_weight_load, - num_spec_prefill_steps=num_spec_prefill_steps) - - def __init__( - self, - proposer_worker: ProposerWorkerBase, - scorer_worker: WorkerBase, - spec_decode_sampler: SpecDecodeBaseSampler, - disable_mqa_scorer: bool = False, - disable_logprobs: bool = False, - disable_log_stats: bool = False, - metrics_collector: Optional[AsyncMetricsCollector] = None, - disable_by_batch_size: Optional[int] = None, - allow_zero_draft_token_step: Optional[bool] = True, - enable_lm_head_weight_load: Optional[bool] = False, - num_spec_prefill_steps: int = 1, - ): - """ - Create a SpecDecodeWorker. - - Args: - proposer_worker: A worker that can produce speculative tokens for - sequences. - scorer_worker: A worker that produces probabilities of speculative - tokens according to some base model. Typically a vanilla vLLM - Worker. - spec_decode_sampler: A Torch module used to perform acceptance - sampling of the draft tokens in the verification step of - speculative decoding. Currently we support two different - types of sampler namely RejectionSampler and - TypicalAcceptanceSampler. 'spec_decode_sampler' is either an - instance of RejectionSampler or TypicalAcceptanceSampler. - disable_mqa_scorer: If set to True, disable the MQA scorer and use - the BatchExpansionTop1Scorer instead. - disable_logprobs: If set to True, token log probabilities will - not be output in both the draft worker and the target worker. - If set to False, log probabilities will be output by both. - disable_log_stats: If set to True, disable periodic printing of - speculative stage times. - disable_by_batch_size: If the batch size is larger than this, - disable speculative decoding for new incoming requests. - metrics_collector: Helper class for collecting metrics; can be set - for testing purposes. - allow_zero_draft_token_step: whether to allow a step where the draft - model generates no draft token; should disallow when the tp of - draft model is larger than 1 (TODO: #5814) - enable_lm_head_weight_load: whether to load lm_head weight for - draft models like eagle. - num_spec_prefill_steps: number of speculative prefill steps to run - before the speculative decoding starts. This is only used when - the draft model is a deepseek_mtp model that requires prefill - kv cache separately for each MTP layer. - """ - self.proposer_worker = proposer_worker - self.scorer_worker = scorer_worker - scorer_runner = getattr(self.scorer_worker, "model_runner", None) - self.generators = scorer_runner.get_generators( - ) if scorer_runner else None - self.disable_by_batch_size = disable_by_batch_size or float("inf") - self.spec_decode_sampler = spec_decode_sampler - self._allow_zero_draft_token_step = allow_zero_draft_token_step - self._enable_lm_head_weight_load = enable_lm_head_weight_load - self._metrics = AsyncMetricsCollector( - self.spec_decode_sampler - ) if metrics_collector is None else metrics_collector - # Tracks the sequence IDs that received a bonus token ID in - # their last forward pass. Needed only if KV cache is being - # used for token generation such as in the case of MultiStepWorker. - self._seq_with_bonus_token_in_last_step: Set[int] = set() - # Tracks the currently active request ids and the sequence IDs - # corresponding to them - self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set) - # Tracks if the proposer worker uses the KV cache or not. - - self.probs_dtype = self.spec_decode_sampler.probs_dtype - self.token_id_dtype = self.spec_decode_sampler.token_id_dtype - # Lazy initialization. - self.scorer: SpeculativeScorer - self.disable_mqa_scorer = disable_mqa_scorer - - # Hidden states from target model to pass to proposer - # in the subsequent step. - self.previous_hidden_states: Optional[HiddenStates] = None - self._disable_logprobs = disable_logprobs - self._disable_log_stats = disable_log_stats - self._num_spec_prefill_steps = num_spec_prefill_steps - - def init_device(self) -> None: - """Initialize both scorer and proposer models. - """ - # The scorer worker model is initialized first in case the proposer - # model has a smaller TP degree than the target worker. - self.scorer_worker.init_device() - self.proposer_worker.init_device() - - # NOTE(cade): load_model is not part of the WorkerBase interface. - self.scorer_worker.load_model() - self.proposer_worker.load_model() - - if self._enable_lm_head_weight_load: - # NOTE(Shangming): gather lm_head weight when tp enabled - target_lm_head_weight: torch.Tensor = tensor_model_parallel_gather( - self.scorer_worker.model_runner.model_runner.model.lm_head.\ - weight.data, - dim=0, - ) - - self.proposer_worker.maybe_load_lm_head_weight( - target_lm_head_weight) - - self._metrics.init_tensors(self.rank, device_type=self.device) - if model_parallel_is_initialized(): - self.spec_decode_sampler.init_tensors(get_tp_group().local_rank, - device_type=self.device) - else: - self.spec_decode_sampler.init_tensors(self.rank, - device_type=self.device) - - scorer_cls: Type[SpeculativeScorer] - if self.disable_mqa_scorer: - scorer_cls = BatchExpansionTop1Scorer - logger.info("[Speculative Decoding] Use batch " - "expansion for scoring proposals.") - else: - scorer_cls = MQAScorer - logger.info( - "[Speculative Decoding] Use MQA scorer for scoring proposals.") - - self.scorer = scorer_cls(scorer_worker=self.scorer_worker, - device=self.device, - vocab_size=self._vocab_size) - - self._configure_model_sampler_for_spec_decode() - - def load_model(self, *args, **kwargs): - pass - - def _configure_model_sampler_for_spec_decode(self): - """Configure model sampler to emit GPU tensors. This allows spec decode - to keep data on device without transferring to CPU and serializing, - which significantly reduces overhead of sampling during verification. - - NOTE(cade): This breaks abstraction boundaries pretty badly. The better - design is to have the "move to CPU and serialize" sampling decision be - done outside of the model/sampler; this way the "last-mile" worker - object which interfaces with the scheduler can serialize and incur the - performance hit as necessary. This allows us to run the worker several - iterations in a row without incurring the "move to CPU and serialize" - performance penalty. - - Since this requires a large change to vLLM, we defer it to later and - temporarily accept this broken abstraction boundary. - - NOTE(cade): This will require a special check if the proposer worker - does not have a sampler (e.g. ngram speculation). - """ - (self.scorer_worker.model_runner.sampler.include_gpu_probs_tensor - ) = True - (self.scorer_worker.model_runner.sampler. - should_modify_greedy_probs_inplace) = True - self.proposer_worker.set_include_gpu_probs_tensor() - self.proposer_worker.set_should_modify_greedy_probs_inplace() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of cache blocks to use. - - This is done by profiling the scorer model (which is typically the - larger of the two). Then the total memory which would be used by the - scorer cache is divided evenly between the proposer and scorer model KV, - such that the number of blocks is equal in both KV caches. - """ - num_gpu_blocks, num_cpu_blocks = ( - self.scorer_worker.determine_num_available_blocks()) - - scorer_cache_block_size_bytes = ( - self.scorer_worker.get_cache_block_size_bytes()) - proposer_cache_block_size_bytes = ( - self.proposer_worker.get_cache_block_size_bytes()) - - new_num_gpu_blocks = split_num_cache_blocks_evenly( - scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, - num_gpu_blocks) - return new_num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the cache engine of the scorer and proposer workers. - """ - self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) - - def get_model(self) -> nn.Module: - return self.scorer_worker.get_model() - - @torch.inference_mode() - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None - ) -> List[SamplerOutput]: - """Perform speculative decoding on the input batch. - """ - if self.rank != self._driver_rank: - self._run_non_driver_rank() - return [] - - if execute_model_req is None: - # This signals that there's no more requests to process for now. - # All workers are running infinite loop with broadcast_tensor_dict, - # and it stops the loop when the driver broadcasts an empty input. - # Send an empty input to notify all other workers to stop their - # execution loop. - broadcast_tensor_dict({}, src=0) - return [] - - self._track_finished_requests(execute_model_req) - disable_all_speculation = self._should_disable_all_speculation( - execute_model_req) - num_lookahead_slots = execute_model_req.num_lookahead_slots - all_prompt = True - atleast_one_prompt = False - all_zero_spec_tokens = True - for sgm in execute_model_req.seq_group_metadata_list: - all_prompt = all_prompt and sgm.is_prompt - atleast_one_prompt = atleast_one_prompt or sgm.is_prompt - all_zero_spec_tokens = all_zero_spec_tokens and ( - sgm.num_speculative_tokens == 0) - - if all_prompt and execute_model_req.seq_group_metadata_list: - assert num_lookahead_slots == 0, ( - "Prompt only runs should have num_lookahead_slots equal to 0. " - "This should never happen, please file a bug at " - "https://github.com/vllm-project/vllm/issues") - # Speculative decoding is disabled in the following cases: - # 1. Prefill phase: Speculative decoding is not - # used during the prefill phase. - # 2. Auto-disable enabled: The running queue size exceeds - # the specified threshold. - # 3. No request: There are no requests in the batch, or - # none of the requests in the batch have spec decoding enabled. - # In any of these cases, the proposer and scorer workers - # are called normally. - # We expect `num_speculative_tokens` to be None for prefills. - no_spec = (num_lookahead_slots == 0 or disable_all_speculation - or all_zero_spec_tokens) - - # Broadcast how many lookahead slots are scheduled for this step, and - # whether all speculation is disabled, to all non-driver workers. - - # This is required as if the number of draft model runs changes - # dynamically, the non-driver workers won't know unless we perform a - # communication to inform them. - - # no_spec is used to signal non-driver worker about prefill vs decode - # stage. This is needed to ensure that order of execution of proposer - # and scorer is same in both driver and non-driver workers (i.e., - # scorer -> proposer for prefill and proposer -> scorer in decode). This - # order is needed to support models like EAGLE that take scorer states - # as inputs. - broadcast_dict = dict( - num_lookahead_slots=num_lookahead_slots, - no_spec=no_spec, - disable_all_speculation=disable_all_speculation, - # When both chunked prefill and speculative decoding are enabled - # it is possible that the same batch contains both prefill - # and decodes. If that happens in the scorer we run the batch - # as one single forward pass. However, in the proposer we - # run them as 2 different batches - one for prefill and - # the other for decodes. The variable indicates to the non-driver - # worker that there are prefills as part of the speculative batch - # and hence it needs to run an extra prefill forward pass. - run_spec_proposer_for_prefill=atleast_one_prompt, - ) - broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) - - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding requires non-None seq_group_metadata_list") - - self._maybe_disable_speculative_tokens( - disable_all_speculation, execute_model_req.seq_group_metadata_list) - - if no_spec: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all_speculation) - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) - - @torch.inference_mode() - def start_worker_execution_loop(self) -> None: - """Execute model loop to perform speculative decoding - in parallel worker.""" - while self._run_non_driver_rank(): - pass - - def _should_disable_all_speculation( - self, execute_model_req: ExecuteModelRequest) -> bool: - # When the batch size is too large, disable speculative decoding - # to stop trading off throughput for latency. - return (execute_model_req.running_queue_size - >= self.disable_by_batch_size) - - def _maybe_disable_speculative_tokens( - self, disable_all_speculation: bool, - seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: - if not disable_all_speculation: - return - - for seq_group_metadata in seq_group_metadata_list: - # Once num_speculative_tokens is set to 0, the spec decode - # of this request will be disabled forever. - # TODO(comaniac): We currently store spec decoding specific - # state in the global data structure, but we should maintain - # this state within spec decode worker. - seq_group_metadata.num_speculative_tokens = 0 - - def _serialize_sampler_output_no_logprobs( - self, execute_model_req: ExecuteModelRequest, - sampler_output: SamplerOutput) -> List[SamplerOutput]: - """ - Creates and returns a `SamplerOutput` with only the token IDs being - serialized to CPU and populated in `CompletionSequenceGroupOutput`. - All other parameters in `CompletionSequenceGroupOutput` related to log - probabilities are skipped. - - Args: - execute_model_req (ExecuteModelRequest): The model request that - was executed. - sampler_output (SamplerOutput): The output from the sampler with - only GPU tensors populated. - - Returns: - SamplerOutput: A new `SamplerOutput` instance containing a list of - `CompletionSequenceGroupOutput` objects with only token IDs - populated. - """ - seq_output_prompt_logprobs = [ - seq.is_prompt and seq.sampling_params.prompt_logprobs is not None - and seq.sampling_params.prompt_logprobs > 0 - for seq in execute_model_req.seq_group_metadata_list - ] - # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID - sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( - # subtracting is faster than testing for equality - sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ - if any(seq_output_prompt_logprobs) else \ - sampler_output.sampled_token_ids).tolist() - - seq_data_entries = [ - (seq_id, seq_data) for sg in \ - execute_model_req.seq_group_metadata_list \ - for seq_id, seq_data in sg.seq_data.items() - ] - completion_seq_group_output_list: List[ - CompletionSequenceGroupOutput] = [] - output_index = 0 - # Make sure the non-terminal prefill chunks are still aligned with - # their own empty output. - for idx, seq_group_meta in enumerate( - execute_model_req.seq_group_metadata_list): - needs_prompt_logprobs = seq_output_prompt_logprobs[idx] - seq_id, seq_data = seq_data_entries[idx] - if needs_prompt_logprobs: - prompt_token_ids = seq_data.get_prompt_token_ids() - - # Some of these sequences may belong to non-terminal chunks, - # which may still have to report logprobs for prompts. - start = 1 if seq_data._num_computed_tokens == 0 \ - else seq_data._num_computed_tokens - end = (seq_data._num_computed_tokens + \ - seq_group_meta.token_chunk_size) - prompt_token_ids = prompt_token_ids[start:end] - prompt_logprobs = [ - create_logprobs_output( - token_id=p_token_id, - token_id_logprob_rank=-1, - token_id_logprob=0.0, - topk_token_ids=[], - topk_logprobs=[], - ) for p_token_id in prompt_token_ids - ] - else: - prompt_logprobs = None - - # Since we can get chunks here, we dont always have a sampled token - # (only on last chunk) but we still have to provide an output. - if not seq_group_meta.do_sample: - completion_seq_group_output_list.append( - CompletionSequenceGroupOutput( - samples=[], prompt_logprobs=prompt_logprobs)) - continue - - # Sequence with output. - completion_seq_group_output_list.append( - create_sequence_group_output( - token_id=sampled_token_ids_list[output_index][0], - token_id_logprob_rank=-1, - token_id_logprob=0.0, - seq_id=seq_id, - topk_token_ids=[], - topk_logprobs=[], - prompt_logprobs=prompt_logprobs)) - output_index += 1 - - return [SamplerOutput(outputs=completion_seq_group_output_list)] - - @nvtx_range("spec_decode_worker._run_no_spec") - def _run_no_spec(self, execute_model_req: ExecuteModelRequest, - skip_proposer: bool) -> List[SamplerOutput]: - """Run a single generation step without any speculation. The input is - sent to the proposer and scorer model so that the KV cache is consistent - between the two. When skip_proposer is True, the proposer model is - not called, meaning that the kv-cache in proposer for requests is not - updated, so they cannot enable spec decode in the rest decoding. - """ - - sampler_output = self.scorer_worker.execute_model(execute_model_req) - assert len(sampler_output) == 1 - sampler_output = sampler_output[0] - - # Store hidden states from target model execution, BxD. - hidden_states = sampler_output.hidden_states - if hidden_states is not None: - # Only decodes and prefill terminal chunks need a hidden state. - seq_group_meta_with_hidden = [ - sg for sg in execute_model_req.seq_group_metadata_list - if sg.do_sample - ] - if any(seq.is_prompt for seq in seq_group_meta_with_hidden): - # Drop hidden_states with no prediction (eg non-terminal chunks) - hidden_states = hidden_states[ - torch.where(sampler_output.sampled_token_ids - - VLLM_INVALID_TOKEN_ID)[0]] - if self.previous_hidden_states is None and len( - seq_group_meta_with_hidden): - self.previous_hidden_states = HiddenStates( - hidden_states, seq_group_meta_with_hidden) - elif self.previous_hidden_states and len( - seq_group_meta_with_hidden): - self.previous_hidden_states.update(hidden_states, - seq_group_meta_with_hidden) - self.previous_hidden_states.prune(seq_group_meta_with_hidden) - - if not skip_proposer: - # We prepare the prefill hidden states here so that there no - # additional complexity in worker for spec_decode vs non_spec_decode - # flow and execute_model doesn't need additional modifications. - execute_model_req.previous_hidden_states = \ - prepare_prefill_hidden_states( - sampler_output.prefill_hidden_states) - for i in range(self._num_spec_prefill_steps): - execute_model_req.spec_step_idx = i - self.proposer_worker.execute_model(execute_model_req) - - sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( - execute_model_req=execute_model_req, sampler_output=sampler_output) - if self._disable_logprobs else - [sampler_output]) - - # Clear device tensors from sampler output. This reduces communication - # overhead when the engine runs in a different process than the workers. - sampler_output.sampled_token_probs = None - sampler_output.sampled_token_ids = None - sampler_output.logprobs = None - return sampler_output_to_return - - def _run_non_driver_rank(self) -> bool: - """Run proposer and verifier model in non-driver workers. This is used - for both speculation cases (num_lookahead_slots>0) and non-speculation - cases (e.g. prefill). - - Returns True if there are remaining sequences to process. - """ - assert self.rank != self._driver_rank - - data = broadcast_tensor_dict(src=self._driver_rank) - if not data: - return False - num_lookahead_slots = data["num_lookahead_slots"] - - # In case of prefill, scorer_worker has to be run before proposer so - # that the hidden states can be propagated to proposer when needed. - if data["no_spec"]: - self.scorer_worker.execute_model() - - if not data["disable_all_speculation"]: - # Even if num_lookahead_slots is zero, we want to run the - # proposer model as it may have KV. - # - # We run the proposer once per lookahead slot. In the future we - # should delegate how many times it runs to the proposer. - for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model() - - if not data["no_spec"]: - self.scorer_worker.execute_model() - if data["run_spec_proposer_for_prefill"]: - self.proposer_worker.execute_model() - - return True - - @nvtx_range("spec_decode_worker._run_speculative_decoding_step") - def _run_speculative_decoding_step( - self, execute_model_req: ExecuteModelRequest, - num_lookahead_slots: int) -> List[SamplerOutput]: - """Execute a single step of speculative decoding. - - This invokes the proposer worker to get k speculative tokens for each - sequence, then scores each speculative token using the scoring worker. - - When `enable_chunked_prefill` is set, scorer will batch decodes and - prefills, while proposer will sync its KV-cache by running an extra - forward on prefills. - - Returns a list of SamplerOutput, each containing a single token per - sequence. - """ - # With prefill chunking, expect requests to have prompts first - # so that backend gets prefill|decode. - assert num_lookahead_slots == execute_model_req.num_lookahead_slots - - # Pass last hidden states from target model to proposer - execute_model_req.previous_hidden_states = self.previous_hidden_states - self.previous_hidden_states = None - - with Timer() as proposal_timer: - # Generate proposals using draft worker. - proposals = self.proposer_worker.get_spec_proposals( - execute_model_req, self._seq_with_bonus_token_in_last_step) - - if not self._allow_zero_draft_token_step and proposals.no_proposals: - #TODO: Fix it #5814 - raise RuntimeError("Cannot handle cases where distributed draft " - "workers generate no tokens") - - execute_model_req.previous_hidden_states = None - - with Timer() as scoring_timer: - proposal_scores = self.scorer.score_proposals( - execute_model_req, - proposals, - ) - - _, (non_spec_seqs, non_spec_indices) = split_batch_by_proposal_len( - execute_model_req.seq_group_metadata_list, proposals.proposal_lens) - # With prefill chunking enabled, `non_spec_seqs` contains prefills too: - # discard decodes that have already been processed by proposer. - non_spec_indices = [ - idx for idx in non_spec_indices - if execute_model_req.seq_group_metadata_list[idx].is_prompt - ] - if len(non_spec_indices): - all_hidden_states = proposal_scores.hidden_states - if all_hidden_states is not None: - prefill_hidden_states = all_hidden_states[non_spec_indices] - execute_model_req.previous_hidden_states = \ - prepare_prefill_hidden_states(prefill_hidden_states) - # Sync proposer KV cache for prefills. - prefill_req = execute_model_req.clone(non_spec_seqs) - # TODO avoid sampling here? - self.proposer_worker.execute_model(prefill_req) - - with Timer() as verification_timer: - accepted_token_ids, target_logprobs = self._verify_tokens( - execute_model_req.seq_group_metadata_list, proposal_scores, - proposals, execute_model_req.num_lookahead_slots) - - stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, - scoring_timer.elapsed_time_ms, - verification_timer.elapsed_time_ms) - - return self._create_output_sampler_list( - execute_model_req.seq_group_metadata_list, - accepted_token_ids, - target_logprobs=target_logprobs, - prompt_logprobs=proposal_scores.prompt_logprobs - if not self._disable_logprobs else None, - k=execute_model_req.num_lookahead_slots, - stage_times=stage_times) - - @nvtx_range("spec_decode_worker._verify_tokens") - def _verify_tokens( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_scores: SpeculativeScores, - proposals: SpeculativeProposals, - max_proposal_len: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Determine which speculative tokens are accepted using the - probabilities of each token according to the proposer and scorer models. - - Returns a tuple of Tensors, one for the accepted token ids and one for - the logprobs according to the scoring model. - """ - proposal_lens_list = proposals.proposal_lens.tolist() - - # vLLM currently only supports proposal lens equal to zero or the batch - # proposal len. This adds some complexity (splitting the batch into spec - # and non spec sequences) and should be removed in the future. It can be - # done by supporting per-sequence proposal lens. - (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( - seq_group_metadata_list, proposal_lens_list) - original_indices = spec_indices + non_spec_indices - - # Get probabilities of target model, including bonus tokens. - proposal_verifier_probs = proposal_scores.probs[spec_indices] - - # Get non-speculative sampled tokens from target model. - non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] - - # Get bonus tokens from target model. - bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] - - # Get probabilities according to proposal method. - proposal_probs = proposals.proposal_probs[spec_indices] - - # Get proposed tokens. - proposal_token_ids = proposals.proposal_token_ids[spec_indices] - - # Sampler arguments - sampler_extra_kwargs: Dict[str, Any] = {} - if self.generators and isinstance(self.spec_decode_sampler, - SpecDecodeStochasticBaseSampler): - sampler_extra_kwargs["seeded_seqs"] = { - idx: self.generators[sgm.request_id] - for idx, sgm in enumerate(seq_group_metadata_list) - if sgm.sampling_params.seed is not None - } - - accepted_token_ids = self.spec_decode_sampler( - target_with_bonus_probs=proposal_verifier_probs, - bonus_token_ids=bonus_token_ids, - draft_probs=proposal_probs, - draft_token_ids=proposal_token_ids, - **sampler_extra_kwargs, - ) - # Append output tokens from non-speculative sequences to - # the accepted token ids tensor. - non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + - 1).clone() - non_spec_token_ids[:, 1:] = -1 - accepted_token_ids = torch.cat( - [accepted_token_ids, non_spec_token_ids]) - logprobs = proposal_scores.logprobs - # Rearrange so that results are in the order of the original seq group - # metadata. - accepted_token_ids[original_indices] = accepted_token_ids.clone() - - # B x K+1 x D - hidden_states = proposal_scores.hidden_states - if hidden_states is not None: - # Only get terminal hidden states for next step - terminal_metadata = [ - sg for sg in seq_group_metadata_list if sg.do_sample - ] - - # Contract hidden states based on accepted tokens - hs_size = hidden_states.shape[-1] - accepted_index = accepted_token_ids + 1 # Convert -1 to 0 - accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) # b - # Drop non-terminal prefill chunks hidden states. - hidden_states = hidden_states[accepted_index != - VLLM_INVALID_TOKEN_ID] - accepted_index = accepted_index[accepted_index != - VLLM_INVALID_TOKEN_ID] - assert len(accepted_index) == hidden_states.shape[0] == len( - terminal_metadata) - index = accepted_index[:, None, None].expand(-1, 1, - hs_size) # b x 1 x d - second_last_token_hidden_states = hidden_states[:, -2] # b x d - hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d - # Store hidden states from target model for subsequent decode step - self.previous_hidden_states = HiddenStates( - hidden_states, terminal_metadata, - second_last_token_hidden_states) - return accepted_token_ids, logprobs - - def _create_output_sampler_list( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] - target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] - prompt_logprobs: Optional[ - torch.Tensor], # shape: [nprompt_tokens, vocab_size] - k: int, - stage_times: Tuple[float, float, float], - ) -> List[SamplerOutput]: - """Given the accepted token ids, create a list of SamplerOutput. - - The output is padded with -1 tokens such that each sequence has - the same number of outputs. - """ - batch_size, num_steps = accepted_token_ids.shape - accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) - if self._disable_logprobs: - # We are skipping the logprobs. Hence don't serialize the - # logprobs related tensors from the GPU. Instead create - # empty/dummy lists. - (accepted_token_id_ranks_by_step, - accepted_token_id_logprobs_by_step, - topk_logprobs_by_step, topk_indices_by_step) =\ - self._create_dummy_logprob_lists( - batch_size, num_steps, - self.scorer_worker.model_config.max_logprobs) - else: - # Organize input tensors by step instead of by sequence. - target_logprobs_by_step = target_logprobs.transpose(0, 1) - # Serialize all tensors into Python lists. - (accepted_token_id_ranks_by_step, - accepted_token_id_logprobs_by_step, - topk_logprobs_by_step, topk_indices_by_step) =\ - self._create_logprob_lists_from_tensors( - target_logprobs_by_step, accepted_token_ids_by_step, - self.scorer_worker.model_config.max_logprobs) - - # Get the sequence ids and num_logprobs (sampling parameter) in the - # batch. - seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( - seq_group_metadata_list) - - num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) - - # Serialize tensor to CPU Python list. - accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() - - # Construct the output on a per-step, per-sequence basis. - # Non-terminal prefill chunks will end up here as rows with just -1s - # i.e mixed-batch [[-1, 1576], [-1, 29884], [-1, -1], [-1, -1]] while - # terminal chunks will only have one generated token at time 0. - sampler_output_list: List[SamplerOutput] = [] - - # Prefills are not multi-step (return at most 1 token), in order to - # avoid padding or repetition to fit decodes, we separate them. - for i, sg in enumerate(seq_group_metadata_list): - if not sg.is_prompt: - # Requests are ordered as prefills|decodes=>no more prefills. - break - num_logprobs = num_logprobs_per_seq[i] - seq_kwargs = dict(token_id=-1, - token_id_logprob_rank=0, - token_id_logprob=-float('inf'), - topk_token_ids=[-1] * num_logprobs, - topk_logprobs=[-float('inf')] * num_logprobs, - seq_id=seq_ids[i]) - # Terminal chunk, has token. - if sg.do_sample: - seq_kwargs.update( - dict( - token_id=accepted_token_ids[i][0].item(), - token_id_logprob_rank=accepted_token_id_ranks_by_step[ - 0][i], - token_id_logprob=accepted_token_id_logprobs_by_step[0] - [i], - topk_token_ids=topk_indices_by_step[0][i] - [:num_logprobs], - # output only so step is 0 - topk_logprobs=topk_logprobs_by_step[0][i] - [:num_logprobs], - )) - needs_plogs = (sg.sampling_params.prompt_logprobs - and sg.sampling_params.prompt_logprobs > 0) - plogs = None - if prompt_logprobs is not None: - # Even non-terminal prompt chunks can have logprobs here. - plogs = prompt_logprobs[i] - elif needs_plogs: - # Prompt logprobs are requested but `_disable_logprobs` is set. - seq_data = next(iter(sg.seq_data.values())) - # Get only the tokens in this chunk! - prompt_token_ids = seq_data.get_prompt_token_ids() - prompt_token_ids = prompt_token_ids[ - seq_data. - _num_computed_tokens:seq_data._num_computed_tokens + - sg.token_chunk_size] - - is_first_chunk = seq_data._num_computed_tokens == 0 - # There's no prob generated for the first token in a sequence. - if is_first_chunk: - prompt_token_ids = prompt_token_ids[1:] - plogs = [ - create_logprobs_output( - token_id=p_token_id, - token_id_logprob_rank=-1, - token_id_logprob=0.0, - topk_token_ids=[], - topk_logprobs=[], - ) for p_token_id in prompt_token_ids - ] - seq_kwargs.update(dict(prompt_logprobs=plogs)) - - sampler_output_list.append( - SamplerOutput( - outputs=[create_sequence_group_output( - **seq_kwargs)])) # type: ignore - - # Decodes, create one SamplerOutput per-step (at most K+1). - for step_index in range(num_steps): - if all(token_id == -1 for sg, token_id in zip( - seq_group_metadata_list, - accepted_token_ids_by_step[step_index]) - if not sg.is_prompt): - break - - step_output_token_ids: List[CompletionSequenceGroupOutput] = [] - for sequence_index in range(batch_size): - seq_meta = seq_group_metadata_list[sequence_index] - # Prompts already processed above. - if seq_meta.is_prompt: - continue - - # Each sequence may have a different num_logprobs; retrieve it. - num_logprobs = num_logprobs_per_seq[sequence_index] - step_output_token_ids.append( - create_sequence_group_output( - token_id=accepted_token_ids_by_step[step_index] - [sequence_index], - token_id_logprob_rank=accepted_token_id_ranks_by_step[ - step_index][sequence_index], - token_id_logprob=accepted_token_id_logprobs_by_step[ - step_index][sequence_index], - seq_id=seq_ids[sequence_index], - topk_token_ids=topk_indices_by_step[step_index] - [sequence_index][:num_logprobs], - topk_logprobs=topk_logprobs_by_step[step_index] - [sequence_index][:num_logprobs], - step_index=step_index)) - sampler_output_list.append( - SamplerOutput(outputs=step_output_token_ids)) - - # Populate the data structures needed to keep track of sequences with - # bonus tokens. - self._track_sequences_with_bonus_tokens(seq_ids, - request_ids_seq_ids_mapping, - accepted_token_ids_by_step) - maybe_rejsample_metrics = ( - self._metrics.maybe_collect_rejsample_metrics(k)) - if maybe_rejsample_metrics is not None: - sampler_output_list[ - 0].spec_decode_worker_metrics = maybe_rejsample_metrics - - # Log time spent in each stage periodically. - # This is periodic because the rejection sampler emits metrics - # periodically. - self._maybe_log_stage_times(*stage_times) - # First `n_prefills` entries will contain prefills SamplerOutput when - # chunked prefill is enabled, the rest is decodes in multi-step format. - return sampler_output_list - - def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, - scoring_time_ms: float, - verification_time_ms: float) -> None: - """Log the speculative stage times. If stat logging is disabled, do - nothing. - """ - if self._disable_log_stats: - return - - logger.info( - "SpecDecodeWorker stage times: " - "average_time_per_proposal_tok_ms=%.02f " - "scoring_time_ms=%.02f verification_time_ms=%.02f", - average_time_per_proposal_tok_ms, scoring_time_ms, - verification_time_ms) - - def _create_dummy_logprob_lists( - self, - batch_size: int, - num_steps: int, - num_top_k: int, - ) -> Tuple[List[List[int]], List[List[float]], - List[List[List[Optional[float]]]], - List[List[List[Optional[int]]]]]: - """ - Creates and returns four dummy lists representing token probabilities - and their ranks. - - This method initializes and returns: - - The ranks of the accepted tokens, shaped (num_steps, batch_size) - - The log probabilities of the accepted tokens, - shaped (num_steps, batch_size) - - The log probabilities of the top k tokens, - shaped (num_steps, batch_size, num_top_k) - - The token IDs of the top k tokens, - shaped (num_steps, batch_size, num_top_k) - - Args: - batch_size (int): The size of the batch. - num_steps (int): The number of steps in the sequence. - num_top_k (int): The number of top-k token log probabilities to - return. - - Returns: - A tuple containing four dummy lists as described above. - """ - accepted_token_id_ranks_by_step = [[-1] * batch_size - for _ in range(num_steps)] - accepted_token_id_logprobs_by_step = [[0.0] * batch_size - for _ in range(num_steps)] - topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[ - [None] * num_top_k for _ in range(batch_size) - ] for _ in range(num_steps)] - topk_indices_by_step: List[List[List[Optional[int]]]] = [[ - [None] * num_top_k for _ in range(batch_size) - ] for _ in range(num_steps)] - return (accepted_token_id_ranks_by_step, - accepted_token_id_logprobs_by_step, topk_logprobs_by_step, - topk_indices_by_step) - - def _create_logprob_lists_from_tensors( - self, - target_logprobs_by_step: torch.Tensor, - accepted_token_ids_by_step: torch.Tensor, - num_top_k: int, - ) -> Tuple[List[List[int]], List[List[float]], - List[List[List[Optional[float]]]], - List[List[List[Optional[int]]]]]: - """ - Creates and returns four lists representing token probabilities and - their ranks. - - This method initializes and returns four lists containing: - - The ranks of the accepted tokens, shaped (num_steps, batch_size) - - The log probabilities of the accepted tokens, - shaped (num_steps, batch_size) - - The log probabilities of the top k tokens, - shaped (num_steps, batch_size, num_top_k) - - The token IDs of the top k tokens, - shaped (num_steps, batch_size, num_top_k) - - Args: - target_logprobs_by_step (torch.Tensor): Tensor representing the - log probabilities of the target model, - shaped (num_steps, batch_size, vocab_size) - accepted_token_ids_by_step (torch.Tensor): Tensor representing - the accepted token_ids, shaped (num_steps, batch_size) - num_top_k (int): The number of top-k token log probabilities to - return. - - Returns: - A tuple containing the lists as described above. - """ - # Serialize all tensors to CPU Python lists. - # Get the logprobs/rank of the accepted tokens. - (accepted_token_id_ranks_by_step_tensor, - accepted_token_id_logprobs_by_step_tensor - ) = get_sampled_token_logprobs( - logprob_tensor=target_logprobs_by_step, - sampled_token_ids=accepted_token_ids_by_step, - ) - # Get the top-k logprobs (which may or may not include the - # logprob of the accepted token). - (topk_logprobs_by_step_tensor, - topk_indices_by_step_tensor) = target_logprobs_by_step.topk( - k=num_top_k, - dim=-1, - ) - accepted_token_id_ranks_by_step = ( - accepted_token_id_ranks_by_step_tensor.tolist()) - accepted_token_id_logprobs_by_step = ( - accepted_token_id_logprobs_by_step_tensor.tolist()) - topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist() - topk_indices_by_step = topk_indices_by_step_tensor.tolist() - return (accepted_token_id_ranks_by_step, - accepted_token_id_logprobs_by_step, topk_logprobs_by_step, - topk_indices_by_step) - - def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): - """ - Removes the finished requests and their associated sequence ids from - internal book keeping data structures. - """ - for finished_request in execute_model_req.finished_requests_ids: - for seq_id in self._request_id_seq_id_mapping[finished_request]: - self._seq_with_bonus_token_in_last_step.discard(seq_id) - del self._request_id_seq_id_mapping[finished_request] - - def _track_sequences_with_bonus_tokens( - self, seq_ids: List[int], - request_ids_seq_ids_mapping: Dict[str, Set[int]], - accepted_token_ids_by_step: List[List[int]]): - """ - Updates the internal data structures which keep track of sequences - which have been assigned bonus tokens in their last forward pass. - """ - for seq_index, seq_id in enumerate(seq_ids): - last_token_id = accepted_token_ids_by_step[-1][seq_index] - if last_token_id == -1: - self._seq_with_bonus_token_in_last_step.discard(seq_id) - else: - self._seq_with_bonus_token_in_last_step.add(seq_id) - for request_id, sequences in request_ids_seq_ids_mapping.items(): - self._request_id_seq_id_mapping[request_id].update(sequences) - - @cached_property - def _vocab_size(self) -> int: - """Get the vocab size of the model and make sure it's consistent between - draft and target workers. - """ - vocab_sizes = [ - worker.vocab_size - for worker in [self.proposer_worker, self.scorer_worker] - ] - assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) - return vocab_sizes[0] - - @property - def rank(self): - return self.scorer_worker.rank - - @property - def device(self): - return self.scorer_worker.device - - @property - def _driver_rank(self) -> int: - return 0 - - def get_cache_block_size_bytes(self): - """Return the size of a cache block in bytes. - - This function is only used to compose workers within a SpecDecodeWorker. - We leave composing a SpecDecodeWorker within a SpecDecodeWorker - undefined for now, although it could be implemented in the future. - See https://arxiv.org/abs/2308.04623. - """ - raise NotImplementedError - - def start_profile(self): - if isinstance(self.scorer_worker, WorkerBase): - self.scorer_worker.start_profile() - - def stop_profile(self): - if isinstance(self.scorer_worker, WorkerBase): - self.scorer_worker.stop_profile() - - -def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, - proposer_cache_block_size_bytes: int, - total_num_gpu_blocks: int) -> int: - """Given total_num_gpu_blocks, the number of GPU blocks that could be - allocate to the target model, this function calculates how many blocks - should be given to the draft and target model. - - Note that usually the block size, in bytes, of each model is different, - as it's a function of number of KV/layer, number of heads, and hidden - dimension size. - - Since the target and draft models allocate the same number of blocks, we - simply calculate the number of blocks where if allocated by both models, - the total memory usage from KV cache is no larger than the number of - blocks allocatable by the target model alone. - """ - new_num_gpu_blocks = int( - total_num_gpu_blocks * scorer_cache_block_size_bytes / - (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) - - return new_num_gpu_blocks - - -def prepare_prefill_hidden_states( - prefill_hidden_states: torch.Tensor) -> HiddenStates: - # For prefill step in proposer, we run the model for N-1 tokens - # because Nth token will be processed in the first decode step. For - # N-1 tokens, the input should be 0:N-1 hidden states which should - # be concatanated with 1:N token (since output of scorer has to be - # the input for proposer). Therefore, we shift the hidden states to - # align n-1th hidden state with nth token. - return HiddenStates(prefill_hidden_states.roll( - shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py deleted file mode 100644 index ca89eb60ac58..000000000000 --- a/vllm/spec_decode/target_model_runner.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional - -from vllm.sequence import SequenceGroupMetadata -from vllm.worker.model_runner_base import (ModelRunnerBase, - ModelRunnerInputBase, - ModelRunnerWrapperBase) - - -class TargetModelRunner(ModelRunnerWrapperBase): - """Specialized model runner for speculative decoding target model. - In speculative decoding, the log probabilities selected finally may not - be the same ones as selected by the target model sampling. This means - that the time spent in the log probability calculation of the target model - is time wasted, since we calculate log probabilities after deciding which - tokens are accepted. For this reason disabling log probabilities in the - target model will make decode faster. The model runner sets the - SamplingMetadata parameters according to whether log probabilities are - requested or not. - """ - - def __init__(self, model_runner: ModelRunnerBase): - # An internal boolean member variable to indicate if token log - # probabilities are needed or not. - super().__init__(model_runner) - self.disable_logprobs = True - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelRunnerInputBase: - model_input: ModelRunnerInputBase =\ - self.model_runner.prepare_model_input( - seq_group_metadata_list, virtual_engine, finished_requests_ids) - # If token log probabilities is disabled then skip generating sampler - # CPU output. We directly serialize the GPU sampled_token_id tensors - # as needed. If log probabilities is enabled then synchronize all the - # sampling related tensors which includes the logprobs tensors. - model_input.sampling_metadata.skip_sampler_cpu_output = ( - self.disable_logprobs) - return model_input diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py deleted file mode 100644 index afd91b42b943..000000000000 --- a/vllm/spec_decode/top1_proposer.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import List, Optional, Set, Tuple - -import torch - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata -from vllm.spec_decode.interfaces import (SpeculativeProposals, - SpeculativeProposer) -from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase -from vllm.spec_decode.util import sampler_output_to_torch - - -class Top1Proposer(SpeculativeProposer): - """Helper class which separates out sequences which would exceed the max - model length when speculated upon. - - This allows combinations of models such as JackFram/llama-68m draft with - meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of - 2048 while Llama2-13b has max_position_embeddings of 4096. - - We treat the sequences which exceed the proposal draft model length as - "non-spec sequences". Essentially they skip the draft model and go through - normal decoding in the target model. - - Currently, only proposal_lens of 0 and k are supported, where k is a global - batch proposal length. In the future vLLM should support per-sequence - proposal lengths. - """ - - def __init__( - self, - worker: ProposerWorkerBase, - device: str, - vocab_size: int, - max_proposal_len: Optional[int] = None, - ): - self._worker = worker - self._device = device - self.max_proposal_len = max_proposal_len - self._vocab_size = vocab_size - - def get_spec_proposals( - self, - execute_model_req: ExecuteModelRequest, - seq_ids_with_bonus_token_in_last_step: Set[int], - ) -> SpeculativeProposals: - """Get speculative proposals given the input batch. - - Sequences which would exceed the max model length are skipped during - speculation. - """ - proposal_len = execute_model_req.num_lookahead_slots - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - - # Split speculative- and non-speculative- sequences. - ( - proposal_lens, - nonzero_proposal_len_seqs, - nonzero_proposal_len_indices, - ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) - - if nonzero_proposal_len_seqs: - # Speculate tokens using the draft worker for the speculative - # sequences. - # If sampler_transposed is true, then maybe_sampler_output's - # token_ids is like [batch] format in proposal_len size list, - # while if it is false, the format would be [proposal_len] - # in batch size list - hidden_states = execute_model_req.previous_hidden_states - if hidden_states is not None: - hidden_states.prune(nonzero_proposal_len_seqs) - nonzero_execute_model_req = ExecuteModelRequest( - seq_group_metadata_list=nonzero_proposal_len_seqs, - num_lookahead_slots=proposal_len, - previous_hidden_states=hidden_states, - ) - maybe_sampler_output, transposed = self._worker.sampler_output( - execute_model_req=nonzero_execute_model_req, - sample_len=proposal_len, - seq_ids_with_bonus_token_in_last_step=\ - seq_ids_with_bonus_token_in_last_step, - ) - ( - proposal_lens, - maybe_sampler_output, - nonzero_proposal_len_indices, - ) = self._remove_no_proposal_seqs(proposal_lens, - maybe_sampler_output, - nonzero_proposal_len_indices, - transposed) - else: - # If no sequences can be speculated, set sampler output to None. - maybe_sampler_output = None - transposed = False - - # Combine speculative- and non-speculative sequences into the same - # representation. - proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( - batch_size=len(seq_group_metadata_list), - proposal_len=proposal_len, - maybe_sampler_output=maybe_sampler_output, - proposal_lens=proposal_lens, - nonzero_proposal_len_indices=nonzero_proposal_len_indices, - sampler_transposed=transposed, - ) - - proposals = SpeculativeProposals(proposal_token_ids=proposal_tokens, - proposal_probs=proposal_probs, - proposal_lens=proposal_lens, - no_proposals=maybe_sampler_output - is None) - return proposals - - def _split_by_proposal_len( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_len: int, - ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: - """Split sequences by two groups: - 1. Sequences with non-zero proposal length. - 2. Sequences with zero proposal length (due to disabled speculation - or exceed the maximum model length). - """ - - proposal_lens: List[int] = [] - nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] - nonzero_proposal_len_indices: List[int] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - # The speculative decoding for this request has either been disabled - # (e.g. due to high traffic) or this is a prompt request. - if (seq_group_metadata.is_prompt - or seq_group_metadata.num_speculative_tokens == 0): - proposal_lens.append(0) - continue - - seq_data = next(iter(seq_group_metadata.seq_data.values())) - seq_len = seq_data.get_len() - - # Currently only proposal lens of 0 or the global batch proposal len - # are supported. - # If max_proposal_len is defined, then we shall not exceed this - # quota for nonzero_proposal - new_k = 0 - if (self.max_proposal_len is None - or seq_len + proposal_len < self.max_proposal_len): - new_k = proposal_len - nonzero_proposal_len_seqs.append(seq_group_metadata) - nonzero_proposal_len_indices.append(i) - proposal_lens.append(new_k) - seq_group_metadata.num_speculative_tokens = new_k - - return ( - proposal_lens, - nonzero_proposal_len_seqs, - nonzero_proposal_len_indices, - ) - - @staticmethod - def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, - nonzero_proposal_len_indices, transposed): - """Remove sequences from nonzero_proposal_len_indices and reset - their proposal_len to 0 the draft worker does not provide a proposal - (maybe_sampler_output=None). This can avoid scoring overheads. - """ - - # If maybe_sampler_output is None, then the draft worker did not - # provide a proposal for any sequence and thus no action needed. - # Also we do not support transposed maybe_sampler_output for now - # because it seems not straightforward for draft workers outputting - # transposed sampler outputs to handle the case of no proposal. - if maybe_sampler_output is None or transposed: - return (proposal_lens, maybe_sampler_output, - nonzero_proposal_len_indices) - - new_proposal_lens: List[int] = [] - new_nonzero_proposal_len_indices: List[int] = [] - new_maybe_sampler_output: List[SamplerOutput] = [] - nonzero_proposal_len_idx_ptr = 0 - seq_idx = 0 - while seq_idx < len( - proposal_lens) and nonzero_proposal_len_idx_ptr < len( - nonzero_proposal_len_indices): - if seq_idx < nonzero_proposal_len_indices[ - nonzero_proposal_len_idx_ptr]: - # Sequence is not in the original nonzero_proposal_len_indices, - # meaning that it has a proposal length of 0 before sending to - # the draft worker. - assert proposal_lens[seq_idx] == 0 - new_proposal_lens.append(0) - else: - # Sequence is in the original nonzero_proposal_len_indices - if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: - # but does not have a proposal from the draft worker. - new_proposal_lens.append(0) - else: - # and has a proposal from the draft worker. Add it to the - # new nonzero proposal list and keep the sampler output. - new_proposal_lens.append(proposal_lens[seq_idx]) - new_nonzero_proposal_len_indices.append(seq_idx) - new_maybe_sampler_output.append( - maybe_sampler_output[nonzero_proposal_len_idx_ptr]) - nonzero_proposal_len_idx_ptr += 1 - seq_idx += 1 - - # The remaining sequences should have proposal length of 0. - new_proposal_lens.extend(proposal_lens[seq_idx:]) - - # We assume sampler_output will not be a list of all Nones. - # In this case this function should not be called. - assert new_maybe_sampler_output - return (new_proposal_lens, new_maybe_sampler_output, - new_nonzero_proposal_len_indices) - - def _merge_outputs( - self, - batch_size: int, - proposal_len: int, - maybe_sampler_output: Optional[List[SamplerOutput]], - proposal_lens: List[int], - nonzero_proposal_len_indices: List[int], - sampler_transposed: bool, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """After speculations are produced, merge the speculation results with - the skipped sequences. - """ - if maybe_sampler_output is None: - # If no speculative tokens, the sampler output will be None. - # In this case we return empty proposals. - proposal_tokens = torch.tensor(-1, - dtype=torch.long, - device=self._device).expand( - batch_size, proposal_len) - proposal_probs = torch.tensor(0, - dtype=torch.float32, - device=self._device).expand( - batch_size, proposal_len, - self._vocab_size) - proposal_lens_tensor = torch.tensor(0, - dtype=torch.long, - device=self._device).expand( - len(proposal_lens)) - return proposal_tokens, proposal_probs, proposal_lens_tensor - - sampler_output = maybe_sampler_output - proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( - sampler_output, sampler_transposed) - - # Now, reformat the output GPU tensors such that each sequence has - # a proposal. the proposal can be empty, e.g. [-1, -1, -1] - - entire_proposal_tokens = proposal_tokens.new_full( - size=(batch_size, *proposal_tokens.shape[1:]), - fill_value=-1, - ) - entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens - entire_proposal_probs = proposal_probs.new_zeros( - batch_size, - *proposal_probs.shape[1:], - ) - entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs - - proposal_tokens, proposal_probs = ( - entire_proposal_tokens, - entire_proposal_probs, - ) - - proposal_lens_tensor = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len - - return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py deleted file mode 100644 index 22d2a4833acf..000000000000 --- a/vllm/spec_decode/util.py +++ /dev/null @@ -1,277 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import time -from contextlib import contextmanager -from typing import Dict, List, Optional, Sequence, Tuple - -import torch - -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.platforms import current_platform -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SequenceGroupMetadata, - SequenceOutput) - -SeqId = int - - -def get_all_num_logprobs( - seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: - """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. - - If the sampling params do not call for any logprobs, return 0 for that - sequence. - """ - - all_num_logprobs: List[int] = [] - for seq_group_metadata in seq_group_metadata_list: - num_logprobs = seq_group_metadata.sampling_params.logprobs - if num_logprobs is None: - num_logprobs = 0 - all_num_logprobs.append(num_logprobs) - - return all_num_logprobs - - -def get_sampled_token_logprobs( - # shape [num_steps, batch_size, vocab_size] - logprob_tensor: torch.Tensor, - sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] -) -> Tuple[torch.Tensor, torch.Tensor]: - """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. - """ - num_steps, batch_size, vocab_size = logprob_tensor.shape - - selected_logprobs = logprob_tensor[ - torch.arange(num_steps).unsqueeze(1), - torch.arange(batch_size), - sampled_token_ids, - ] - expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( - -1, -1, vocab_size) - sampled_token_ids_ranks = (logprob_tensor - > expanded_selected_logprobs).sum(-1).add_(1) - - return sampled_token_ids_ranks, selected_logprobs - - -def create_logprobs_output( - token_id: int, - token_id_logprob_rank: int, - token_id_logprob: float, - topk_token_ids: List[Optional[int]], - topk_logprobs: List[Optional[float]], -) -> Dict[int, Logprob]: - """Create a Logprob Dict for a token given the sampling results. - - Args: - token_id (int): The sampled token for the sequence. - token_id_logprob_rank (int): The logprob rank of the sampled token. - token_id_logprob (float): The logprob value of the sampled token. - topk_token_ids (List[Optional[int]]): The list of top-k token ids. - topk_logprobs (List[Optional[float]]): The list of top-k logprobs. - """ - # vLLM logprobs always include the sampled token. In addition, the user may - # request topk-logprobs (where top-k varies per user up to max_logprobs). - logprobs: Dict[int, Logprob] = { - token_id: Logprob( - logprob=token_id_logprob, - rank=token_id_logprob_rank, - ), - } - logprobs.update({ - topk_token_id: Logprob( - logprob=topk_logprob if topk_logprob is not None else 0.0, - rank=topk_index + 1, - ) - for topk_index, (topk_token_id, topk_logprob) \ - in enumerate(zip(topk_token_ids, topk_logprobs)) \ - if topk_token_id is not None - }) - - return logprobs - - -def create_sequence_group_output( - token_id: int, - token_id_logprob_rank: int, - token_id_logprob: float, - seq_id: SeqId, - topk_token_ids: List[Optional[int]], - topk_logprobs: List[Optional[float]], - prompt_logprobs: Optional[PromptLogprobs] = None, - step_index: Optional[int] = 0) -> CompletionSequenceGroupOutput: - """Create a SequenceGroupOutput given the sampling results. - - Args: - token_id (int): The sampled token for the sequence. - token_id_logprob_rank (int): The logprob rank of the sampled token. - token_id_logprob (float): The logprob value of the sampled token. - seq_id (int): The sequence id. - topk_token_ids (List[Optional[int]]): The list of top-k token ids. - topk_logprobs (List[Optional[float]]): The list of top-k logprobs. - step_index: (Optional[int]): The index of the speculative token. - """ - - logprobs = create_logprobs_output( - token_id, - token_id_logprob_rank, - token_id_logprob, - topk_token_ids, - topk_logprobs, - ) - - return CompletionSequenceGroupOutput(samples=[ - SequenceOutput(parent_seq_id=seq_id, - output_token=token_id, - logprobs=logprobs) - ], - prompt_logprobs=prompt_logprobs, - step_index=step_index) - - -def split_batch_by_proposal_len( - seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_lens: List[int], -) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[ - List[SequenceGroupMetadata], List[int]]]: - """Utility function that splits a batch based on whether the proposal len is - zero or not. We should remove this once vLLM supports per-sequence proposal - lens in a batch. - """ - - nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) - zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) - for i, (seq_group, proposal_len) in enumerate( - zip(seq_group_metadata_list, proposal_lens)): - seq_groups, indices = nonzero_lists if proposal_len else zero_lists - seq_groups.append(seq_group) - indices.append(i) - return nonzero_lists, zero_lists - - -def sampler_output_to_torch( - sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """Utility function which converts a list of SamplerOutput to tensors. - - sampler_transposed here is used as the indicator for whether - we need do additional tensor transpose logic here. - - Returns: - sampled_token_ids: torch.Tensor - shape: [batch_size, len(sampler_output_list)] - - sampled_token_probs: torch.Tensor - shape: [batch_size, len(sampler_output_list), vocab_size] - """ - - # shape: [batch_size, num_sampler_output, vocab_size] - sampled_token_probs = torch.stack( - [ - sampler_output.sampled_token_probs - for sampler_output in sampler_output_list - ], - dim=0, - ) - - # shape: [batch_size, num_sampler_output, vocab_size] - sampled_token_logprobs = torch.stack( - [sampler_output.logprobs for sampler_output in sampler_output_list], - dim=0, - ) - - # shape: [batch_size, num_sampler_output] - sampled_token_ids = torch.stack( - [ - sampler_output.sampled_token_ids.flatten() - for sampler_output in sampler_output_list - ], - dim=0, - ) - - if sampler_transposed: - sampled_token_probs = sampled_token_probs.transpose(0, 1) - sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) - sampled_token_ids = sampled_token_ids.transpose(0, 1) - - if sampler_output_list[0].hidden_states is not None: - # shape: [batch_size, num_sampler_output, hidden_dim] - sampled_hidden_states = torch.stack( - [ - sampler_output.hidden_states - for sampler_output in sampler_output_list - ], - dim=0, - ) - - if sampler_transposed: - sampled_hidden_states = sampled_hidden_states.transpose(0, 1) - else: - sampled_hidden_states = None - - return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, - sampled_hidden_states) - - -def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, - vocab_size: int, device: str) -> None: - """Helper method which mocks out the GPU tensors in SamplerOutput with dummy - values. This will be removed in PR 7/9. - https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer - """ - values = [ - sampler_output.sampled_token_probs, sampler_output.sampled_token_ids - ] - assert all(v is None for v in values) or not any(v is None for v in values) - if not any(v is None for v in values): - # Do nothing if the tensors are already created (usually in unit tests). - return - - # Softmax to ensure valid probs. - sampler_output.sampled_token_probs = torch.nn.functional.softmax( - torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), - dim=-1) - - sampler_output.sampled_token_ids = torch.randint(low=10, - high=100, - size=(batch_size, ), - dtype=torch.long, - device=device) - - -@contextmanager -def nvtx_range(msg, *args, **kwargs): - """ - Context manager / decorator that pushes an NVTX range at the beginning - of its scope, and pops it at the end. If extra arguments are given, - they are passed as arguments to msg.format(). - - If running with cuda graphs, you must enable nsys cuda graph profiling. - - Arguments: - msg (string): message to associate with the range - """ - if current_platform.is_cuda_alike(): - torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) - try: - yield - finally: - torch.cuda.nvtx.range_pop() - else: - yield - - -class Timer: - """Basic timer context manager for measuring CPU time. - """ - - def __enter__(self): - self.start_time = time.time() - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.end_time = time.time() - self.elapsed_time_s = self.end_time - self.start_time - self.elapsed_time_ms = self.elapsed_time_s * 1000 diff --git a/vllm/test_utils.py b/vllm/test_utils.py index c6b126d002b2..1e61ca6b3dea 100644 --- a/vllm/test_utils.py +++ b/vllm/test_utils.py @@ -10,7 +10,7 @@ "allenai/OLMoE-1B-7B-0924-Instruct", "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test", "AMead10/Llama-3.2-1B-Instruct-AWQ", - "ArthurZ/Ilama-3.2-1B", + "hmellor/Ilama-3.2-1B", "BAAI/bge-base-en-v1.5", "BAAI/bge-multilingual-gemma2", "BAAI/bge-reranker-v2-m3", diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 9ccde292974c..8d1f59e6eadf 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -7,7 +7,7 @@ import time from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Literal, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union import huggingface_hub from huggingface_hub import get_safetensors_metadata, hf_hub_download @@ -17,7 +17,6 @@ HFValidationError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) -from torch import nn from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -32,18 +31,20 @@ # yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, DbrxConfig, DeepseekVLV2Config, - EAGLEConfig, ExaoneConfig, - JAISConfig, KimiVLConfig, - MedusaConfig, MiniMaxText01Config, + EAGLEConfig, Exaone4Config, + ExaoneConfig, JAISConfig, + KimiVLConfig, MedusaConfig, + MiniMaxText01Config, MiniMaxVL01Config, MllamaConfig, MLPSpeculatorConfig, MPTConfig, + Nemotron_Nano_VL_Config, NemotronConfig, NVLM_D_Config, OvisConfig, RWConfig, SkyworkR1VChatConfig, SolarConfig, Telechat2Config, UltravoxConfig) # yapf: enable +from vllm.transformers_utils.configs.mistral import adapt_config_dict from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import resolve_obj_by_qualname if envs.VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -80,6 +81,7 @@ def _get_hf_token() -> Optional[str]: "dbrx": DbrxConfig, "deepseek_vl_v2": DeepseekVLV2Config, "kimi_vl": KimiVLConfig, + "Llama_Nemotron_Nano_VL": Nemotron_Nano_VL_Config, "mpt": MPTConfig, "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) @@ -88,6 +90,7 @@ def _get_hf_token() -> Optional[str]: "medusa": MedusaConfig, "eagle": EAGLEConfig, "exaone": ExaoneConfig, + "exaone4": Exaone4Config, "minimax_text_01": MiniMaxText01Config, "minimax_vl_01": MiniMaxVL01Config, "nemotron": NemotronConfig, @@ -304,6 +307,9 @@ def get_config( revision: Optional[str] = None, code_revision: Optional[str] = None, config_format: ConfigFormat = ConfigFormat.AUTO, + hf_overrides_kw: Optional[dict[str, Any]] = None, + hf_overrides_fn: Optional[Callable[[PretrainedConfig], + PretrainedConfig]] = None, **kwargs, ) -> PretrainedConfig: # Separate model folder from file path for GGUF models @@ -394,7 +400,16 @@ def get_config( config = _maybe_remap_hf_config_attrs(config) elif config_format == ConfigFormat.MISTRAL: - config = load_params_config(model, revision, **kwargs) + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if (max_position_embeddings := + config_dict.get("max_position_embeddings")) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs) + config_dict["max_position_embeddings"] = max_position_embeddings + + config = adapt_config_dict(config_dict) else: supported_formats = [ fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO @@ -413,6 +428,13 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) + if hf_overrides_kw: + logger.debug("Overriding HF config with %s", hf_overrides_kw) + config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.debug("Overriding HF config with %s", hf_overrides_fn) + config = hf_overrides_fn(config) + patch_rope_scaling(config) if trust_remote_code: @@ -693,117 +715,6 @@ def _reduce_config(config: VllmConfig): exc_info=e) -def load_params_config(model: Union[str, Path], revision: Optional[str], - **kwargs) -> PretrainedConfig: - # This function loads a params.json config which - # should be used when loading models in mistral format - - config_file_name = "params.json" - - config_dict = get_hf_file_to_dict(config_file_name, model, revision) - if config_dict is None: - raise ValueError( - f"Failed to load mistral '{config_file_name}' config for model " - f"{model}. Please check if the model is a mistral-format model " - f"and if the config file exists.") - assert isinstance(config_dict, dict) - - config_mapping = { - "dim": "hidden_size", - "norm_eps": "rms_norm_eps", - "n_kv_heads": "num_key_value_heads", - "n_layers": "num_hidden_layers", - "n_heads": "num_attention_heads", - "hidden_dim": "intermediate_size", - } - - def recurse_elems(elem: Any): - if isinstance(elem, dict): - config_dict = {} - for key, value in elem.items(): - key = config_mapping.get(key, key) - config_dict[key] = recurse_elems(value) - - return config_dict - else: - return elem - - config_dict["model_type"] = config_dict.get("model_type", "transformer") - config_dict["hidden_act"] = config_dict.get("activation", "silu") - config_dict["tie_word_embeddings"] = config_dict.get( - "tie_embeddings", False) - - if config_dict.get("max_position_embeddings") is None: - max_position_embeddings = 128_000 - try: - trust_remote_code_val = kwargs.get("trust_remote_code", False) - hf_config = get_config(model=model, - trust_remote_code=trust_remote_code_val, - revision=revision, - config_format=ConfigFormat.HF) - if hf_value := hf_config.get_text_config().max_position_embeddings: - max_position_embeddings = hf_value - except Exception as e: - logger.warning( - "The params.json file is missing 'max_position_embeddings'" - " and could not get a value from the HF config." - " Defaulting to 128000", - exc_info=e) - config_dict["max_position_embeddings"] = max_position_embeddings - - if config_dict.get("quantization") is not None: - quantization = config_dict.get("quantization", {}) - if quantization.get("qformat_weight") == "fp8_e4m3": - # This maps to the FP8 static per-tensor quantization scheme - quantization_config = { - "quant_method": "fp8", - "activation_scheme": "static" - } - elif quantization.get("quant_method") == "compressed-tensors": - # Pass through the quantization config to compressed-tensors - quantization_config = quantization - else: - raise ValueError( - f"Found unknown quantization='{quantization}' in config") - - config_dict["quantization_config"] = quantization_config - - config_type: Literal["text", - "multimodal"] = "multimodal" if config_dict.get( - "vision_encoder") is not None else "text" - - if config_dict.get("moe") is not None: - config_dict["architectures"] = ["MixtralForCausalLM"] - else: - config_dict["architectures"] = ["MistralForCausalLM"] - - if config_type == "multimodal": - multimodal_config = config_dict.pop("vision_encoder") - quantization_config = config_dict.get("quantization_config", {}) - - config_dict = { - "text_config": config_dict, - "vision_config": multimodal_config - } - config_dict["architectures"] = ["PixtralForConditionalGeneration"] - config_dict["model_type"] = "pixtral" - if quantization_config: - config_dict["quantization_config"] = quantization_config - - config_dict.update(kwargs) - - config_dict = recurse_elems(config_dict) - - # transform to HF config format - if config_type == "multimodal": - config_dict["text_config"] = PretrainedConfig( - **config_dict["text_config"]) - config_dict["vision_config"] = PretrainedConfig( - **config_dict["vision_config"]) - - return PretrainedConfig(**config_dict) - - def get_hf_image_processor_config( model: Union[str, Path], hf_token: Optional[Union[bool, str]] = None, @@ -826,13 +737,6 @@ def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. No op for pure text models. """ - # This block should be unnecessary after https://github.com/huggingface/transformers/pull/37517 - if hasattr(config, "thinker_config"): - # TODO(suyang.fy): Refactor code. - # For Qwen2.5-Omni, change hf_text_config to - # thinker_config.text_config. - return config.thinker_config.text_config - text_config = config.get_text_config() if text_config is not config: @@ -866,28 +770,6 @@ def try_get_generation_config( return None -def get_classification_activation_function(config: PretrainedConfig): - return nn.Sigmoid() if config.num_labels == 1 else nn.Softmax() - - -def get_cross_encoder_activation_function(config: PretrainedConfig): - function_name: Optional[str] = None - if (hasattr(config, "sentence_transformers") - and "activation_fn" in config.sentence_transformers): - function_name = config.sentence_transformers["activation_fn"] - elif (hasattr(config, "sbert_ce_default_activation_function") - and config.sbert_ce_default_activation_function is not None): - function_name = config.sbert_ce_default_activation_function - - if function_name is not None: - assert function_name.startswith("torch.nn.modules."), ( - "Loading of activation functions is restricted to " - "torch.nn.modules for security reasons") - return resolve_obj_by_qualname(function_name)() - - return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() - - def try_get_safetensors_metadata( model: str, *, @@ -920,3 +802,35 @@ def try_get_tokenizer_config( ) except Exception: return None + + +def _download_mistral_config_file(model, revision) -> dict: + config_file_name = "params.json" + config_dict = get_hf_file_to_dict(config_file_name, model, revision) + if config_dict is None: + raise ValueError( + f"Failed to load mistral '{config_file_name}' config for model " + f"{model}. Please check if the model is a mistral-format model " + f"and if the config file exists.") + assert isinstance(config_dict, dict) + return config_dict + + +def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: + max_position_embeddings = 128_000 + try: + trust_remote_code_val = kwargs.get("trust_remote_code", False) + hf_config = get_config(model=model, + trust_remote_code=trust_remote_code_val, + revision=revision, + config_format=ConfigFormat.HF) + if hf_value := hf_config.get_text_config().max_position_embeddings: + max_position_embeddings = hf_value + except Exception as e: + logger.warning( + "The params.json file is missing 'max_position_embeddings'" + " and could not get a value from the HF config." + " Defaulting to 128000", + exc_info=e) + + return max_position_embeddings diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 734f1e09d0fd..89303213a27e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -7,6 +7,7 @@ from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekVLV2Config from vllm.transformers_utils.configs.eagle import EAGLEConfig from vllm.transformers_utils.configs.exaone import ExaoneConfig +from vllm.transformers_utils.configs.exaone4 import Exaone4Config # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. @@ -22,6 +23,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.nemotron_h import NemotronHConfig +from vllm.transformers_utils.configs.nemotron_vl import Nemotron_Nano_VL_Config from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config from vllm.transformers_utils.configs.ovis import OvisConfig from vllm.transformers_utils.configs.skyworkr1v import SkyworkR1VChatConfig @@ -40,6 +42,7 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "Exaone4Config", "MiniMaxText01Config", "MiniMaxVL01Config", "MllamaConfig", @@ -48,6 +51,7 @@ "KimiVLConfig", "NemotronConfig", "NemotronHConfig", + "Nemotron_Nano_VL_Config", "NVLM_D_Config", "OvisConfig", "SkyworkR1VChatConfig", diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index fb2e8a1df705..5445a333c493 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -6,7 +6,6 @@ from transformers import AutoConfig, PretrainedConfig -import vllm.envs as envs from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config @@ -44,28 +43,25 @@ def __init__(self, self.truncated_vocab_size = self.model.vocab_size if \ truncated_vocab_size is None else truncated_vocab_size - if not envs.VLLM_USE_V1: - kwargs["architectures"] = ["EAGLEModel"] + # Eagle model name should follow naming convention of + # LlamaForCausalLM -> EagleLlamaForCausalLM + if method == "eagle": + assert self.model is not None, \ + "model should not be None when method is eagle" + kwargs["architectures"] = [ + f"Eagle{arch}" if not arch.startswith("Eagle") \ + else arch for arch in self.model.architectures + ] + elif method == "eagle3": + assert self.model is not None, \ + "model should not be None when method is eagle3" + kwargs["architectures"] = [ + f"Eagle3{arch}" if not arch.startswith("Eagle3") \ + else arch for arch in self.model.architectures + ] else: - # Eagle model name should follow naming convention of - # LlamaForCausalLM -> EagleLlamaForCausalLM - if method == "eagle": - assert self.model is not None, \ - "model should not be None when method is eagle" - kwargs["architectures"] = [ - f"Eagle{arch}" if not arch.startswith("Eagle") \ - else arch for arch in self.model.architectures - ] - elif method == "eagle3": - assert self.model is not None, \ - "model should not be None when method is eagle3" - kwargs["architectures"] = [ - f"Eagle3{arch}" if not arch.startswith("Eagle3") \ - else arch for arch in self.model.architectures - ] - else: - raise ValueError(f"Invalid method {method}. \ - Supported methods are eagle and eagle3.") + raise ValueError(f"Invalid method {method}. \ + Supported methods are eagle and eagle3.") super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/exaone4.py b/vllm/transformers_utils/configs/exaone4.py new file mode 100644 index 000000000000..a22ebaa6bd6b --- /dev/null +++ b/vllm/transformers_utils/configs/exaone4.py @@ -0,0 +1,252 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# ruff: noqa: E501 + +# Copied from +# https://github.com/lgai-exaone/transformers/blob/add-exaone4/src/transformers/models/exaone4/configuration_exaone4.py +# Copyright 2025 The LG CNS Gen AI Solution Delivery Team. +# Copyright 2025 The LG AI Research and HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers.configuration_utils import (PretrainedConfig, + layer_type_validation) +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def check_is_sliding(config, layer_idx): + """ + Check if the current layer is a sliding window attention (local attention) layer. + """ + if config.sliding_window is None: + return False + if config.layer_types is not None: + return config.layer_types[layer_idx] == "sliding_attention" + if isinstance(config.sliding_window_pattern, int): + return ((layer_idx + 1) % config.sliding_window_pattern) != 0 + elif isinstance(config.sliding_window_pattern, str): + assert isinstance(config.sliding_window, int), ( + f"Sliding window must be positive integer, but got {config.sliding_window}" + ) + return (layer_idx != config.num_hidden_layers - 1 + and config.sliding_window_pattern[layer_idx % len( + config.sliding_window_pattern)] == "L") + else: + logger.warning_once( + "Sliding window is set, but none of `sliding_window_pattern` or `layer_types` is set. " + "Defaulting to use 'full_attention' for all layers.") + return False + + +class Exaone4Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Exaone4Model`]. It is used to + instantiate a EXAONE 4.0 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the EXAONE-4.0-Instruct [LGAI-EXAONE/EXAONE-4.0-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-Instruct) + NOTE: `EXAONE-4.0-Instruct` is a placeholder model ID. The exact model ID will be updated in the future. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model + outputs. Read the documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the EXAONE 4.0 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Exaone4Model`]. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to `hidden_size * 4`): + Dimensionality of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 32768 for EXAONE 3.5). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the layer normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if ``config.is_decoder=True``. + bos_token_id (`int`, *optional*, defaults to 0): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + sliding_window (`int`, *optional*): + The size of the sliding window for the sliding window attention. + sliding_window_pattern (`str`, *optional*): + The pattern to use for sliding window attention. Can be one of: + - `None`: No sliding window attention is used + - `int`: Every `sliding_window` layers, use global attention, else use local attention. + - `str`: A sequence of "L" (local attention) and "G" (global attention) characters that defines the + attention pattern. The pattern starts from layer 0 and repeats every `sliding_window` layers. The + final layer always uses global attention regardless of the pattern. + For instance, sliding_window_pattern="LLLG" same as sliding_window=4, which means: + - Layer 0, 1, 2: local attention, + - Layer 3: global attention, + ...(repeated) + layer_types (`list`, *optional*): + Attention pattern for each layer. Prioritized over `sliding_window_pattern`. + + Example: + + ```python + >>> from transformers import Exaone4Model, Exaone4Config + + >>> # Initializing a EXAONE configuration + >>> configuration = Exaone4Config() + + >>> # Initializing a model from configuration + >>> model = Exaone4Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "exaone4" + keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `LlamaModel` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=None, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_dropout=0.0, + sliding_window=None, + sliding_window_pattern=None, + layer_types=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.attention_dropout = attention_dropout + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.sliding_window = sliding_window + self.sliding_window_pattern = sliding_window_pattern + + self.layer_types = layer_types + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" + if check_is_sliding(self, i) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs) + + +__all__ = ["Exaone4Config"] diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py new file mode 100644 index 000000000000..8a9c660b882f --- /dev/null +++ b/vllm/transformers_utils/configs/mistral.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from transformers import PretrainedConfig, WhisperConfig + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def adapt_config_dict(config_dict: dict[str, Any], + **kwargs) -> PretrainedConfig: + config_dict.update(kwargs) + config_dict = _remap_general_mistral_args(config_dict) + + if bool(config_dict.get("quantization")): + config_dict = _remap_mistral_quantization_args(config_dict) + + if bool(config_dict.get("moe")): + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if bool(config_dict.get("yarn")): + config_dict = _remap_mistral_yarn_args(config_dict) + + is_vision = ((config_dict.get("multimodal") + or {}).get("vision_encoder_args") + or config_dict.get("vision_encoder")) + is_audio = bool( + ((config_dict.get("multimodal") or {}).get("whisper_model_args") + or {}).get("encoder_args")) + + assert not (is_vision and is_audio), \ + "Vision and audio are mutually exclusive" + + if is_vision: + config_dict = _remap_mistral_vision_args(config_dict) + if is_audio: + config_dict = _remap_mistral_audio_args(config_dict) + + config = PretrainedConfig.from_dict(config_dict) + + logger.debug("Initialized config %s", config) + + return config + + +def _remap_mistral_vision_args(config: dict) -> dict: + if config.get("multimodal"): + vision_config = config.pop("multimodal") + else: + vision_config = config.pop("vision_encoder") + + quant_config = config.get("quantization_config") + config = { + "model_type": "pixtral", + "architectures": ["PixtralForConditionalGeneration"], + "text_config": PretrainedConfig.from_dict(config), + "vision_config": PretrainedConfig.from_dict(vision_config), + } + if quant_config: + config["quantization_config"] = quant_config + return config + + +def _remap_mistral_yarn_args(config: dict) -> dict: + # Direct remaps: yarn.X -> rope_scaling.Y + # Source keys are from mistral.model.args.YarnArgs + _map = { + "beta": "beta_fast", + "alpha": "beta_slow", + } + yarn_config = config.get("yarn") or {} + renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()} + config["rope_scaling"] = { + "rope_type": "yarn", + "mscale_all_dim": 1, # We hardcoded this to 1 + **renamed_yarn_config + } + return config + + +def _remap_general_mistral_args(config: dict) -> dict: + # Mistral key -> HF key + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + # HF key -> (Mistral key, default value) + top_level_mapping_with_default = { + "model_type": ("model_type", "transformer"), + "hidden_act": ("activation", "silu"), + "tie_word_embeddings": ("tied_embeddings", False), + "max_seq_len": ("max_seq_len", 128_000), + "max_position_embeddings": ("max_position_embeddings", 128_000), + } + + for key, new_key in config_mapping.items(): + if key in config: + config[new_key] = config.pop(key) + + for new_key, (key, + default_value) in top_level_mapping_with_default.items(): + config[new_key] = config.pop(key, default_value) + + return config + + +def _remap_mistral_quantization_args(config: dict) -> dict: + quantization = config.get("quantization", {}) + if quantization.get("qformat_weight") == "fp8_e4m3": + # This maps to the FP8 static per-tensor quantization scheme + quantization_config = { + "quant_method": "fp8", + "activation_scheme": "static" + } + elif quantization.get("quant_method") == "compressed-tensors": + # Pass through the quantization config to compressed-tensors + quantization_config = quantization + else: + raise ValueError( + f"Found unknown quantization='{quantization}' in config") + + config["quantization_config"] = quantization_config + + return config + + +def _remap_mistral_audio_args(config: dict) -> dict: + whisper_args = config["multimodal"].pop("whisper_model_args") + encoder_args = whisper_args["encoder_args"] + downsample_args = whisper_args["downsample_args"] + + quant_config = config.get("quantization_config") + config = { + "model_type": + "whixtral", + "architectures": ["VoxtralForConditionalGeneration"], + "text_config": + PretrainedConfig.from_dict(config), + "audio_config": + WhisperConfig( + num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"], + window_size=encoder_args["audio_encoding_args"]["window_size"], + sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"], + hop_length=encoder_args["audio_encoding_args"]["hop_length"], + downsample_factor=downsample_args["downsample_factor"], + d_model=encoder_args["dim"], + encoder_layers=encoder_args["n_layers"], + encoder_ffn_dim=encoder_args["hidden_dim"], + encoder_attention_heads=encoder_args["n_heads"], + vocab_size=encoder_args["vocab_size"], + max_source_positions=encoder_args["max_source_positions"], + ) + } + if quant_config: + config["quantization_config"] = quant_config + return config diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py index d65b572dc7f2..9a7243b1262c 100644 --- a/vllm/transformers_utils/configs/nemotron.py +++ b/vllm/transformers_utils/configs/nemotron.py @@ -202,4 +202,4 @@ def _rope_scaling_validation(self): rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError( "`rope_scaling`'s factor field must be a float > 1, got " - f"{rope_scaling_factor}") + f"{rope_scaling_factor}") \ No newline at end of file diff --git a/vllm/transformers_utils/configs/nemotron_vl.py b/vllm/transformers_utils/configs/nemotron_vl.py new file mode 100644 index 000000000000..6a642f26b82a --- /dev/null +++ b/vllm/transformers_utils/configs/nemotron_vl.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# yapf: disable +# ruff: noqa: E501 +# Adapted from +# https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1/blob/main/configuration.py +# -------------------------------------------------------- +# Adapted from https://huggingface.co/OpenGVLab/InternVL2-Llama3-76B under MIT License +# LICENSE is in incl_licenses directory. +# -------------------------------------------------------- + +from transformers import LlamaConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.dynamic_module_utils import get_class_from_dynamic_module + + +class Nemotron_Nano_VL_Config(PretrainedConfig): + model_type = 'Llama_Nemotron_Nano_VL' + is_composition = True + + def __init__( + self, + vision_config=None, + llm_config=None, + force_image_size=None, + downsample_ratio=0.5, + template=None, + ps_version='v1', + image_tag_type="internvl", + projector_hidden_size=4096, + vit_hidden_size=1280, + **kwargs + ): + super().__init__(**kwargs) + + if vision_config is not None: + assert "auto_map" in vision_config and "AutoConfig" in vision_config["auto_map"] + vision_auto_config = get_class_from_dynamic_module(*vision_config["auto_map"]["AutoConfig"].split("--")[::-1]) + self.vision_config = vision_auto_config(**vision_config) + else: + self.vision_config = PretrainedConfig() + + if llm_config is None: + self.text_config = LlamaConfig() + else: + self.text_config = LlamaConfig(**llm_config) + + # Assign configuration values + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template # TODO move out of here and into the tokenizer + self.ps_version = ps_version # Pixel shuffle version + self.image_tag_type = image_tag_type # TODO: into the tokenizer too? + self.projector_hidden_size = projector_hidden_size + self.vit_hidden_size = vit_hidden_size diff --git a/vllm/transformers_utils/configs/ovis.py b/vllm/transformers_utils/configs/ovis.py index c2728f0ed64c..021d402a71f4 100644 --- a/vllm/transformers_utils/configs/ovis.py +++ b/vllm/transformers_utils/configs/ovis.py @@ -73,8 +73,6 @@ def __init__( IMAGE_ATOM_ID = -300 IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305] -AutoConfig.register("aimv2", AIMv2Config) - # ---------------------------------------------------------------------- # Visual Tokenizer Configuration @@ -105,9 +103,11 @@ def __init__(self, f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type" if not isinstance(backbone_config, PretrainedConfig): model_type = backbone_config['model_type'] - backbone_config.pop('model_type') - backbone_config = AutoConfig.for_model(model_type, - **backbone_config) + if model_type != "aimv2": + backbone_config.pop('model_type') + backbone_config = AutoConfig.for_model(model_type, **backbone_config) + else: + backbone_config = AIMv2Config(**backbone_config) self.backbone_config = backbone_config self.hidden_stride = hidden_stride diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py index 342632989d57..be1040c3e014 100644 --- a/vllm/transformers_utils/detokenizer_utils.py +++ b/vllm/transformers_utils/detokenizer_utils.py @@ -89,8 +89,13 @@ def convert_ids_list_to_tokens( Python list of token string representations """ - token_str_lst = tokenizer.convert_ids_to_tokens(token_ids) - _replace_none_with_empty(token_str_lst) # type: ignore + token_str_lst = [] + for token_id in token_ids: + # use default skip_special_tokens. + token_str = tokenizer.decode([token_id]) + if token_str is None: + token_str = "" + token_str_lst.append(token_str) return token_str_lst diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index ae96ebe4eaa2..25dd71d877fb 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -16,15 +16,20 @@ from vllm import envs from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer_base import (TokenizerBase, - TokenizerRegistry) +from vllm.transformers_utils.config import ( + get_sentence_transformer_tokenizer_config) from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async if TYPE_CHECKING: from vllm.config import ModelConfig + from vllm.lora.request import LoRARequest + from vllm.transformers_utils.tokenizer_base import TokenizerBase +else: + ModelConfig = Any + LoRARequest = Any + TokenizerBase = Any logger = init_logger(__name__) @@ -222,6 +227,7 @@ def get_tokenizer( tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) elif tokenizer_mode == "custom": + from vllm.transformers_utils.tokenizer_base import TokenizerRegistry tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name), *args, revision=revision, @@ -252,6 +258,18 @@ def get_tokenizer( else: raise e + # The special_tokens in tokenizer should also be + # controlled by do_lower_case in encoder_config + encoder_config = get_sentence_transformer_tokenizer_config( + tokenizer_name, revision) + if isinstance(encoder_config, dict) and encoder_config.get( + "do_lower_case", False): + special_tokens_map = { + k: v.lower() + for k, v in tokenizer.special_tokens_map.items() + } + tokenizer.add_special_tokens(special_tokens_map) + # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324 if type(tokenizer).__name__ in ("ChatGLMTokenizer", "ChatGLM4Tokenizer"): @@ -271,7 +289,7 @@ def get_tokenizer( def cached_tokenizer_from_config( - model_config: "ModelConfig", + model_config: ModelConfig, **kwargs: Any, ): return cached_get_tokenizer( diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 24ac4580d670..f83405cfc016 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -145,6 +145,21 @@ def find_tokenizer_file(files: list[str]): return matched_files[0] +def _aggregate_content(content: list) -> list[dict[str, Any]]: + aggregated_content: list[dict[str, Any]] = [] + for chunk in content: + if chunk.get("type" + ) == "text" and aggregated_content and aggregated_content[ + -1].get("type") == "text": + aggregated_content[-1]["text"] += "\n\n" + chunk.get("text") + else: + aggregated_content.append(chunk) + if len(aggregated_content) == 1 and aggregated_content[0].get( + "type") == "text": + content = aggregated_content[0]["text"] + return content + + def make_mistral_chat_completion_request( messages: list["ChatCompletionMessageParam"], tools: Optional[list[dict[str, @@ -162,10 +177,10 @@ def make_mistral_chat_completion_request( # Convert list text content to string if message.get("role") in ("assistant", "tool"): - content = message.get("content") + content: Any = message.get("content") if isinstance(content, list): - content = "\n".join(chunk.get("text") for chunk in content) - message["content"] = content + content = _aggregate_content(content) + message["content"] = content # The Mistral client, in comparison to the OpenAI client, requires the # "parameters" dict to be present, even if it's empty. @@ -465,6 +480,8 @@ def convert_ids_to_tokens( skip_special_tokens: bool = True, ) -> list[str]: from mistral_common.tokens.tokenizers.base import SpecialTokens + from mistral_common.tokens.tokenizers.instruct import ( + InstructTokenizerV13) # TODO(Patrick) - potentially allow special tokens to not be skipped assert ( @@ -474,10 +491,18 @@ def convert_ids_to_tokens( assert self.is_tekken or self.is_spm, type(self.tokenizer) if self.is_tekken: - # skip special tokens except tool call - ids = [ - i for i in ids if i > self.tokenizer.num_special_tokens or i == + # skip special tokens except tool call and think tokens + non_skip_special_tokens = { self.tokenizer.get_control_token(SpecialTokens.tool_calls) + } + if isinstance(self.instruct, InstructTokenizerV13): + if self.instruct.BEGIN_THINK: + non_skip_special_tokens.add(self.instruct.BEGIN_THINK) + if self.instruct.END_THINK: + non_skip_special_tokens.add(self.instruct.END_THINK) + ids = [ + i for i in ids if i > self.tokenizer.num_special_tokens + or i in non_skip_special_tokens ] tokens = [self.tokenizer.id_to_piece(id) for id in ids] diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 6cc8429d76c3..372200027bf9 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -92,3 +92,4 @@ def __init__(self): self.constexpr = None self.dtype = None self.int64 = None + self.int32 = None diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 9550b056fbba..5b9c3b6a50cd 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -41,6 +41,7 @@ from collections.abc import (AsyncGenerator, Awaitable, Collection, Generator, Hashable, Iterable, Iterator, KeysView, Mapping, Sequence) +from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps @@ -51,6 +52,7 @@ from uuid import uuid4 import cachetools +import cbor2 import cloudpickle import numpy as np import numpy.typing as npt @@ -64,6 +66,7 @@ from packaging import version from packaging.version import Version from torch.library import Library +from transformers.tokenization_utils_base import BatchEncoding from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -125,10 +128,6 @@ "backends currently supported with encoder/" "decoder models.") -STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " - "currently supported with encoder/" - "decoder models.") - # Efficiently import all enc/dec error strings # rather than having to import all of the above STR_NOT_IMPL_ENC_DEC_ERR_STRS = { @@ -142,7 +141,6 @@ "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, - "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, } # Constants related to forcing the attention backend selection @@ -176,6 +174,7 @@ "fp8_e4m3": torch.uint8, "fp8_e5m2": torch.uint8, "int8": torch.int8, + "fp8_inc": torch.float8_e4m3fn, } TORCH_DTYPE_TO_NUMPY_DTYPE = { @@ -507,6 +506,196 @@ def random_uuid() -> str: return str(uuid.uuid4().hex) +class AsyncMicrobatchTokenizer: + """Asynchronous tokenizer with micro-batching. + + Pulls pending encode/decode requests from a queue and batches them + up to reduce overhead. A single-thread ThreadPoolExecutor is used + so the event loop stays responsive. + """ + + def __init__( + self, + tokenizer, + max_batch_size: int = 32, + batch_wait_timeout_s: float = 0.002, + ) -> None: + self.tokenizer = tokenizer + self.max_batch_size = max_batch_size + self.batch_wait_timeout_s = batch_wait_timeout_s + + self._loop = asyncio.get_running_loop() + self._queues: dict[tuple, + asyncio.Queue[Union[tuple[str, dict, + asyncio.Future], + tuple[list[int], + asyncio.Future]]]] = {} + self._batcher_tasks: list[asyncio.Task] = [] + + # Single-thread executor for blocking tokenizer calls. + self._executor = ThreadPoolExecutor(max_workers=1) + + # === Public async API === + async def __call__(self, prompt, **kwargs): + result_future: asyncio.Future = self._loop.create_future() + key = self._queue_key("encode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((prompt, kwargs, result_future)) + return await result_future + + async def decode(self, token_ids, **kwargs): + result_future: asyncio.Future = self._loop.create_future() + key = self._queue_key("decode", kwargs) + queue = self._get_queue(self._loop, key) + await queue.put((token_ids, result_future)) + return await result_future + + # === Internal helpers === + def _get_queue( + self, loop: asyncio.AbstractEventLoop, key: tuple + ) -> asyncio.Queue[Union[tuple[str, dict, asyncio.Future], tuple[ + list[int], asyncio.Future]]]: + """Get the request queue for the given operation key, creating a new + queue and batcher task if needed.""" + queue = self._queues.get(key) + if queue is None: + self._queues[key] = queue = asyncio.Queue() + if key[0] == "encode": + can_batch = key[1] != "other" + coro = self._batch_encode_loop(queue, can_batch) + else: + assert key[0] == "decode", \ + f"Unknown operation type: {key[0]}." + coro = self._batch_decode_loop(queue) + self._batcher_tasks.append(loop.create_task(coro)) + return queue + + async def _batch_encode_loop(self, queue: asyncio.Queue, can_batch: bool): + """Batch incoming encode requests for efficiency.""" + while True: + prompt, kwargs, result_future = await queue.get() + prompts = [prompt] + kwargs_list = [kwargs] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(prompts) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + prompt, kwargs, result_future = await asyncio.wait_for( + queue.get(), timeout) + prompts.append(prompt) + result_futures.append(result_future) + if not can_batch: + kwargs_list.append(kwargs) + except asyncio.TimeoutError: + break + + try: + # If every request uses identical kwargs we can run a single + # batched tokenizer call for a big speed-up. + if can_batch and len(prompts) > 1: + encode_fn = partial(self.tokenizer, prompts, **kwargs) + results = await self._loop.run_in_executor( + self._executor, encode_fn) + + for i, fut in enumerate(result_futures): + if not fut.done(): + data = {k: v[i] for k, v in results.items()} + fut.set_result(BatchEncoding(data)) + else: + encode_fn = lambda prompts=prompts, kwargs=kwargs_list: [ + self.tokenizer(p, **kw) + for p, kw in zip(prompts, kwargs) + ] + results = await self._loop.run_in_executor( + self._executor, encode_fn) + + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + async def _batch_decode_loop(self, queue: asyncio.Queue): + """Batch incoming decode requests for efficiency.""" + while True: + token_ids, result_future = await queue.get() + token_ids_list = [token_ids] + result_futures = [result_future] + deadline = self._loop.time() + self.batch_wait_timeout_s + + while len(token_ids_list) < self.max_batch_size: + timeout = deadline - self._loop.time() + if timeout <= 0: + break + try: + token_ids, result_future = await asyncio.wait_for( + queue.get(), timeout) + token_ids_list.append(token_ids) + result_futures.append(result_future) + except asyncio.TimeoutError: + break + + try: + # Perform a single batched decode call for all requests + results = await self._loop.run_in_executor( + self._executor, self.tokenizer.batch_decode, + token_ids_list) + for fut, res in zip(result_futures, results): + if not fut.done(): + fut.set_result(res) + except Exception as e: + for fut in result_futures: + if not fut.done(): + fut.set_exception(e) + + def _queue_key(self, op: str, kwargs: dict) -> tuple: + """ + Return a normalized key describing operation + kwargs. + + - `add_special_tokens`: {True/False} + - `truncation`: {True/False} + - If `truncation` is False (`max_length` is None), + returns a key for a can_batch queue. + - If `truncation` is True and `max_length` is None or equals + `tokenizer.model_max_length`, returns a key for a can_batch queue. + - Otherwise, returns a key for a cannot_batch queue. + + Examples: + - Decode: ("decode",) + - Encode typical: + ("encode", add_special_tokens, bool_truncation, max_length_label) + - Fallback: ("encode", "other") + """ + + if op == "decode": + return ("decode", ) + + add_special_tokens = kwargs.get("add_special_tokens", True) + truncation = kwargs.get("truncation", False) + max_length = kwargs.get("max_length") + + if not truncation: + return ("encode", add_special_tokens, False, None) + + model_max = getattr(self.tokenizer, "model_max_length", None) + if max_length is None or (model_max is not None + and max_length == model_max): + return ("encode", add_special_tokens, True, "model_max") + + return ("encode", "other") + + def __del__(self): + for task in self._batcher_tasks: + if not task.done(): + task.cancel() + + def make_async( func: Callable[P, T], executor: Optional[concurrent.futures.Executor] = None @@ -620,6 +809,33 @@ def get_ip() -> str: return "0.0.0.0" +def test_loopback_bind(address, family): + try: + s = socket.socket(family, socket.SOCK_DGRAM) + s.bind((address, 0)) # Port 0 = auto assign + s.close() + return True + except OSError: + return False + + +def get_loopback_ip() -> str: + loopback_ip = envs.VLLM_LOOPBACK_IP + if loopback_ip: + return loopback_ip + + # VLLM_LOOPBACK_IP is not set, try to get it based on network interface + + if test_loopback_bind("127.0.0.1", socket.AF_INET): + return "127.0.0.1" + elif test_loopback_bind("::1", socket.AF_INET6): + return "::1" + else: + raise RuntimeError( + "Neither 127.0.0.1 nor ::1 are bound to a local interface. " + "Set the VLLM_LOOPBACK_IP environment variable explicitly.") + + def is_valid_ipv6_address(address: str) -> bool: try: ipaddress.IPv6Address(address) @@ -754,6 +970,13 @@ def next_power_of_2(n) -> int: return 1 << (n - 1).bit_length() +def prev_power_of_2(n: int) -> int: + """The previous power of 2 (inclusive)""" + if n <= 0: + return 0 + return 1 << (n.bit_length() - 1) + + def round_up(x: int, y: int) -> int: return ((x + y - 1) // y) * y @@ -1155,12 +1378,11 @@ def find_nccl_library() -> str: prev_set_stream = torch.cuda.set_stream -_current_stream = None +_current_stream_tls = threading.local() def _patched_set_stream(stream: torch.cuda.Stream) -> None: - global _current_stream - _current_stream = stream + _current_stream_tls.value = stream prev_set_stream(stream) @@ -1179,16 +1401,16 @@ def current_stream() -> torch.cuda.Stream: from C/C++ code. """ from vllm.platforms import current_platform - global _current_stream - if _current_stream is None: + if not hasattr(_current_stream_tls, + "value") or _current_stream_tls.value is None: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL # is hurting performance. Therefore creating a dedicated stream # per process - _current_stream = torch.cuda.Stream() if current_platform.is_rocm( - ) else torch.cuda.current_stream() - return _current_stream + _current_stream_tls.value = torch.cuda.Stream( + ) if current_platform.is_rocm() else torch.cuda.current_stream() + return _current_stream_tls.value def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None: @@ -1343,6 +1565,13 @@ def cuda_is_initialized() -> bool: return torch.cuda.is_initialized() +def xpu_is_initialized() -> bool: + """Check if XPU is initialized.""" + if not torch.xpu._is_compiled(): + return False + return torch.xpu.is_initialized() + + def cuda_get_device_properties(device, names: Sequence[str], init_cuda=False) -> tuple[Any, ...]: @@ -1886,6 +2115,12 @@ def supports_dynamo() -> bool: return base_torch_version >= Version("2.4.0") +# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform +def supports_xccl() -> bool: + return is_torch_equal_or_newer( + "2.8.0.dev") and torch.distributed.is_xccl_available() + + # Some backends use pytorch version < 2.4.0 which doesn't # support `torch.library.custom_op`. def supports_custom_op() -> bool: @@ -2650,6 +2885,8 @@ def _maybe_force_spawn(): reason = None if cuda_is_initialized(): reason = "CUDA is initialized" + elif xpu_is_initialized(): + reason = "XPU is initialized" elif is_in_ray_actor(): # even if we choose to spawn, we need to pass the ray address # to the subprocess so that it knows how to connect to the ray cluster. @@ -2681,8 +2918,9 @@ def get_mp_context(): def bind_kv_cache( - ctx: dict[str, Any], - kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] + ctx: dict[str, Any], + kv_cache: list[list[torch.Tensor]], # [virtual_engine][layer_index] + shared_kv_cache_layers: Optional[dict[str, str]] = None ) -> None: # Bind the kv_cache tensor to Attention modules, similar to # ctx[layer_name].kv_cache[ve]=kv_cache[ve][extract_layer_index(layer_name)] @@ -2694,12 +2932,17 @@ def bind_kv_cache( # attention of the same layer (e.g., bart's decoder.layers.1.self_attn # and decoder.layers.1.encoder_attn) is mapped to the same kv cache # tensor + # 5. Some models have attention layers that share kv cache with previous + # layers, this is specified through shared_kv_cache_layers + if shared_kv_cache_layers is None: + shared_kv_cache_layers = {} from vllm.attention import AttentionType from vllm.model_executor.models.utils import extract_layer_index layer_need_kv_cache = [ layer_name for layer_name in ctx if (hasattr(ctx[layer_name], 'attn_type') and ctx[layer_name].attn_type - in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) + in (AttentionType.DECODER, AttentionType.ENCODER_DECODER)) \ + and ctx[layer_name].kv_sharing_target_layer_name is None ] layer_index_sorted = sorted( set( @@ -2712,6 +2955,12 @@ def bind_kv_cache( assert len(forward_ctx.kv_cache) == len(kv_cache) for ve, ve_kv_cache in enumerate(kv_cache): forward_ctx.kv_cache[ve] = ve_kv_cache[kv_cache_idx] + if shared_kv_cache_layers is not None: + for layer_name, target_layer_name in shared_kv_cache_layers.items(): + assert extract_layer_index(target_layer_name) < \ + extract_layer_index(layer_name), \ + "v0 doesn't support interleaving kv sharing" + ctx[layer_name].kv_cache = ctx[target_layer_name].kv_cache def run_method(obj: Any, method: Union[str, bytes, Callable], args: tuple[Any], @@ -2958,6 +3207,29 @@ def sha256(input) -> int: byteorder="big") +def sha256_cbor_64bit(input) -> int: + """ + Hash objects using CBOR serialization and SHA-256, then truncate to 64bits. + + This option is useful for non-Python-dependent serialization and hashing. + + Args: + input: Object to be serialized and hashed. Supported types include + basic Python types and complex structures like lists, tuples, and + dictionaries. + Custom classes must implement CBOR serialization methods. + + Returns: + An integer in the range [0, 2^64-1] representing the lower 64 bits + of the SHA-256 hash of the CBOR serialized input. + """ + input_bytes = cbor2.dumps(input, canonical=True) + full_hash = int.from_bytes(hashlib.sha256(input_bytes).digest(), + byteorder="big") + + return full_hash & ((1 << 64) - 1) + + def is_torch_equal_or_newer(target: str) -> bool: """Check if the installed torch version is >= the target version. diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py new file mode 100644 index 000000000000..09a12a8c11c5 --- /dev/null +++ b/vllm/utils/deep_gemm.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for DeepGEMM API changes. + +Users of vLLM should always import **only** these wrappers. +""" +from __future__ import annotations + +import functools +import importlib +from typing import Any, Callable, NoReturn + +import torch + +import vllm.envs as envs +from vllm.utils import cuda_get_device_properties, has_deep_gemm + + +@functools.cache +def is_blackwell_deep_gemm_used() -> bool: + """Return ``True`` if vLLM is configured to use DeepGEMM on a + Blackwell-class GPU. + """ + + if not (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() + and _per_block_cast_impl is not None): + return False + + return cuda_get_device_properties(0, ("major", ))[0] == 10 + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable DeepGEMM backend.""" + raise RuntimeError( + "DeepGEMM backend is not available. Please install the `deep_gemm` " + "package to enable FP8 kernels.") + + +def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None: + """Return the *new* symbol if it exists, otherwise the *old* one.""" + if hasattr(module, new): + return getattr(module, new) + if hasattr(module, old): + return getattr(module, old) + return None + + +_fp8_gemm_nt_impl: Callable[..., Any] | None = None +_grouped_impl: Callable[..., Any] | None = None +_grouped_masked_impl: Callable[..., Any] | None = None +_per_block_cast_impl: Callable[..., Any] | None = None + + +def _lazy_init() -> None: + """Import deep_gemm and resolve symbols on first use.""" + global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl, \ + _per_block_cast_impl + + # fast path + if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None + or _grouped_masked_impl is not None + or _per_block_cast_impl is not None): + return + + if not has_deep_gemm(): + return + + _dg = importlib.import_module("deep_gemm") + + _fp8_gemm_nt_impl = _resolve_symbol(_dg, "fp8_gemm_nt", + "gemm_fp8_fp8_bf16_nt") + _grouped_impl = _resolve_symbol( + _dg, "m_grouped_fp8_gemm_nt_contiguous", + "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous") + _grouped_masked_impl = _resolve_symbol( + _dg, "fp8_m_grouped_gemm_nt_masked", + "m_grouped_gemm_fp8_fp8_bf16_nt_masked") + # Try to get per_token_cast_to_fp8 from DeepGEMM math utils. + try: + _math_mod = importlib.import_module( + "deep_gemm.utils.math") # type: ignore + _per_block_cast_impl = getattr(_math_mod, "per_block_cast_to_fp8", + None) + except ModuleNotFoundError: + _per_block_cast_impl = None + + +def fp8_gemm_nt(*args, **kwargs): + _lazy_init() + if _fp8_gemm_nt_impl is None: + return _missing(*args, **kwargs) + return _fp8_gemm_nt_impl(*args, **kwargs) + + +def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): + _lazy_init() + if _grouped_impl is None: + return _missing(*args, **kwargs) + return _grouped_impl(*args, **kwargs) + + +def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): + _lazy_init() + if _grouped_masked_impl is None: + return _missing(*args, **kwargs) + return _grouped_masked_impl(*args, **kwargs) + + +def per_block_cast_to_fp8(x, *args, **kwargs): + _lazy_init() + if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used(): + return _per_block_cast_impl(x, use_ue8m0=True) + # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils + from tests.kernels.quant_utils import per_block_cast_to_fp8 as _pbcf + return _pbcf(x, *args, **kwargs) + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + """Return a global difference metric for unit tests. + + DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element + error, causing ``torch.testing.assert_close`` to fail. Instead of checking + every element, we compute a cosine-style similarity over the whole tensor + and report ``1 - sim``. Once kernel accuracy improves this helper can be + removed. + """ + + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +__all__ = [ + "calc_diff", + "fp8_gemm_nt", + "m_grouped_fp8_gemm_nt_contiguous", + "fp8_m_grouped_gemm_nt_masked", + "per_block_cast_to_fp8", + "is_blackwell_deep_gemm_used", +] diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py new file mode 100644 index 000000000000..1ddafbae7fc0 --- /dev/null +++ b/vllm/utils/flashinfer.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compatibility wrapper for FlashInfer API changes. + +Users of vLLM should always import **only** these wrappers. +""" +from __future__ import annotations + +import contextlib +import functools +import importlib +import importlib.util +from typing import Any, Callable, NoReturn + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@functools.cache +def has_flashinfer() -> bool: + """Return ``True`` if FlashInfer is available.""" + # Use find_spec to check if the module exists without importing it + # This avoids potential CUDA initialization side effects + return importlib.util.find_spec("flashinfer") is not None + + +def _missing(*_: Any, **__: Any) -> NoReturn: + """Placeholder for unavailable FlashInfer backend.""" + raise RuntimeError( + "FlashInfer backend is not available. Please install the package " + "to enable FlashInfer kernels: " + "https://github.com/flashinfer-ai/flashinfer") + + +def _get_submodule(module_name: str) -> Any | None: + """Safely import a submodule and return it, or None if not available.""" + try: + return importlib.import_module(module_name) + except (ImportError, ModuleNotFoundError): + return None + + +# General lazy import wrapper +def _lazy_import_wrapper(module_name: str, + attr_name: str, + fallback_fn: Callable[..., Any] = _missing): + """Create a lazy import wrapper for a specific function.""" + + @functools.cache + def _get_impl(): + if not has_flashinfer(): + return None + mod = _get_submodule(module_name) + return getattr(mod, attr_name, None) if mod else None + + def wrapper(*args, **kwargs): + impl = _get_impl() + if impl is None: + return fallback_fn(*args, **kwargs) + return impl(*args, **kwargs) + + return wrapper + + +# Create lazy wrappers for each function +flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper( + "flashinfer.fused_moe", "trtllm_fp8_block_scale_moe") +flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe", + "cutlass_fused_moe") +fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") +block_scale_interleave = _lazy_import_wrapper("flashinfer", + "block_scale_interleave") + +# Special case for autotune since it returns a context manager +autotune = _lazy_import_wrapper( + "flashinfer.autotuner", + "autotune", + fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) + + +@functools.cache +def has_flashinfer_moe() -> bool: + """Return ``True`` if FlashInfer MoE module is available.""" + return importlib.util.find_spec("flashinfer.fused_moe") is not None + + +@functools.cache +def has_flashinfer_cutlass_fused_moe() -> bool: + """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" + if not has_flashinfer_moe(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.fused_moe", "cutlass_fused_moe"), + ("flashinfer", "fp4_quantize"), + ("flashinfer", "block_scale_interleave"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + +__all__ = [ + "has_flashinfer", + "flashinfer_trtllm_fp8_block_scale_moe", + "flashinfer_cutlass_fused_moe", + "fp4_quantize", + "block_scale_interleave", + "autotune", + "has_flashinfer_moe", + "has_flashinfer_cutlass_fused_moe", +] diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 37c04c7a029e..3b6d753863d0 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -1,34 +1,51 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import Optional + import numpy as np import torch +from torch.nn.functional import scaled_dot_product_attention -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata) -from vllm.attention.backends.torch_sdpa import (TorchSDPABackendImpl, - TorchSDPAMetadata) +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType, + is_quantized_kv_cache) from vllm.attention.backends.utils import CommonAttentionState -from vllm.attention.ops.ipex_attn import PagedAttention +from vllm.config import VllmConfig +from vllm.logger import init_logger from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable -from vllm.v1.worker.cpu_model_runner import CPUModelRunner from vllm.v1.worker.gpu_input_batch import InputBatch +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +# AttributeError is to handle a bug in ipex +# https://github.com/intel/intel-extension-for-pytorch/pull/813 +except (ImportError, AttributeError): + _use_ipex = False + +from vllm import _custom_ops as ops + +logger = init_logger(__name__) + class TorchSDPABackend(AttentionBackend): accept_output_buffer: bool = False @classmethod - def get_supported_head_sizes(cls) -> list[int]: - return PagedAttention.get_supported_head_sizes() + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] @classmethod def validate_head_size(cls, head_size: int) -> None: - supported_head_sizes = cls.get_supported_head_sizes() - if head_size not in supported_head_sizes: + attn_impl = _get_paged_attn_impl() + is_valid, supported_head_sizes = attn_impl.validate_head_size( + head_size) + if not is_valid: attn_type = cls.__name__.removesuffix("Backend") raise ValueError( f"Head size {head_size} is not supported by {attn_type}. " @@ -63,30 +80,256 @@ def get_kv_cache_shape( num_kv_heads: int, head_size: int, ) -> tuple[int, ...]: - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, - num_kv_heads, head_size) + return _get_paged_attn_impl().get_kv_cache_shape( + num_blocks, block_size, num_kv_heads, head_size) @staticmethod def use_cascade_attention(*args, **kwargs) -> bool: return False +@dataclass +class TorchSDPAMetadata(AttentionMetadata): + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + chunked_prefill: bool + seq_lens: Optional[list[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + prefill_query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None + + # For V1 logits index only + query_start_loc: Optional[torch.Tensor] = None + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[list[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[list[torch.Tensor]] = None + self.encoder_attn_bias: Optional[list[torch.Tensor]] = None + self.cross_attn_bias: Optional[list[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_prefill_tokens == 0: + return None + return self + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + if self.num_decode_tokens == 0: + return None + return self + + def get_seq_lens( + self, + attn_type: str, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: str, + ) -> Optional[list[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: list[torch.Tensor], + attn_type: str, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: str, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]): - def __init__(self, runner: CPUModelRunner, kv_cache_spec: AttentionSpec, - block_table: BlockTable) -> None: - self.runner = runner - self.block_table = block_table + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device) -> None: + self.kv_cache_spec = kv_cache_spec + self.vllm_config = vllm_config + self.scheduler_config = vllm_config.scheduler_config # For reorder - self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs, - dtype=np.int64) - self.reorder_decode_req_index_list = np.empty(self.runner.max_num_reqs, - dtype=np.int64) + self.reorder_prompt_req_index_list = np.empty( + vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) + self.reorder_decode_req_index_list = np.empty( + vllm_config.scheduler_config.max_num_seqs, dtype=np.int64) self.num_prompt_req: int = 0 self.seq_start_loc_cpu = torch.zeros( - runner.max_num_reqs + 1, + vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, device="cpu", ) @@ -136,15 +379,15 @@ def reorder_batch(self, input_batch: InputBatch, return True - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> TorchSDPAMetadata: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - runner = self.runner - block_table = self.block_table - seq_lens_np = runner.seq_lens_np[:num_reqs] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + seq_lens_np = seq_lens_cpu.numpy() num_prompt_req = self.num_prompt_req max_prefill_seq_len = seq_lens_np[:num_prompt_req].max().item( ) if num_prompt_req > 0 else 0 @@ -152,33 +395,526 @@ def build(self, common_prefix_len: int, ) if num_prompt_req < num_reqs else 0 self.seq_start_loc_np[0] = 0 np.cumsum(seq_lens_np, out=self.seq_start_loc_np[1:num_reqs + 1]) - num_prefill_tokens = runner.query_start_loc_np[num_prompt_req].item() - num_decode_tokens = runner.query_start_loc_np[num_reqs].item( - ) - num_prefill_tokens - slot_mapping = block_table.slot_mapping_cpu[:num_actual_tokens].long() - block_table_tensor = block_table.get_device_tensor() + + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + num_prefill_tokens = int(query_start_loc_cpu[num_prompt_req].item()) + num_decode_tokens = int(query_start_loc_cpu[num_reqs].item() - + num_prefill_tokens) + + slot_mapping = common_attn_metadata.slot_mapping.long() + block_table_tensor = common_attn_metadata.block_table_tensor + attn_metadata = TorchSDPAMetadata( num_prefills=num_prompt_req, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - seq_lens_tensor=runner. - seq_lens_cpu[num_prompt_req:num_reqs], # decode + # to ensure inference when chunked_prefill is disabled + seq_lens=seq_lens_cpu.tolist(), + seq_lens_tensor=seq_lens_cpu[num_prompt_req:num_reqs], # decode max_decode_seq_len=max_decode_seq_len, # decode block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode - chunked_prefill=True, + chunked_prefill=self.scheduler_config.chunked_prefill_enabled, max_query_len=max_query_len, max_kv_len=max_prefill_seq_len, - prefill_query_start_loc=runner. - query_start_loc_cpu[:num_prompt_req + 1], # prefill + prefill_query_start_loc=query_start_loc_cpu[:num_prompt_req + + 1], # prefill kv_start_loc=self.seq_start_loc_cpu[:num_prompt_req + 1], # prefill prefill_block_tables=block_table_tensor[: num_prompt_req], # prefill - query_start_loc=runner.query_start_loc_cpu[:num_reqs + - 1], # for logits index + query_start_loc=query_start_loc_cpu[:num_reqs + + 1], # for logits index multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, ) return attn_metadata + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + ) -> None: + if kv_sharing_target_layer_name is not None: + raise NotImplementedError("KV sharing is not supported in V0.") + if logits_soft_cap is not None: + logger.warning_once("Torch SPDA does not support logits soft cap. " + "Outputs may be slightly off.") + self.paged_attn_impl = _get_paged_attn_impl() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + if is_quantized_kv_cache(kv_cache_dtype) and not _use_ipex: + raise NotImplementedError( + "Torch SDPA backend FP8 KV cache requires " + "intel_extension_for_pytorch support.") + self.attn_type = attn_type + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if output_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for TorchSDPABackendImpl") + + # For warming-up + if attn_metadata is None: + return query + + attn_type = self.attn_type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = self.paged_attn_impl.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + self.paged_attn_impl.write_to_paged_cache( + key, value, key_cache, value_cache, updated_slot_mapping, + self.kv_cache_dtype, layer._k_scale, layer._v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + if prefill_meta := attn_metadata.prefill_metadata: + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + assert attn_metadata.seq_lens is not None + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.prefill_query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) + + if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + self.paged_attn_impl.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + layer._k_scale, + layer._v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: str = AttentionType.DECODER, + ) -> None: + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and mask is None, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: list[int], +) -> list[torch.Tensor]: + attn_biases: list[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: list[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> list[torch.Tensor]: + attn_biases: list[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases + + +class _PagedAttention: + + @staticmethod + def validate_head_size(head_size: int) -> tuple[bool, list[int]]: + SUPPORT_HS = [32, 64, 80, 96, 112, 128, 192, 256] + return head_size in SUPPORT_HS, SUPPORT_HS + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> tuple[int, ...]: + return 2, num_blocks, block_size * num_kv_heads * head_size + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: list[torch.Tensor], + src_to_dists: torch.Tensor, + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + + @staticmethod + def validate_head_size(head_size: int) -> tuple[bool, list[int]]: + return True, [] + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: torch.Tensor, + v_scale: torch.Tensor, + *args, + ) -> None: + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + +def _get_paged_attn_impl(): + if _use_ipex: + return _IPEXPagedAttention + else: + return _PagedAttention diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index fbc13c06c65a..5fe274f2c65b 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import ClassVar, Optional import numpy as np import torch @@ -25,14 +25,10 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.logger import init_logger from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, - make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata, + get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable - -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -44,6 +40,10 @@ class FlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @@ -130,18 +130,6 @@ class FlashAttentionMetadata: prefix_scheduler_metadata: Optional[torch.Tensor] = None max_num_splits: int = 0 - # for local attention - @dataclass - class LocalAttentionMetadata: - local_query_start_loc: torch.Tensor - local_seqused_k: torch.Tensor - local_block_table: torch.Tensor - local_max_query_len: int - local_max_seq_len: int - local_scheduler_metadata: Optional[torch.Tensor] - - local_attn_metadata: Optional[LocalAttentionMetadata] = None - def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: @@ -158,29 +146,30 @@ class FlashAttentionMetadataBuilder( AttentionMetadataBuilder[FlashAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3 - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - model_config = runner.model_config - compilation_config = runner.vllm_config.compilation_config - - self.runner = runner - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.compilation_config = vllm_config.compilation_config + self.device = device + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config) + self.num_heads_kv = self.model_config.get_num_kv_heads( + self.parallel_config) + self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size - self.kv_cache_spec = kv_cache_spec - self.block_table = block_table self.max_num_splits = 0 # No upper bound on the number of splits. self.aot_schedule = (get_flash_attn_version() == 3) - self.use_full_cuda_graph = compilation_config.full_cuda_graph + self.use_full_cuda_graph = self.compilation_config.full_cuda_graph if self.use_full_cuda_graph: if not self.aot_schedule: raise ValueError( "AoT scheduling is required for full cuda graph.") - capture_sizes = compilation_config.cudagraph_capture_sizes + capture_sizes = self.compilation_config.cudagraph_capture_sizes if not capture_sizes: raise ValueError( "cudagraph_capture_sizes should not be None when " @@ -194,9 +183,9 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, "full cuda graph.") self.scheduler_metadata = torch.zeros( - self.runner.max_num_reqs + 1, + vllm_config.scheduler_config.max_num_seqs + 1, dtype=torch.int32, - device=self.runner.device, + device=self.device, ) # When using cuda graph, we need to set the upper bound of the # number of splits so that large enough intermediate buffers are @@ -207,28 +196,26 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata - ) -> FlashAttentionMetadata: + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashAttentionMetadata: + """ + fast_build disables AOT scheduling, used when there will be few + iterations i.e. spec-decode + """ num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + # the overhead of the aot schedule is not worth it for spec-decode + aot_schedule = self.aot_schedule and not fast_build if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) @@ -236,19 +223,20 @@ def build( # constant for all layers to. We have to populate this on the first # build() call so the layers are constructed (cannot populate) # in __init__. - if self.aot_schedule: + if aot_schedule: sliding_window_configs = _get_sliding_window_configs( - self.runner.vllm_config) + self.vllm_config) if len(sliding_window_configs) == 1: sliding_window_config = sliding_window_configs.pop() if sliding_window_config is not None: self.aot_sliding_window = sliding_window_config elif len(sliding_window_configs) > 1: self.aot_schedule = False + aot_schedule = False def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): - if self.aot_schedule: + if aot_schedule: return get_scheduler_metadata( batch_size=batch_size, max_seqlen_q=max_query_len, @@ -265,53 +253,17 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, ) return None - # for local attention - local_attn_metadata = None - if self.runner.attention_chunk_size is not None: - seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table_tensor = make_local_attention_virtual_batches( - self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], - block_table_tensor, - self.block_size, - ) - local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.runner.device, non_blocking=True) - local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max() - local_max_seq_len = virt_k_seqlens_np.max() - local_scheduler_metadata = schedule( - batch_size=local_query_start_loc.shape[0] - 1, - cu_query_lens=local_query_start_loc, - max_query_len=local_max_query_len, - seqlens=local_seqused_k, - max_seq_len=local_max_seq_len, - causal=True) - - local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=local_query_start_loc, - local_seqused_k=local_seqused_k, - local_block_table=virt_block_table_tensor, - local_max_query_len=local_max_query_len, - local_max_seq_len=local_max_seq_len, - local_scheduler_metadata=local_scheduler_metadata, - ) - use_cascade = common_prefix_len > 0 if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, - device=self.runner.device) + device=self.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, - device=self.runner.device) - suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - - common_prefix_len) - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( - self.runner.device) + device=self.device) + suffix_kv_lens = (seq_lens_cpu[:num_reqs] - common_prefix_len).to( + self.device, non_blocking=True) prefix_scheduler_metadata = schedule( batch_size=1, cu_query_lens=cu_prefix_query_lens, @@ -372,7 +324,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, - local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, max_num_splits=max_num_splits, ) @@ -398,15 +349,10 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, - use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -434,7 +380,6 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") - self.use_irope = use_irope self.vllm_flash_attn_version = get_flash_attn_version() if is_quantized_kv_cache(self.kv_cache_dtype) \ and not flash_attn_supports_fp8(): @@ -518,27 +463,13 @@ def forward( layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) - # Compute attention and update output up to `num_actual_tokens`. - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if not attn_metadata.use_cascade or use_local_attn: - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - scheduler_metadata = local_metadata.local_scheduler_metadata - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table - scheduler_metadata = attn_metadata.scheduler_metadata + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + scheduler_metadata = attn_metadata.scheduler_metadata descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) @@ -566,8 +497,6 @@ def forward( ) return output - assert not use_local_attn, ( - "Cascade attention does not support local attention.") # Cascade attention (rare case). cascade_attention( output[:num_actual_tokens], @@ -603,6 +532,7 @@ def use_cascade_attention( num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, + use_local_attention: bool, num_sms: int, ) -> bool: """Decide whether to use cascade attention. @@ -618,7 +548,7 @@ def use_cascade_attention( if common_prefix_len < 256: return False # Cascade attention is currently not supported with these variants. - if use_alibi or use_sliding_window: + if use_alibi or use_sliding_window or use_local_attention: return False # Too few queries. Probably not worth using cascade attention. # We use an arbitrary threshold of 8 queries. TODO: Tune this threshold. diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 860309faa905..953ef26c8143 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -4,30 +4,31 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Optional import torch from flashinfer import (BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, MultiLevelCascadeAttentionWrapper) +from flashinfer.decode import trtllm_batch_decode_with_kv_cache import vllm.envs as envs from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionType) -from vllm.attention.layer import Attention -from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import use_cascade_attention -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, - get_kv_cache_layout) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, PerLayerParameters, + get_kv_cache_layout, get_per_layer_parameters, + infer_global_hyperparameters, reorder_batch_to_split_decodes_and_prefills, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 @@ -37,6 +38,11 @@ class FlashInferBackend(AttentionBackend): accept_output_buffer: bool = True + cached_sm100a_supported: Optional[bool] = None + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] @classmethod def get_supported_head_sizes(cls) -> list[int]: @@ -92,69 +98,56 @@ def get_kv_cache_stride_order() -> tuple[int, ...]: raise ValueError(f"Unknown cache layout format {cache_layout}.") return stride_order + @staticmethod + def use_trtllm_decode_attention( + batch_size: int, + max_seq_len: int, + kv_cache_dtype: str, + num_qo_heads: int, + num_kv_heads: int, + attn_head_size: int, + ) -> bool: + if FlashInferBackend.cached_sm100a_supported is None: + FlashInferBackend.cached_sm100a_supported = ( + current_platform.has_device_capability(100)) + if not FlashInferBackend.cached_sm100a_supported: + return False + if (num_qo_heads // num_kv_heads > 8 + or num_qo_heads % num_kv_heads != 0 or attn_head_size != 128): + return False + env_value = envs.VLLM_USE_TRTLLM_DECODE_ATTENTION + if env_value is not None: + logger.info_once("VLLM_USE_TRTLLM_DECODE_ATTENTION is set to %s", + env_value) + # Environment variable is set - respect it + # Making the conditional check for zero because + # the path is automatically enabled if the batch size condition + # is satisfied. + no_use_trtllm = env_value == "0" + if not no_use_trtllm: + logger.info_once( + "VLLM_USE_TRTLLM_DECODE_ATTENTION is set to 1, " + "using TRTLLM decode attention.") + return not no_use_trtllm + else: + # Environment variable not set - use auto-detection + # Only supports attention head size of 128 + use_trtllm = (FlashInferBackend.cached_sm100a_supported + and batch_size <= 256 and max_seq_len < 131072 + and kv_cache_dtype == "auto") + if use_trtllm: + logger.warning_once( + "Using TRTLLM decode attention (auto-detected).") + return use_trtllm -@dataclass -class PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters. - """ - - window_left: int - logits_soft_cap: Optional[float] - sm_scale: float - - -def get_per_layer_parameters( - vllm_config: VllmConfig) -> dict[str, PerLayerParameters]: - """ - Scan all attention layers and determine some hyperparameters - to use during `plan`. - """ - - layers = get_layers_from_vllm_config(vllm_config, Attention) - per_layer_params: dict[str, PerLayerParameters] = {} - - for key, layer in layers.items(): - impl = layer.impl - assert isinstance(impl, FlashInferImpl) - - # Infer hyperparameters from the attention layer - window_size = impl.sliding_window - window_left = window_size[0] if window_size is not None else -1 - logits_soft_cap = impl.logits_soft_cap - sm_scale = impl.scale - - per_layer_params[key] = PerLayerParameters(window_left, - logits_soft_cap, sm_scale) - - return per_layer_params - - -def infer_global_hyperparameters( - per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: - """ - Currently, FlashInfer backend only support models in which all layers share - the same values for the following hyperparameters: - - `window_left` - - `logits_soft_cap` - - `sm_scale` - - So this function asserts that all layers share the same values for these - hyperparameters and returns the global values. - """ - - assert len(per_layer_params) > 0, "No attention layers found in the model." - - param_sets = list(per_layer_params.values()) - global_params = param_sets[0] - for params in param_sets: - assert params == global_params, ( - "FlashInfer backend currently only supports models in which all " - "layers share the same values for the following hyperparameters: " - "`window_left`, `logits_soft_cap`, `sm_scale`.") - - return global_params + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") @dataclass @@ -190,12 +183,18 @@ class FlashInferMetadata: # Block size of vllm page_size: int # The data type of the paged kv cache - data_type: torch.dtype + kv_data_type: torch.dtype # The data type of the query q_data_type: torch.dtype slot_mapping: torch.Tensor + # For flashinfer trtllm batch decode + max_seq_len: int + seq_lens: torch.Tensor + block_table_tensor: torch.Tensor + workspace_buffer: torch.Tensor + # For handling prefill decode split num_decodes: int num_decode_tokens: int @@ -225,9 +224,9 @@ def __post_init__(self): class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): - def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, - block_table: BlockTable): - self.runner = runner + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.device = device self._workspace_buffer = None self._prefill_wrapper = None # Wrapper for prefill/append self._decode_wrapper = None # Wrapper for decode @@ -236,75 +235,22 @@ def __init__(self, runner: GPUModelRunner, kv_cache_spec: AttentionSpec, # Global hyperparameters shared by all attention layers self.global_hyperparameters: Optional[PerLayerParameters] = None - self.vllm_config = runner.vllm_config + self.vllm_config = vllm_config + self.cache_config = vllm_config.cache_config self.kv_cache_spec = kv_cache_spec - self.block_table = block_table def reorder_batch(self, input_batch: InputBatch, scheduler_output: SchedulerOutput) -> bool: - # We now want to reorder the batch so that the "decode" requests are and - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the decode run only supports num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - - return modified_batch + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) def _get_workspace_buffer(self): if self._workspace_buffer is None: self._workspace_buffer = torch.empty( FLASHINFER_WORKSPACE_BUFFER_SIZE, dtype=torch.uint8, - device=self.runner.device) + device=self.device) return self._workspace_buffer def _get_prefill_wrapper(self): @@ -315,10 +261,11 @@ def _get_prefill_wrapper(self): def _get_decode_wrapper(self): if self._decode_wrapper is None: - num_qo_heads = (self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config)) - num_kv_heads = self.runner.model_config.get_num_kv_heads( - self.runner.parallel_config) + num_qo_heads = ( + self.vllm_config.model_config.get_num_attention_heads( + self.vllm_config.parallel_config)) + num_kv_heads = self.vllm_config.model_config.get_num_kv_heads( + self.vllm_config.parallel_config) use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( @@ -333,10 +280,11 @@ def _get_cascade_wrapper(self): 2, self._get_workspace_buffer(), get_kv_cache_layout()) return self._cascade_wrapper - def _plan(self, attn_metadata: FlashInferMetadata): + def _plan(self, num_prefills: int, num_decodes: int, + attn_metadata: FlashInferMetadata): if self.global_hyperparameters is None: self.global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(self.vllm_config)) + get_per_layer_parameters(self.vllm_config, FlashInferImpl)) if attn_metadata.use_cascade: attn_metadata.cascade_wrapper = self._get_cascade_wrapper() attn_metadata.cascade_wrapper.plan( @@ -362,21 +310,22 @@ def _plan(self, attn_metadata: FlashInferMetadata): window_left=self.global_hyperparameters.window_left, logits_soft_cap=self.global_hyperparameters.logits_soft_cap, q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.kv_data_type, ) else: # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() - if self._num_prefills > 0: + if num_prefills > 0: # Decodes are first so prefills start after the last decode - prefill_start = self._num_decodes + prefill_start = num_decodes attn_metadata.prefill_wrapper = self._get_prefill_wrapper() assert attn_metadata.qo_indptr[prefill_start:].shape[ - 0] == self._num_prefills + 1 + 0] == num_prefills + 1 assert attn_metadata.paged_kv_indptr[prefill_start:].shape[ - 0] == self._num_prefills + 1 + 0] == num_prefills + 1 assert attn_metadata.paged_kv_last_page_len[ - prefill_start:].shape[0] == self._num_prefills + prefill_start:].shape[0] == num_prefills # Since prefill_wrapper.run() will be called with # query[num_decode_tokens:] we need to adjust the qo_indptr # to be relative to the start of the prefill queries. @@ -397,44 +346,48 @@ def _plan(self, attn_metadata: FlashInferMetadata): logits_soft_cap=self.global_hyperparameters. logits_soft_cap, q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type, + kv_data_type=attn_metadata.kv_data_type, ) - if self._num_decodes > 0: + if num_decodes > 0: attn_metadata.decode_wrapper = self._get_decode_wrapper() - attn_metadata.decode_wrapper.plan( - attn_metadata.paged_kv_indptr[:self._num_decodes + 1], - attn_metadata.paged_kv_indices, - attn_metadata.paged_kv_last_page_len[:self._num_decodes], - attn_metadata.num_qo_heads, - attn_metadata.num_kv_heads, - attn_metadata.head_dim, - attn_metadata.page_size, - # Disable flashinfer's pos encoding and use vllm's rope. - pos_encoding_mode="NONE", - sm_scale=self.global_hyperparameters.sm_scale, - window_left=self.global_hyperparameters.window_left, - logits_soft_cap=self.global_hyperparameters. - logits_soft_cap, - q_data_type=attn_metadata.q_data_type, - kv_data_type=attn_metadata.data_type, - ) - - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): - num_reqs = common_attn_metadata.num_reqs + if not FlashInferBackend.use_trtllm_decode_attention( + num_decodes, attn_metadata.max_seq_len, + self.cache_config.cache_dtype, + attn_metadata.num_qo_heads, attn_metadata.num_kv_heads, + attn_metadata.head_dim): + attn_metadata.decode_wrapper.plan( + attn_metadata.paged_kv_indptr[:num_decodes + 1], + attn_metadata.paged_kv_indices, + attn_metadata.paged_kv_last_page_len[:num_decodes], + attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, + attn_metadata.head_dim, + attn_metadata.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + sm_scale=self.global_hyperparameters.sm_scale, + window_left=self.global_hyperparameters.window_left, + logits_soft_cap=self.global_hyperparameters. + logits_soft_cap, + q_data_type=attn_metadata.q_data_type, + kv_data_type=attn_metadata.kv_data_type, + ) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashInferMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ + split_decodes_and_prefills(common_attn_metadata) - assert self._num_decodes + self._num_prefills == num_reqs - assert (self._num_decode_tokens + - self._num_prefill_tokens == num_actual_tokens) page_size = self.kv_cache_spec.block_size - device = self.runner.device + device = self.device qo_indptr = common_attn_metadata.query_start_loc + max_seq_len = common_attn_metadata.seq_lens_cpu.max() seq_lens = common_attn_metadata.seq_lens - block_table_tensor = self.block_table.get_device_tensor()[:num_reqs] - slot_mapping = self.block_table.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True).long() + block_table_tensor = common_attn_metadata.block_table_tensor block_table_bounds = (seq_lens + page_size - 1) // page_size @@ -479,37 +432,47 @@ def build(self, common_prefix_len: int, paged_kv_last_page_len = seq_lens % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) - + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + cache_dtype) + else: + kv_cache_dtype = self.kv_cache_spec.dtype attn_metadata = FlashInferMetadata( num_actual_tokens=num_actual_tokens, qo_indptr=qo_indptr, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - num_qo_heads=self.runner.num_query_heads, + num_qo_heads=self.vllm_config.model_config.get_num_attention_heads( + self.vllm_config.parallel_config), num_kv_heads=self.kv_cache_spec.num_kv_heads, head_dim=self.kv_cache_spec.head_size, page_size=page_size, - data_type=self.kv_cache_spec.dtype, - q_data_type=self.runner.dtype, - slot_mapping=slot_mapping, - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, - num_prefill_tokens=self._num_prefill_tokens, + kv_data_type=kv_cache_dtype, + q_data_type=self.vllm_config.model_config.dtype, + slot_mapping=common_attn_metadata.slot_mapping, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, use_cascade=use_cascade, shared_qo_indptr=shared_qo_indptr, shared_kv_page_indptr=shared_kv_page_indptr, shared_kv_page_indices=shared_kv_page_indices, shared_kv_last_page_len=shared_kv_last_page_len, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table_tensor=block_table_tensor, + workspace_buffer=self._workspace_buffer, ) - self._plan(attn_metadata) + self._plan(num_prefills, num_decodes, attn_metadata) return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: - if self.kv_cache_spec.dtype != self.runner.model_config.dtype: + if self.kv_cache_spec.dtype != self.vllm_config.model_config.dtype: # TODO: The cascade wrapper currently does not support setting # kv cache dtype to something different from query dtype. return False @@ -527,16 +490,10 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, ) -> None: - if use_irope: - logger.warning_once( - "Using irope in FlashInfer is not supported yet, it will fall" - " back to global attention for long context.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -577,7 +534,11 @@ def forward( query: shape = [num_tokens, num_heads, head_size] key: shape = [num_tokens, num_kv_heads, head_size] value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache = [num_blocks, 2, block_size, num_kv_heads, head_size] + kv_cache: shape - + # NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] + # HND: [num_blocks, 2, num_kv_heads, block_size, head_size] + + attn_metadata: Metadata for attention. Returns: shape = [num_tokens, num_heads * head_size] @@ -623,6 +584,13 @@ def forward( layer._v_scale, ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + window_left = (self.sliding_window[0] if self.sliding_window is not None else -1) @@ -641,6 +609,7 @@ def forward( num_prefill_tokens = attn_metadata.num_prefill_tokens stride_order = FlashInferBackend.get_kv_cache_stride_order() + kv_cache_permute = kv_cache.permute(*stride_order) # Regular attention (common case). # Decodes are at the front and prefills are at the back, # according to reorder_batch() @@ -655,26 +624,60 @@ def forward( assert prefill_wrapper._sm_scale == self.scale prefill_wrapper.run( prefill_query, - kv_cache.permute(*stride_order), + kv_cache_permute, k_scale=layer._k_scale_float, v_scale=layer._v_scale_float, out=output[num_decode_tokens:], ) - if decode_wrapper := attn_metadata.decode_wrapper: decode_query = query[:num_decode_tokens] assert decode_query.shape[0] == num_decode_tokens - assert decode_wrapper is not None - assert decode_wrapper._window_left == window_left - assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap - or 0.0) - assert decode_wrapper._sm_scale == self.scale - decode_wrapper.run( - decode_query, - kv_cache.permute(*stride_order), - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - out=output[:num_decode_tokens], - ) - + if not FlashInferBackend.use_trtllm_decode_attention( + attn_metadata.num_decodes, attn_metadata.max_seq_len, + self.kv_cache_dtype, attn_metadata.num_qo_heads, + attn_metadata.num_kv_heads, attn_metadata.head_dim): + assert decode_wrapper is not None + assert decode_wrapper._window_left == window_left + assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap + or 0.0) + assert decode_wrapper._sm_scale == self.scale + decode_wrapper.run( + decode_query, + kv_cache_permute, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + out=output[:num_decode_tokens], + ) + else: + # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND + if num_decode_tokens > 0: + # decode_query may be non-contiguous + decode_query = decode_query.contiguous() + block_tables_decode = attn_metadata.block_table_tensor[: + num_decode_tokens] + seq_lens_decode = attn_metadata.seq_lens[: + num_decode_tokens] + + assert get_kv_cache_layout() == "HND" + assert decode_query.is_contiguous() + assert kv_cache_permute.is_contiguous() + assert block_tables_decode.is_contiguous() + assert seq_lens_decode.is_contiguous() + + output[:num_decode_tokens] = ( + trtllm_batch_decode_with_kv_cache( + query=decode_query, + kv_cache=kv_cache_permute, + workspace_buffer=attn_metadata.workspace_buffer, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + scale=self.scale, + block_tables=block_tables_decode, + seq_lens=seq_lens_decode, + block_size=attn_metadata.page_size, + max_seq_len=attn_metadata.max_seq_len, + kv_cache_dtype=self.kv_cache_dtype, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, + )) return output_padded diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a8c5f464aa32..ad63f92cd88a 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -3,7 +3,7 @@ """Attention layer with FlashAttention.""" from collections import defaultdict from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import Optional import torch from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, @@ -14,18 +14,15 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner - create_block_mask_compiled = torch.compile(create_block_mask, fullgraph=True, mode="reduce-overhead") @@ -42,6 +39,10 @@ def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor: class FlexAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16, torch.float32] + @classmethod def validate_head_size(cls, head_size: int) -> None: return # FlexAttention supports any head size @@ -257,36 +258,34 @@ def __post_init__(self): class FlexAttentionMetadataBuilder( AttentionMetadataBuilder[FlexAttentionMetadata]): - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - model_config = runner.model_config - - self.runner = runner - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + + self.num_heads_q = self.model_config.get_num_attention_heads( + vllm_config.parallel_config) + self.num_heads_kv = self.model_config.get_num_kv_heads( + vllm_config.parallel_config) + self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_table = block_table + self.device = device - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlexAttentionMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -296,17 +295,15 @@ def build(self, common_prefix_len: int, raise NotImplementedError("Not yet my friend") block_size = self.kv_cache_spec.block_size - max_possible_seq_len = self.runner.model_config.max_model_len - total_cache_tokens = (self.runner.cache_config.num_gpu_blocks * - block_size) + max_possible_seq_len = self.model_config.max_model_len + total_cache_tokens = self.cache_config.num_gpu_blocks * block_size inverse_block_table = physical_to_logical_mapping( - block_table_tensor, self.runner.cache_config.num_gpu_blocks) + block_table_tensor, self.cache_config.num_gpu_blocks) # Get the original offset tensor - offset_tensor = torch.tensor( - self.runner.input_batch.num_computed_tokens_cpu[:num_reqs]).to( - self.runner.device, non_blocking=True) + offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to( + self.device, non_blocking=True) out = FlexAttentionMetadata( num_actual_tokens=num_actual_tokens, @@ -345,15 +342,10 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, ) -> None: - if blocksparse_params is not None: - # TODO we should support this :think - raise ValueError( - "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 74d619aadbdc..dca5de46c065 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -1,32 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig, get_layers_from_vllm_config -from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _query_start_loc_to_chunk_indices_offsets) -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) -from vllm.v1.kv_cache_interface import MambaSpec -from vllm.v1.worker.block_table import BlockTable +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner -def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int: - from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 - layers = get_layers_from_vllm_config(vllm_config, MambaMixer2) - chunk_sizes = set(layer.chunk_size for layer in layers.values()) - assert len( - chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size" - return chunk_sizes.pop() +def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, + chunk_size: int, + total_seqlens: int): + + cu_seqlens = query_start_loc[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, + dtype=torch.int, + device=query_start_loc.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=query_start_loc.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size + > 0) + + # adjust indices and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets class Mamba2AttentionBackend(AttentionBackend): @@ -53,82 +78,33 @@ class Mamba2AttentionMetadata: chunk_offsets: torch.Tensor state_indices_tensor: torch.Tensor # shape: [batch,] + nums_dict: Optional[dict] = None + cu_seqlen: Optional[int] = None + batch_ptr: Optional[torch.tensor] = None + token_chunk_offset_ptr: Optional[torch.tensor] = None class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec, - block_table: BlockTable): - self.runner = runner + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec - self.block_table = block_table - self.chunk_size = get_mamba2_chunk_size(runner.vllm_config) + self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() + assert self.chunk_size is not None, ( + "chunk_size needs to be set in the model config for Mamba2 models") def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - # NOTE (Chen): Copied from MLACommonMetadataBuilder and - # FlashInferMetadataBuilder. Should be refactored later to avoid code - # duplication of these 3 functions. - # We now want to reorder the batch so that the "decode" requests are and - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the decode run only supports num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - - return modified_batch - - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> Mamba2AttentionMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -140,29 +116,31 @@ def build(self, common_prefix_len: int, has_initial_states = None prep_initial_states = False - state_indices_tensor = self.block_table.block_table[:num_reqs, 0] + state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only - if self._num_prefills > 0: + if num_prefills > 0: #[batch,] has_initial_states_cpu = ( - self.runner.input_batch. - num_computed_tokens_cpu_tensor[num_reqs - - self._num_prefills:num_reqs] - > 0) + common_attn_metadata. + num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) prep_initial_states = torch.any(has_initial_states_cpu).item() has_initial_states = has_initial_states_cpu.to( query_start_loc.device) query_start_loc_p = common_attn_metadata.query_start_loc[ - -self._num_prefills - 1:] - self._num_decode_tokens - - seq_idx = torch.repeat_interleave( - torch.arange(self._num_prefills, - dtype=torch.int32, - device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=self._num_prefill_tokens) + -num_prefills - 1:] - num_decode_tokens + + seq_idx = torch.repeat_interleave(torch.arange( + num_prefills, + dtype=torch.int32, + device=query_start_loc_p.device), + query_start_loc_p.diff(), + output_size=num_prefill_tokens) seq_idx.unsqueeze_(0) # We compute metadata for chunked prefill once at the top level @@ -172,13 +150,13 @@ def build(self, common_prefix_len: int, chunk_indices, chunk_offsets = ( _query_start_loc_to_chunk_indices_offsets( query_start_loc_p, self.chunk_size, - self._num_prefill_tokens)) + num_prefill_tokens)) attn_metadata = Mamba2AttentionMetadata( - num_prefills=self._num_prefills, - num_prefill_tokens=self._num_prefill_tokens, - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, seq_lens=seq_lens, has_initial_states=has_initial_states, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py old mode 100644 new mode 100755 index f2aaf59a40f8..cf17d9330239 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -189,11 +189,12 @@ import functools from abc import abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, AttentionMetadata, @@ -201,16 +202,18 @@ from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) from vllm.platforms import current_platform from vllm.utils import cdiv, round_down -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + get_per_layer_parameters, infer_global_hyperparameters, + reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable try: from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -221,13 +224,22 @@ from flash_attn import flash_attn_varlen_func is_vllm_fa = False +try: + from flashinfer import BatchPrefillWithRaggedKVCacheWrapper + from flashinfer.prefill import ( # noqa: F401 + cudnn_batch_prefill_with_kv_cache) + flashinfer_available = True +except ImportError: + flashinfer_available = False + if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) +CUDNN_WORKSPACE_SIZE = 12800 + class MLACommonBackend(AttentionBackend): @@ -254,6 +266,10 @@ def get_kv_cache_shape( ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [576] @@ -282,6 +298,7 @@ class ChunkedContextMetadata: starts: torch.Tensor seq_tot: list[int] max_seq_lens: list[int] + seq_lens: torch.Tensor workspace: torch.Tensor block_table: torch.Tensor @@ -290,6 +307,24 @@ class ChunkedContextMetadata: chunked_context: Optional[ChunkedContextMetadata] = None +@dataclass +class FlashInferPrefillMetadata(MLACommonPrefillMetadata): + prefill_main: Optional['BatchPrefillWithRaggedKVCacheWrapper'] = None + prefill_chunks: list['BatchPrefillWithRaggedKVCacheWrapper'] = field( + default_factory=list) + + +@dataclass +class CudnnPrefillMetadata(MLACommonPrefillMetadata): + + class ChunkedContextMetadata( + MLACommonPrefillMetadata.ChunkedContextMetadata): + seq_lens: torch.Tensor + + query_seq_lens: Optional[torch.Tensor] = None + cudnn_workspace: Optional[torch.Tensor] = None + + @dataclass class MLACommonDecodeMetadata: block_table: torch.Tensor @@ -314,6 +349,9 @@ class MLACommonMetadata(Generic[D]): # |-------------------- seq_len ---------------------| # |-- query_len ---| + num_reqs: int + max_query_len: int + num_actual_tokens: int # Number of tokens excluding padding. query_start_loc: torch.Tensor slot_mapping: torch.Tensor @@ -328,7 +366,9 @@ class MLACommonMetadata(Generic[D]): head_dim: Optional[int] = None decode: Optional[D] = None - prefill: Optional[MLACommonPrefillMetadata] = None + prefill: Optional[Union[MLACommonPrefillMetadata, + FlashInferPrefillMetadata, + CudnnPrefillMetadata]] = None def __post_init__(self): if self.head_dim is not None: @@ -338,6 +378,26 @@ def __post_init__(self): M = TypeVar("M", bound=MLACommonMetadata) +def use_flashinfer_prefill() -> bool: + if flashinfer_available and not envs.VLLM_USE_CUDNN_PREFILL: + # For blackwell default to flashinfer prefill if its available since + # its faster than FA2. + return current_platform.has_device_capability(100) + return False + + +def use_cudnn_prefill() -> bool: + if flashinfer_available and envs.VLLM_USE_CUDNN_PREFILL: + return current_platform.has_device_capability(100) + return False + + +# Currently 394MB, this can be tuned based on GEMM sizes used. +# Chosen to be the same as sglang: +# https://github.com/sgl-project/sglang/blob/766392c6bda2558b61ce6d1c1bfd8081a549e1f1/python/sglang/global_config.py#L37 +FLASHINFER_WORKSPACE_BUFFER_SIZE = 394 * 1024 * 1024 + + class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -345,22 +405,23 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ def __init__(self, - runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable, + vllm_config: VllmConfig, + device: torch.device, metadata_cls: Optional[type[M]] = None): self.metadata_cls = metadata_cls \ if metadata_cls is not None else MLACommonMetadata - self.runner = runner - scheduler_config = runner.scheduler_config - model_config = runner.model_config - cache_config = runner.cache_config + self.kv_cache_spec = kv_cache_spec + self.device = device + scheduler_config = vllm_config.scheduler_config + self.model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + parallel_config = vllm_config.parallel_config self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - self.num_heads = model_config.get_num_attention_heads( - runner.parallel_config) - self.mla_dims = get_mla_dims(model_config) + self.num_heads = self.model_config.get_num_attention_heads( + parallel_config) + self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() - self.kv_cache_spec = kv_cache_spec # Dont try to access the runner on AMD if self.aot_schedule: @@ -371,7 +432,7 @@ def __init__(self, # Max sure there is enough for 8 full length request or at least # 4 pages of cache per request max( - 8 * model_config.max_model_len, 4 * + 8 * self.model_config.max_model_len, 4 * scheduler_config.max_num_seqs * cache_config.block_size), # For long-context models try not to over-allocate limiting # kv-cache space, limiting it to 64k tokens, @@ -386,71 +447,121 @@ def __init__(self, scheduler_config.max_num_seqs * cache_config.block_size self.chunked_prefill_workspace = torch.empty( (self.chunked_prefill_workspace_size, - model_config.get_head_size()), - dtype=model_config.dtype, - device=runner.device, + self.model_config.get_head_size()), + dtype=self.model_config.dtype, + device=device, + ) + + self._use_cudnn_prefill = use_cudnn_prefill() + self._use_fi_prefill = use_flashinfer_prefill() + self.prefill_metadata_cls = ( + FlashInferPrefillMetadata + if self._use_fi_prefill else CudnnPrefillMetadata + if self._use_cudnn_prefill else MLACommonPrefillMetadata) + + if self._use_fi_prefill: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=device) + + self._fi_prefill_main: Optional[ + BatchPrefillWithRaggedKVCacheWrapper] = None + self._fi_prefill_chunks: list[ + BatchPrefillWithRaggedKVCacheWrapper] = [] + + self._global_hyperparameters = infer_global_hyperparameters( + get_per_layer_parameters(vllm_config, MLACommonImpl)) + + if self._use_cudnn_prefill: + self.cudnn_workspace = torch.empty( + CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs, + dtype=torch.int8, + device=device, ) - self.block_table = block_table + + def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): + qo_indptr = prefill.query_start_loc + + has_context = False + if prefill.chunked_context is not None: + chunked_context = prefill.chunked_context + has_context = True + + if self._fi_prefill_main is None: + self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass") + + if has_context: + num_chunks = chunked_context.cu_seq_lens.shape[0] + # Allocate more prefill chunk wrappers if needed + if len(self._fi_prefill_chunks) < num_chunks: + for _ in range(len(self._fi_prefill_chunks), num_chunks): + self._fi_prefill_chunks.append( + BatchPrefillWithRaggedKVCacheWrapper( + self._workspace_buffer, "NHD", backend="cutlass")) + assert num_chunks <= len(self._fi_prefill_chunks) + + # In MLA, the non-latent num_qo_heads == num_kv_heads + num_qo_heads = self.num_heads + num_kv_heads = num_qo_heads + + # Sanity: Verify that num_kv_heads == 1 since it is latent space + assert self.kv_cache_spec.num_kv_heads == 1 + + # Get non-latent head_dim_qk and head_dim_vo + head_dim_qk = (self.mla_dims.qk_nope_head_dim + + self.mla_dims.qk_rope_head_dim) + head_dim_vo = self.mla_dims.v_head_dim + + # For main run, qo_indptr == kv_indptr + kv_indptr = qo_indptr.clone() + + # Prepare main prefill + self._fi_prefill_main.plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=True, # This is main run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters.logits_soft_cap, + q_data_type=self.model_config.dtype, + kv_data_type=self.kv_cache_spec.dtype, + ) + + # Prepare context prefills + if has_context: + for i in range(num_chunks): + kv_indptr_chunk = chunked_context.cu_seq_lens[i] + + self._fi_prefill_chunks[i].plan( + qo_indptr=qo_indptr, + kv_indptr=kv_indptr_chunk, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim_qk=head_dim_qk, + head_dim_vo=head_dim_vo, + causal=False, # This is context run + sm_scale=self._global_hyperparameters.sm_scale, + window_left=self._global_hyperparameters.window_left, + logits_soft_cap=self._global_hyperparameters. + logits_soft_cap, + q_data_type=self.model_config.dtype, + kv_data_type=self.kv_cache_spec.dtype, + ) + + prefill.prefill_main = self._fi_prefill_main + prefill.prefill_chunks = self._fi_prefill_chunks def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - # We now want to reorder the batch so that the "decode" requests are and - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - num_decode_tokens = 0 - num_prefill_tokens = 0 - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - # for now treat 1 scheduled token as "decode" even if its not, - # we should update this to something like < 8 in the future but - # currently the TritonMLA._forward_decode only supports - # num_tokens = 1 - if num_tokens == 1: - decodes.append(i) - num_decode_tokens += num_tokens - else: - prefills.append(i) - num_prefill_tokens += num_tokens - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - decode_idx = decodes[num_decodes - i] - if decode_idx < num_decodes: - break - - input_batch.swap_states(prefills[i - 1], decode_idx) - modified_batch = True - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - self._num_decodes = num_decodes - self._num_prefills = num_prefills - self._num_decode_tokens = num_decode_tokens - self._num_prefill_tokens = num_prefill_tokens - - return modified_batch + return reorder_batch_to_split_decodes_and_prefills(input_batch, + scheduler_output, + decode_threshold=1) def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor): @@ -472,49 +583,50 @@ def build_for_cudagraph_capture( m.max_query_len = 1 # decode-only - # Update state usually set in reorder_batch. - self._num_decodes = m.num_reqs - self._num_decode_tokens = m.num_actual_tokens - self._num_prefills = 0 - self._num_prefill_tokens = 0 return self.build(0, m) - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> M: num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens + num_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - assert self._num_decodes + self._num_prefills == num_reqs - # Note(simon): be careful about the CPU <> GPU memory movement in this # function. We should avoid GPU -> CPU sync as much as possible because # it blocks on all previous kernels. - device = self.runner.device - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping query_start_loc = common_attn_metadata.query_start_loc + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + + num_computed_tokens_cpu = (common_attn_metadata.seq_lens_cpu - + query_seq_lens_cpu) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills(common_attn_metadata) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + prefill_metadata = None - if self._num_prefills > 0: - reqs_start = self._num_decodes # prefill_start + if num_prefills > 0: + reqs_start = num_decodes # prefill_start - context_lens_cpu = self.runner.input_batch.\ - num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] chunked_context_metadata = None - if self.chunked_prefill_enabled and self._num_prefills > 0 \ + if self.chunked_prefill_enabled and num_prefills > 0 \ and max_context_len_cpu > 0: # NOTE: it is recommend you read the `Chunked Prefill` section # in the comment at the top of the file before trying to @@ -545,14 +657,14 @@ def build(self, common_prefix_len: int, # of `to_list`. chunk_starts = \ torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ + .unsqueeze(1).expand(-1, num_prefills) \ * max_context_chunk chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk) chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) cu_seq_lens_cpu = torch.zeros(num_chunks, - self._num_prefills + 1, + num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(chunk_seq_lens, @@ -560,45 +672,68 @@ def build(self, common_prefix_len: int, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32) + chunked_context_metadata_cls = \ + CudnnPrefillMetadata.ChunkedContextMetadata \ + if self._use_cudnn_prefill else \ + MLACommonPrefillMetadata.ChunkedContextMetadata + chunked_context_metadata = \ - MLACommonPrefillMetadata.ChunkedContextMetadata( + chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), starts=chunk_starts.to(device, non_blocking=True), seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, ) + if self._use_cudnn_prefill: + chunked_context_metadata.seq_lens = chunk_seq_lens + assert max(chunked_context_metadata.max_seq_lens) <= \ self.chunked_prefill_workspace_size - prefill_metadata = MLACommonPrefillMetadata( + prefill_metadata = self.prefill_metadata_cls( block_table=block_table_tensor[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, ) + if self._use_cudnn_prefill: + assert isinstance(prefill_metadata, CudnnPrefillMetadata) + prefill_metadata.query_seq_lens = prefill_query_start_loc[1:] \ + - prefill_query_start_loc[:-1] + prefill_metadata.cudnn_workspace = self.cudnn_workspace + decode_metadata = None - if self._num_decodes > 0: + if num_decodes > 0: decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:self._num_decodes, ...], - seq_lens=seq_lens[:self._num_decodes], + block_table_tensor=block_table_tensor[:num_decodes, ...], + seq_lens=seq_lens[:num_decodes], ) - return self.metadata_cls( - num_actual_tokens=num_actual_tokens, + attn_metadata = self.metadata_cls( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + num_actual_tokens=num_tokens, query_start_loc=query_start_loc, slot_mapping=slot_mapping, - head_dim=self.runner.model_config.get_head_size(), + head_dim=self.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific - num_decodes=self._num_decodes, - num_decode_tokens=self._num_decode_tokens, - num_prefills=self._num_prefills, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, prefill=prefill_metadata, decode=decode_metadata, ) + if self._use_fi_prefill and num_prefills > 0: + assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata) + self._build_fi_prefill_wrappers(attn_metadata.prefill) + + return attn_metadata + def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: return common_attn_metadata.max_query_len == 1 @@ -619,7 +754,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -649,23 +783,40 @@ def __init__( self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj - # Handle the differences between the flash_attn_varlen from flash_attn - # and the one from vllm_flash_attn. The former is used on RoCM and the - # latter has an additional parameter to control FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() - if self.vllm_flash_attn_version is not None: - self.flash_attn_varlen_func = \ - functools.partial(flash_attn_varlen_func, - fa_version=self.vllm_flash_attn_version) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 - self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9) + if use_flashinfer_prefill(): + logger.debug_once("Using FlashInfer prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi + self._pad_v = False + elif use_cudnn_prefill(): + logger.debug_once("Using CUDNN prefill for MLA") + self._run_prefill_context_chunk = \ + self._run_prefill_context_chunk_cudnn + self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn + self._pad_v = False + else: # Use FlashAttention + logger.debug_once("Using FlashAttention prefill for MLA") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa + + # Handle the differences between the flash_attn_varlen from + # flash_attn and the one from vllm_flash_attn. The former is used on + # RoCM and the latter has an additional parameter to control + # FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + if self.vllm_flash_attn_version is not None: + self.flash_attn_varlen_func = \ + functools.partial(flash_attn_varlen_func, + fa_version=self.vllm_flash_attn_version) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + self._pad_v = self.vllm_flash_attn_version is None or not ( + self.vllm_flash_attn_version == 3 + and current_platform.get_device_capability()[0] == 9) def _flash_attn_varlen_diff_headdims(self, q, @@ -705,6 +856,105 @@ def _flash_attn_varlen_diff_headdims(self, return attn_out, lse return attn_out + def _run_prefill_new_tokens_fa(self, prefill: MLACommonPrefillMetadata, q, + k, v, return_softmax_lse): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + ) + + def _run_prefill_new_tokens_fi(self, prefill: MLACommonPrefillMetadata, q, + k, v, return_softmax_lse): + assert isinstance(prefill, FlashInferPrefillMetadata) + assert prefill.prefill_main is not None + return prefill.prefill_main.run( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, + ) + + def _run_prefill_new_tokens_cudnn(self, prefill: MLACommonPrefillMetadata, + q, k, v, return_softmax_lse): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.query_seq_lens is not None + output, lse = cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.max_query_len, + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), + causal=True, + return_lse=True, # do not support False for now + is_cuda_graph_compatible= + True, #Indicates actual_seq_lens are on GPU or CPU. + ) + if return_softmax_lse: + return output, lse + return output + + def _run_prefill_context_chunk_fa(self, prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + ) + + def _run_prefill_context_chunk_fi(self, prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert isinstance(prefill, FlashInferPrefillMetadata) + return prefill.prefill_chunks[chunk_idx].run( + q=q, + k=k, + v=v, + return_lse=True, + ) + + def _run_prefill_context_chunk_cudnn(self, + prefill: MLACommonPrefillMetadata, + chunk_idx: int, q, k, v): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + assert prefill.query_seq_lens is not None + return cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx]. + view(-1, 1, 1, 1), + causal=False, + return_lse=True, + is_cuda_graph_compatible= + True, #Indicates actual_seq_lens are on GPU or CPU. + ) + def _v_up_proj(self, x): # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) @@ -803,18 +1053,12 @@ def _compute_prefill_context( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output, attn_softmax_lse = \ - self._flash_attn_varlen_diff_headdims( + attn_output, attn_softmax_lse = self._run_prefill_context_chunk( + prefill=prefill_metadata, + chunk_idx=i, q=q, k=k, v=v, - cu_seqlens_q=prefill_metadata.query_start_loc, - cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], - max_seqlen_q=prefill_metadata.max_query_len, - max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, ) if output is None: @@ -854,16 +1098,11 @@ def _forward_prefill( k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - output = self._flash_attn_varlen_diff_headdims( + output = self._run_prefill_new_tokens( + prefill=attn_metadata.prefill, q=q, k=k, v=v, - cu_seqlens_q=attn_metadata.prefill.query_start_loc, - cu_seqlens_k=attn_metadata.prefill.query_start_loc, - max_seqlen_q=attn_metadata.prefill.max_query_len, - max_seqlen_k=attn_metadata.prefill.max_query_len, - softmax_scale=self.scale, - causal=True, return_softmax_lse=has_context, ) @@ -908,7 +1147,6 @@ def forward( output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." if output_scale is not None: diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index db4b9c9537e5..c787f25cd3ad 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +import os +from typing import Optional import torch @@ -27,6 +28,41 @@ def get_impl_cls() -> type["CutlassMLAImpl"]: return CutlassMLAImpl +class SM100Workspace: + + def __init__(self, initial_workspace_size): + self._workspace_buf = torch.empty(initial_workspace_size, + device="cuda", + dtype=torch.uint8) + + self._block_size = 128 # Forced to 128 + + # Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy + # (assumes all devices are similar) + properties = torch.cuda.get_device_properties(torch.device("cuda:0")) + self._sm_count = properties.multi_processor_count + + def get_buf(self): + return self._workspace_buf + + def ensure_size(self, attn_metadata: MLACommonMetadata, + num_kv_splits: int): + batch_size = attn_metadata.num_reqs + max_seq_len = attn_metadata.max_query_len + + workspace_size = ops.sm100_cutlass_mla_get_workspace_size( + max_seq_len * self._block_size, + batch_size, + self._sm_count, + num_kv_splits=num_kv_splits) + + if self._workspace_buf.shape[0] < workspace_size: + self._workspace_buf.resize_(workspace_size) + + +g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB + + class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): def __init__( @@ -38,7 +74,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -46,17 +81,14 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "CutlassMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -68,7 +100,137 @@ def __init__( raise NotImplementedError( "CutlassMLA V1 with FP8 KV cache not yet supported") - def _forward_decode( + self._use_old_cutlass_mla = False + force_old_cutlass = os.environ.get("FORCE_OLD_CUTLASS_MLA", None) + if force_old_cutlass: + logger.warning("Forcing old cutlass mla kernel") + self._use_old_cutlass_mla = True + + # TODO: Currently, num_kv_splits is limited to 16 to avoid hanging + # issues. In case the code hangs, use: + # FORCE_NUM_KV_SPLITS=1 + force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None) + if force_num_kv_splits: + logger.warning("Forcing num_kv_splits to %d", + int(force_num_kv_splits)) + self._num_kv_splits = int(force_num_kv_splits) + else: + self._num_kv_splits = -1 # => Auto-detect + + # Share workspace buffer across all executions + self._workspace = g_sm100_workspace + + def _sm100_cutlass_mla_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + workspace: torch.Tensor, + sm_scale: float, + num_kv_splits: int, + ) -> torch.Tensor: + assert (q_nope.ndim == 3 + ), f"q_nope must be a 3D tensor, but got {q_nope.ndim}" + assert ( + q_pe.ndim == 3), f"q_pe must be a 3D tensor, but got {q_pe.ndim}" + assert ( + kv_c_and_k_pe_cache.ndim == 3 + ), "kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format( + kv_c_and_k_pe_cache.ndim) + + B_q, H, D_q_nope = q_nope.shape + B_q_2, H_2, D_q_pe = q_pe.shape + assert (B_q == B_q_2) and (H == H_2) + + _, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape + + D_latent = 512 + D_rope = 64 + assert D_q_nope == D_latent + assert D_q_pe == D_rope + assert D_ckv == D_latent + D_rope + + MAX_HEADS = 128 + assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" + if H < MAX_HEADS: + q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope)) + q_nope_padded[:, :H] = q_nope + q_nope = q_nope_padded + + q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe)) + q_pe_padded[:, :H] = q_pe + q_pe = q_pe_padded + + assert len(page_table.shape) == 2 + B_block_table, block_num = page_table.shape + assert B_block_table == B_q + assert (block_num + > 0), f"block num must be greater than 0, got {block_num}" + assert block_num % (128 / PAGE_SIZE) == 0 + + # TODO(kaixih@nvidia): support fp8 + assert q_nope.dtype in ( + torch.float16, + torch.bfloat16, + ), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}." + assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype + assert ( + seq_lens.dtype == torch.int32 + ), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}." + assert ( + page_table.dtype == torch.int32 + ), f"page_table.dtype needs to be int32 but got {page_table.dtype}." + + out = q_nope.new_empty((B_q, MAX_HEADS, D_latent)) + + ops.sm100_cutlass_mla_decode( + out, + q_nope, + q_pe, + kv_c_and_k_pe_cache, + seq_lens, + page_table, + workspace, + sm_scale, + num_kv_splits, + ) + return out[:, :H].contiguous() + + def _sm100_forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError("FP8 Cutlass MLA not yet supported") + + # Adjust workspace size (if necessary) + self._workspace.ensure_size(attn_metadata, self._num_kv_splits) + + # Run MLA + # Clone q_nope and q_pe to make sure strides computation is correct. + # TODO: Check if we really need it + q_nope = q_nope.clone() + q_pe = q_pe.clone() + + o = self._sm100_cutlass_mla_decode(q_nope, q_pe, kv_c_and_k_pe_cache, + attn_metadata.decode.seq_lens, + attn_metadata.decode.block_table, + self._workspace.get_buf(), + self.scale, self._num_kv_splits) + + return self._v_up_proj(o) + + # TODO: Currently we leave it here only for backup in case something is + # wrong with the new SM100 CUTLASS MLA kernel + def _old_forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, @@ -91,8 +253,25 @@ def _forward_decode( # Clone q_nope and q_pe to make sure strides computation is correct. q_nope = q_nope.clone() q_pe = q_pe.clone() + ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache, attn_metadata.decode.seq_lens, attn_metadata.decode.block_table, self.scale) return self._v_up_proj(o) + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: MLACommonMetadata, + ) -> torch.Tensor: + if self._use_old_cutlass_mla: + # TODO: Remove the old cutlass MLA kernel after more extensive + # testing + return self._old_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, + attn_metadata) + + return self._sm100_forward_decode(q_nope, q_pe, kv_c_and_k_pe_cache, + attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index be26e0060db5..d3e5300dbbd6 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import torch @@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -18,7 +19,6 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable logger = init_logger(__name__) @@ -56,12 +56,13 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # Decode-only - def __init__(self, runner, kv_cache_spec: AttentionSpec, - block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table, FlashMLAMetadata) + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + super().__init__(kv_cache_spec, vllm_config, device, FlashMLAMetadata) - self.num_q_heads = self.runner.model_config.get_num_attention_heads( - self.runner.parallel_config) + self.compilation_config = vllm_config.compilation_config + self.num_q_heads = vllm_config.model_config.get_num_attention_heads( + vllm_config.parallel_config) self.cg_buf_tile_scheduler_metadata = None self.cg_buf_num_splits = None @@ -75,7 +76,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, 1, # MQA for the decode path ) - if self.runner.full_cuda_graph: + if self.compilation_config.full_cuda_graph: # First time around (CUDAGraph capture), allocate the static buffer if self.cg_buf_tile_scheduler_metadata is None: self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata @@ -118,7 +119,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -126,20 +126,17 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) assert is_flashmla_supported(), \ "FlashMLA is not supported on this device" - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "FlashMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index d5f9dfaea065..834c23455835 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -2,12 +2,14 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional import torch import vllm.envs as envs from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +from vllm.config import VllmConfig +from vllm.utils import cdiv # yapf conflicts with isort for this docstring # yapf: disable from vllm.v1.attention.backends.mla.common import (MLACommonBackend, @@ -16,7 +18,6 @@ MLACommonMetadata, MLACommonMetadataBuilder) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable # yapf: enable @@ -65,24 +66,26 @@ class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): full_cudagraph_supported: ClassVar[bool] = True # decode only - def __init__(self, runner, kv_cache_spec: AttentionSpec, - block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table, AiterMLAMetadata) + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + super().__init__(kv_cache_spec, vllm_config, device, AiterMLAMetadata) assert self.kv_cache_spec.block_size == 1, "AITER MLA" \ "only supports block size 1." + self.compilation_config = vllm_config.compilation_config + max_num_pages_per_req = cdiv(vllm_config.model_config.max_model_len, + self.kv_cache_spec.block_size) + max_num_reqs = vllm_config.scheduler_config.max_num_seqs + max_num_pages = max_num_reqs * max_num_pages_per_req + # Preparing persistent buffers - if self.runner.full_cuda_graph: - device = self.runner.device - max_num_reqs = self.runner.max_num_reqs + if vllm_config.compilation_config.full_cuda_graph: self.paged_kv_indptr = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) - self.paged_kv_indices = torch.zeros( - block_table.get_device_tensor().numel( - ), # max num pages possible - dtype=torch.int32, - device=device) + self.paged_kv_indices = torch.zeros(max_num_pages, + dtype=torch.int32, + device=device) self.paged_kv_last_page_len = torch.zeros(max_num_reqs, dtype=torch.int32, device=device) @@ -96,7 +99,8 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens + page_size - 1) // page_size - device = self.runner.device + device = self.device + num_reqs = seq_lens.size(0) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, @@ -113,8 +117,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, block_table_bounds.cumsum(dim=0, dtype=torch.int32) ]) - if self.runner.full_cuda_graph: - num_reqs = self._num_decodes + if self.compilation_config.full_cuda_graph: num_actual_pages = paged_kv_indices.size(0) @@ -137,7 +140,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, else: qo_indptr = torch.arange(0, - self._num_decodes + 1, + num_reqs + 1, step=1, dtype=torch.int32, device=device) @@ -164,7 +167,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -172,20 +174,17 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) assert (num_heads == 16 or num_heads == 128), ( f"Aiter MLA only supports 16 or 128 number of heads.\n" f"Provided {num_heads} number of heads.\n" "Try adjusting tensor_parallel_size value.") - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "Aiter MLA does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") from aiter import flash_attn_varlen_func self.flash_attn_varlen_func = flash_attn_varlen_func diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 99938f22f108..700fce68953e 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Optional import torch @@ -42,7 +42,6 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, kv_sharing_target_layer_name: Optional[str], @@ -50,17 +49,14 @@ def __init__( **mla_args) -> None: super().__init__(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap, attn_type, + logits_soft_cap, attn_type, kv_sharing_target_layer_name, **mla_args) - unsupported_features = [ - alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap - ] + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] if any(unsupported_features): raise NotImplementedError( "TritonMLAImpl does not support one of the following: " - "alibi_slopes, sliding_window, blocksparse_params, " - "logits_soft_cap") + "alibi_slopes, sliding_window, logits_soft_cap") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 253d79d925ce..9b122136afb7 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch import torch_xla.core.xla_builder as xb @@ -24,6 +24,19 @@ # TPU requires the head size to be a multiple of 128. TPU_HEAD_SIZE_ALIGNMENT = 128 +# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8 +# from to fp32 directly. That's why it has a dtype mapping different from GPU +TPU_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.float8_e4m3fn, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + "int8": torch.int8, + "uint8": torch.uint8, +} + class PallasAttentionBackend(AttentionBackend): @@ -86,6 +99,12 @@ def get_max_num_seqs(model_len: int, page_size: int) -> int: # spill less likely. Meanwhile we make sure the page size is in [16, 256]. @staticmethod def get_page_size(vllm_config: VllmConfig) -> int: + # TODO: This is a temporary fix for vmem OOM. + # For long model length, we use 16 page-size to avoid too much + # VMEM spill. A more robust solution should be implemented to + # handle VREG spills. + if vllm_config.model_config.max_model_len > 8192: + return 16 page_size = next_power_of_2( vllm_config.model_config.max_model_len) // 16 if page_size <= 16: @@ -126,19 +145,10 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, ) -> None: - if use_irope: - logger.warning_once( - "Using irope in Pallas is not supported yet, it will fall back " - "to global attention for long context.") - if blocksparse_params is not None: - raise ValueError("Paged attention Pallas kernel does " - "not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -150,10 +160,6 @@ def __init__( self.num_queries_per_kv = self.num_heads // self.num_kv_heads if alibi_slopes is not None: raise NotImplementedError("Alibi slopes is not supported.") - if kv_cache_dtype != "auto": - raise NotImplementedError("FP8 KV cache dtype is not supported.") - if blocksparse_params is not None: - raise NotImplementedError("Blocksparse is not supported.") if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -161,9 +167,10 @@ def __init__( "are not implemented for " "PallasAttentionBackendImpl") - tpu_version = torch_xla.tpu.version() - if tpu_version < 4: - raise NotImplementedError("TPU version must be 4 or higher.") + self.kv_cache_quantized_dtype = None + if kv_cache_dtype != "auto": + self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( + kv_cache_dtype.lower().strip()) def forward( self, @@ -198,7 +205,6 @@ def forward( output = torch.ones_like(query) return output - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape query = query.view(num_tokens, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) @@ -219,10 +225,21 @@ def forward( # Skip this if sharing KV cache with an earlier attention layer. slot_mapping = attn_metadata.slot_mapping write_to_kv_cache( - key, value, kv_cache, slot_mapping, + key, + value, + kv_cache, + slot_mapping, attn_metadata.num_slices_per_kv_cache_update_block, - attn_metadata.num_kv_update_slices) - + attn_metadata.num_kv_update_slices, + self.kv_cache_quantized_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + + if self.kv_cache_quantized_dtype is not None and ( + layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0): + raise ValueError( + "k_scale_float and v_scale_float must be non-zero") output = torch.ops.xla.ragged_paged_attention( query, kv_cache, @@ -240,6 +257,8 @@ def forward( sm_scale=self.scale, sliding_window=self.sliding_window, soft_cap=self.logits_soft_cap, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, ) if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: @@ -255,18 +274,32 @@ def write_to_kv_cache( slot_mapping: torch.Tensor, num_slices_per_kv_cache_update_block: int, num_kv_update_slices: torch.Tensor, + kv_cache_quantized_dtype: Optional[torch.dtype] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, ) -> None: """ Write the key and values to the KV cache. Args: - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size] num_slices_per_kv_cache_update_block: int """ _, page_size, num_combined_kv_heads, head_size = kv_cache.shape head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + + if kv_cache_quantized_dtype is not None: + dtype_info = torch.finfo(kv_cache_quantized_dtype) + key = key.to(torch.float32) / k_scale + # NOTE: clamp is added here to avoid out of range of quantized dtype + key = torch.clamp(key, dtype_info.min, dtype_info.max) + key = key.to(kv_cache_quantized_dtype) + value = value.to(torch.float32) / v_scale + value = torch.clamp(value, dtype_info.min, dtype_info.max) + value = value.to(kv_cache_quantized_dtype) + kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) @@ -318,3 +351,56 @@ def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, page_size: int, num_slices_per_block: int) -> torch.Tensor: return kv_cache + + +# We can move this function to a common utils file if it's also useful for other +# hardware. +def dtype_bits(dtype: torch.dtype): + if dtype.is_floating_point: + try: + return torch.finfo(dtype).bits + except TypeError: + pass + elif dtype.is_complex: + if dtype is torch.complex32: + return 32 + elif dtype is torch.complex64: + return 64 + elif dtype is torch.complex128: + return 128 + else: + try: + return torch.iinfo(dtype).bits + # torch.iinfo cannot support int4, int2, bits8... + except TypeError: + pass + str_dtype = str(dtype) + # support torch.int4, torch.int5, torch.uint5... + if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"): + return int(str_dtype[-1]) + raise TypeError(f"Getting the bit width of {dtype} is not supported") + + +def get_dtype_packing(dtype): + bits = dtype_bits(dtype) + if 32 % bits != 0: + raise ValueError( + f"The bit width must be divisible by 32, but got bits={bits}, " + "dtype={dtype}") + return 32 // bits + + +def get_page_size_bytes(block_size: int, num_kv_heads: int, head_size: int, + kv_cache_dtype: torch.dtype) -> int: + """Returns the size in bytes of one page of the KV cache.""" + padded_head_size = cdiv(head_size, + TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT + num_combined_kv_heads = num_kv_heads * 2 + + # NOTE: for the implicit padding in XLA + packing = get_dtype_packing(kv_cache_dtype) + num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing + + kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) + return (block_size * num_combined_kv_heads * padded_head_size * + kv_cache_dtype_bits // 8) diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 6a78b03dce86..0739d2596676 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with AiterFlashAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import Optional import torch @@ -10,18 +10,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType, is_quantized_kv_cache) +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import ( - make_local_attention_virtual_batches) from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch - from vllm.v1.worker.gpu_model_runner import GPUModelRunner if current_platform.is_rocm(): import aiter @@ -172,110 +165,52 @@ def flash_attn_varlen_func_fake( class AiterFlashAttentionMetadataBuilder: - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - model_config = runner.model_config - - self.runner = runner - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.parallel_config = vllm_config.parallel_config + self.cache_config = vllm_config.cache_config + self.device = device + + self.num_heads_q = self.model_config.get_num_attention_heads( + self.parallel_config) + self.num_heads_kv = self.model_config.get_num_kv_heads( + self.parallel_config) + self.headdim = self.model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_table = block_table # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: + def reorder_batch(self, input_batch, scheduler_output) -> bool: return False - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> 'AiterFlashAttentionMetadata': - num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) - total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) + total_tokens = int(common_attn_metadata.seq_lens_cpu.sum()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - - slot_mapping = block_table.slot_mapping[:num_actual_tokens] + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1, dtype=torch.int32, - device="cuda") + device=self.device) torch.cumsum(seq_lens, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]) - def schedule(batch_size, cu_query_lens, max_query_len, seqlens, - max_seq_len, causal): - return None - - # for local attention - local_attn_metadata = None - if self.runner.attention_chunk_size is not None: - seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table_tensor = make_local_attention_virtual_batches( - self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], - block_table_tensor, - self.block_size, - ) - local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.runner.device, non_blocking=True) - local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True) - local_max_query_len = int(seqlens_q_local_np.max()) - local_max_seq_len = int(virt_k_seqlens_np.max()) - local_scheduler_metadata = schedule( - batch_size=local_query_start_loc.shape[0] - 1, - cu_query_lens=local_query_start_loc, - max_query_len=local_max_query_len, - seqlens=local_seqused_k, - max_seq_len=local_max_seq_len, - causal=True) - - local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1, - dtype=torch.int32, - device=self.runner.device) - local_cu_seq_lens[1:] = torch.cumsum( - torch.from_numpy(virt_k_seqlens_np).to( - device=self.runner.device, - dtype=torch.int32, - non_blocking=True), - dim=0) - - - local_attn_metadata = \ - AiterFlashAttentionMetadata.LocalAttentionMetadata( - local_query_start_loc=local_query_start_loc, - local_seqused_k=local_seqused_k, - local_block_table=virt_block_table_tensor, - local_max_query_len=local_max_query_len, - local_max_seq_len=local_max_seq_len, - local_cu_seq_lens=local_cu_seq_lens, - local_scheduler_metadata=local_scheduler_metadata, - ) - use_cascade = common_prefix_len > 0 cu_prefix_query_lens = None @@ -297,7 +232,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, - local_attn_metadata=local_attn_metadata, ) return attn_metadata @@ -314,6 +248,10 @@ class AiterFlashAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @@ -384,19 +322,6 @@ class AiterFlashAttentionMetadata: prefix_kv_lens: Optional[torch.Tensor] suffix_kv_lens: Optional[torch.Tensor] - # for local attention - @dataclass - class LocalAttentionMetadata: - local_query_start_loc: torch.Tensor - local_seqused_k: torch.Tensor - local_block_table: torch.Tensor - local_max_query_len: int - local_max_seq_len: int - local_cu_seq_lens: torch.Tensor - local_scheduler_metadata: Optional[torch.Tensor] - - local_attn_metadata: Optional[LocalAttentionMetadata] = None - class AiterFlashAttentionImpl(AttentionImpl): @@ -409,15 +334,10 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "AiterFlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -446,7 +366,6 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "FlashAttentionImpl") - self.use_irope = use_irope if is_quantized_kv_cache(self.kv_cache_dtype): raise NotImplementedError( "AiterFlashAttention does not support fp8 kv-cache on this " @@ -528,25 +447,12 @@ def forward( layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) - # Compute attention and update output up to `num_actual_tokens`. - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if not attn_metadata.use_cascade or use_local_attn: - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table + if not attn_metadata.use_cascade: + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table if max_seqlen_q > 1: cu_seq_lens = attn_metadata.cu_seq_lens @@ -564,9 +470,7 @@ def forward( alibi_slopes=self.alibi_slopes, window_size=self.sliding_window, block_table=block_table, - cu_seqlens_k=(cu_seq_lens if not use_local_attn else - local_metadata.local_cu_seq_lens), - ) + cu_seqlens_k=cu_seq_lens) _, num_heads, head_size = query.shape _PARTITION_SIZE_ROCM = 256 diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index cdaff2f6a40f..83471ca51b73 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Optional +from typing import ClassVar, Optional import torch @@ -14,17 +14,13 @@ chunked_prefill_paged_decode) from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, CommonAttentionMetadata, - make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, + CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.v1.worker.block_table import BlockTable - -if TYPE_CHECKING: - from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -58,29 +54,23 @@ class TritonAttentionMetadata: scheduler_metadata: Optional[torch.Tensor] = None prefix_scheduler_metadata: Optional[torch.Tensor] = None - # for local attention - @dataclass - class LocalAttentionMetadata: - local_query_start_loc: torch.Tensor - local_seqused_k: torch.Tensor - local_block_table: torch.Tensor - local_max_query_len: int - local_max_seq_len: int - local_scheduler_metadata: Optional[torch.Tensor] - - local_attn_metadata: Optional[LocalAttentionMetadata] = None - class TritonAttentionMetadataBuilder( AttentionMetadataBuilder[TritonAttentionMetadata]): full_cudagraph_supported: ClassVar[bool] = True - def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, - block_table: BlockTable): - self.runner = runner + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.device = device self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec - self.block_table = block_table + + model_config = vllm_config.model_config + self.num_heads_q = model_config.get_num_attention_heads( + vllm_config.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + vllm_config.parallel_config) + self.headdim = model_config.get_head_size() def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata @@ -92,70 +82,31 @@ def build_for_cudagraph_capture( attn_metadata.seq_lens.fill_(1) return attn_metadata - def build( - self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata - ) -> TritonAttentionMetadata: - num_reqs = common_attn_metadata.num_reqs + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> TritonAttentionMetadata: num_actual_tokens = common_attn_metadata.num_actual_tokens max_query_len = common_attn_metadata.max_query_len - max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + max_seq_len = int(common_attn_metadata.seq_lens_cpu.max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - block_table = self.block_table - block_table_tensor = block_table.get_device_tensor()[:num_reqs] - - block_table.slot_mapping[:num_actual_tokens].copy_( - block_table.slot_mapping_cpu[:num_actual_tokens], - non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache in full cuda graph - # mode. - block_table.slot_mapping[num_actual_tokens:].fill_(-1) - - slot_mapping = block_table.slot_mapping[:num_actual_tokens] - - # for local attention - local_attn_metadata = None - if self.runner.attention_chunk_size is not None: - seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ - virt_block_table_tensor = make_local_attention_virtual_batches( - self.runner.attention_chunk_size, - self.runner.query_start_loc_np[:num_reqs + 1], - self.runner.seq_lens_np[:num_reqs], - block_table_tensor, - self.block_size, - ) - local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( - self.runner.device, non_blocking=True) - local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( - self.runner.device, non_blocking=True) - local_max_query_len = seqlens_q_local_np.max() - local_max_seq_len = virt_k_seqlens_np.max() - - local_attn_metadata = TritonAttentionMetadata \ - .LocalAttentionMetadata( - local_query_start_loc=local_query_start_loc, - local_seqused_k=local_seqused_k, - local_block_table=virt_block_table_tensor, - local_max_query_len=local_max_query_len, - local_max_seq_len=local_max_seq_len, - local_scheduler_metadata=None, - ) + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping use_cascade = common_prefix_len > 0 if use_cascade: cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], dtype=torch.int32, - device=self.runner.device) + device=self.device) prefix_kv_lens = torch.tensor([common_prefix_len], dtype=torch.int32, - device=self.runner.device) - suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + device=self.device) + suffix_kv_lens = (common_attn_metadata.seq_lens_cpu - common_prefix_len) - suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( - self.runner.device) + suffix_kv_lens = suffix_kv_lens.to(self.device) else: cu_prefix_query_lens = None prefix_kv_lens = None @@ -175,7 +126,6 @@ def build( cu_prefix_query_lens=cu_prefix_query_lens, prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, - local_attn_metadata=local_attn_metadata, prefix_scheduler_metadata=prefix_scheduler_metadata, ) return attn_metadata @@ -190,6 +140,10 @@ class TritonAttentionBackend(AttentionBackend): accept_output_buffer: bool = True + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.float16, torch.bfloat16] + @classmethod def get_supported_head_sizes(cls) -> list[int]: return [32, 64, 96, 128, 160, 192, 224, 256] @@ -248,15 +202,10 @@ def __init__( alibi_slopes: Optional[list[float]], sliding_window: Optional[int], kv_cache_dtype: str, - blocksparse_params: Optional[dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, attn_type: AttentionType = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[int] = None, - use_irope: bool = False, ) -> None: - if blocksparse_params is not None: - raise ValueError( - "TritonAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -275,8 +224,6 @@ def __init__( self.logits_soft_cap = logits_soft_cap self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - self.use_irope = use_irope - self.num_queries_per_kv = self.num_heads // self.num_kv_heads TritonAttentionBackend.validate_head_size(head_size) @@ -385,23 +332,11 @@ def forward( layer._q_scale) query = query.reshape((num_tokens, num_heads, head_size)) - use_local_attn = \ - (self.use_irope and attn_metadata.local_attn_metadata is not None) - - if use_local_attn: - assert attn_metadata.local_attn_metadata is not None - local_metadata = attn_metadata.local_attn_metadata - cu_seqlens_q = local_metadata.local_query_start_loc - seqused_k = local_metadata.local_seqused_k - max_seqlen_q = local_metadata.local_max_query_len - max_seqlen_k = local_metadata.local_max_seq_len - block_table = local_metadata.local_block_table - else: - cu_seqlens_q = attn_metadata.query_start_loc - seqused_k = attn_metadata.seq_lens - max_seqlen_q = attn_metadata.max_query_len - max_seqlen_k = attn_metadata.max_seq_len - block_table = attn_metadata.block_table + cu_seqlens_q = attn_metadata.query_start_loc + seqused_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table if use_prefill_decode_attn: # Compute attention and update output up to `num_actual_tokens`. diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index b0ebb00d9e6b..fc8649d587ee 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -4,14 +4,17 @@ import functools from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar +from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar import numpy as np import torch +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.utils import cdiv if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionImpl from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -19,8 +22,10 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger +from vllm.v1.kv_cache_interface import AttentionSpec logger = init_logger(__name__) +_KV_CACHE_LAYOUT_OVERRIDE = None @dataclass @@ -28,14 +33,22 @@ class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. + + For many of the tensors we keep both GPU and CPU versions. """ query_start_loc: torch.Tensor + query_start_loc_cpu: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" + num_computed_tokens_cpu: torch.Tensor + """(batch_size,), the number of computed tokens for each request""" + num_reqs: int """Number of requests""" num_actual_tokens: int @@ -43,6 +56,9 @@ class CommonAttentionMetadata: max_query_len: int """Longest query in batch""" + block_table_tensor: torch.Tensor + slot_mapping: torch.Tensor + M = TypeVar("M") @@ -52,11 +68,25 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): full_cudagraph_supported: ClassVar[bool] = False @abstractmethod - def build(self, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, + device: torch.device): + self.kv_cache_spec = kv_cache_spec + + @abstractmethod + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. + + Args: + common_prefix_len: The length of the common prefix of the batch. + common_attn_metadata: The common attention metadata. + fast_build: The meta-data will prioritize speed of building over + then speed at execution. Can be used for spec-decode where the + result of a build call may only be used for few layers/iters. """ raise NotImplementedError @@ -85,6 +115,7 @@ def use_cascade_attention( num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, + use_local_attention: bool, num_sms: int, ) -> bool: return False @@ -98,41 +129,9 @@ def reorder_batch(self, input_batch: "InputBatch", return False -def validate_kv_sharing_target(current_layer_name, target_layer_name, - static_forward_context): - error_msg = (f"Specified KV sharing target layer for {current_layer_name} " - f"is not valid: target layer {target_layer_name} ") - - if current_layer_name == target_layer_name: - raise ValueError(error_msg + - "cannot be the same as the current layer.") - - if target_layer_name not in static_forward_context: - from vllm.model_executor.models.utils import extract_layer_index - - # If target layer name is not in the static fwd context, it means either - # a) the target layer does not come BEFORE the current layer, or - # b) the target layer is not an Attention layer that exists in the model - current_layer_idx = extract_layer_index(current_layer_name) - target_layer_idx = extract_layer_index(target_layer_name) - if current_layer_idx <= target_layer_idx: - raise ValueError(error_msg + "must come before the current layer.") - else: - raise ValueError(error_msg + - "is not a valid Attention layer in the model.") - - # Currently KV sharing is only supported between layers of the same type - target_layer_attn_type = static_forward_context[ - target_layer_name].attn_type - expected = static_forward_context[current_layer_name].attn_type - if target_layer_attn_type != expected: - raise ValueError( - error_msg + - f"must be the same type as the current layer ({expected}).") - - @functools.lru_cache def get_kv_cache_layout(): + global _KV_CACHE_LAYOUT_OVERRIDE # Override with format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT if cache_layout is None: @@ -140,10 +139,81 @@ def get_kv_cache_layout(): else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) - + if _KV_CACHE_LAYOUT_OVERRIDE is not None: + cache_layout = _KV_CACHE_LAYOUT_OVERRIDE return cache_layout +def set_kv_cache_layout(cache_layout: str): + global _KV_CACHE_LAYOUT_OVERRIDE + _KV_CACHE_LAYOUT_OVERRIDE = cache_layout + + +@dataclass +class PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters. + """ + + window_left: int + logits_soft_cap: Optional[float] + sm_scale: float + + +def get_per_layer_parameters( + vllm_config: VllmConfig, + cls_: type['AttentionImpl']) -> dict[str, PerLayerParameters]: + """ + Scan all attention layers and determine some hyperparameters + to use during `plan`. + """ + + layers = get_layers_from_vllm_config(vllm_config, Attention) + per_layer_params: dict[str, PerLayerParameters] = {} + + for key, layer in layers.items(): + impl = layer.impl + assert isinstance(impl, cls_) + + # Infer hyperparameters from the attention layer + window_size = getattr(impl, "sliding_window", None) + window_left = window_size[0] if window_size is not None else -1 + logits_soft_cap = getattr(impl, "logits_soft_cap", None) + sm_scale = impl.scale + + per_layer_params[key] = PerLayerParameters(window_left, + logits_soft_cap, sm_scale) + + return per_layer_params + + +def infer_global_hyperparameters( + per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters: + """ + Currently, FlashInfer backend only support models in which all layers share + the same values for the following hyperparameters: + - `window_left` + - `logits_soft_cap` + - `sm_scale` + + So this function asserts that all layers share the same values for these + hyperparameters and returns the global values. + """ + + assert len(per_layer_params) > 0, "No attention layers found in the model." + + param_sets = list(per_layer_params.values()) + global_params = param_sets[0] + for params in param_sets: + assert params == global_params, ( + "FlashInfer backend currently only supports models in which all " + "layers share the same values for the following hyperparameters: " + "`window_left`, `logits_soft_cap`, `sm_scale`.") + + return global_params + + # # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into # local attention blocks, where each block is passed to the attention kernel @@ -198,11 +268,14 @@ def get_kv_cache_layout(): # block_table_local : shape[local_virtual_batches, pages_per_local_batch] def make_local_attention_virtual_batches( attn_chunk_size: int, - query_start_loc_np: np.ndarray, - seq_lens_np: np.ndarray, - block_table: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, block_size: int = 0, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: +) -> CommonAttentionMetadata: + query_start_loc_np = common_attn_metadata.query_start_loc_cpu.numpy() + seq_lens_np = common_attn_metadata.seq_lens_cpu.numpy() + block_table = common_attn_metadata.block_table_tensor + device = common_attn_metadata.query_start_loc.device + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] actual_batch_size = seq_lens_np.shape[0] @@ -265,6 +338,7 @@ def make_local_attention_virtual_batches( attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + num_computed_tokens_local = seqlens_k_local - seqlens_q_local k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ (rarange * attn_chunk_size + \ @@ -306,5 +380,124 @@ def make_local_attention_virtual_batches( block_table_local = block_table[batch_indices, block_indices]\ .view(virtual_batches, -1) - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ - block_table_local + query_start_loc_cpu = torch.from_numpy(cu_seqlens_q_local) + seq_lens_cpu = torch.from_numpy(seqlens_k_local) + + return CommonAttentionMetadata( + query_start_loc_cpu=query_start_loc_cpu, + query_start_loc=query_start_loc_cpu.to(device=device, + non_blocking=True), + seq_lens_cpu=seq_lens_cpu, + seq_lens=seq_lens_cpu.to(device=device, non_blocking=True), + num_computed_tokens_cpu=torch.from_numpy(num_computed_tokens_local), + num_reqs=len(seq_lens_cpu), + num_actual_tokens=common_attn_metadata.num_actual_tokens, + max_query_len=seqlens_q_local.max(), + block_table_tensor=block_table_local, + slot_mapping=common_attn_metadata.slot_mapping, + ) + + +def split_decodes_and_prefills( + common_attn_metadata: CommonAttentionMetadata, + decode_threshold: int = 1, +) -> tuple[int, int, int, int]: + """ + Assuming a reordered batch, finds the boundary between prefill and decode + requests. + + Args: + common_attn_metadata: CommonAttentionMetadata object containing the + batch metadata. + decode_threshold: The maximum query length to be considered a decode. + + Returns: + num_decodes: The number of decode requests. + num_prefills: The number of prefill requests. + num_decode_tokens: The number of tokens in the decode requests. + num_prefill_tokens: The number of tokens in the prefill requests. + """ + max_query_len = common_attn_metadata.max_query_len + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + query_start_loc = common_attn_metadata.query_start_loc_cpu + + if max_query_len <= decode_threshold: + return num_reqs, 0, num_tokens, 0 + + query_lens = query_start_loc[1:] - query_start_loc[:-1] + is_prefill = query_lens > decode_threshold + if not torch.any(is_prefill): + return num_reqs, 0, num_tokens, 0 + + first_prefill = is_prefill.int().argmax(dim=-1).item() + assert torch.all(query_lens[first_prefill:] > decode_threshold) + assert torch.all(query_lens[:first_prefill] <= decode_threshold) + num_decodes = first_prefill + num_prefills = num_reqs - num_decodes + num_decode_tokens = query_start_loc[first_prefill].item() + num_prefill_tokens = num_tokens - num_decode_tokens + return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + +def reorder_batch_to_split_decodes_and_prefills( + input_batch: "InputBatch", + scheduler_output: "SchedulerOutput", + decode_threshold: int = 1, +) -> bool: + """ + Reorders the batch to split into prefill and decode requests; places all + requests with <= decode_threshold tokens at the front of the batch. + + Returns: + True if the batch was modified, False otherwise. + """ + # We now want to reorder the batch so that the "decode" requests are at + # the front and the "prefill" requests are at the back using the least + # amount of swaps possible. (NOTE for now we loosely use "decode" to mean + # requests where attention is likely memory-bound and "prefill" to mean + # requests where attention is likely compute-bound, TODO(lucas): figure out + # a better naming here) + decodes = [] + prefills = [] + num_decode_tokens = 0 + num_prefill_tokens = 0 + + for i, req_id in enumerate(input_batch.req_ids): + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + # for now treat 1 scheduled token as "decode" even if its not, + # we should update this to something like < 8 in the future but + # currently the TritonMLA._forward_decode only supports + # num_tokens = 1 + if num_tokens <= decode_threshold: + decodes.append(i) + num_decode_tokens += num_tokens + else: + prefills.append(i) + num_prefill_tokens += num_tokens + + # We hope that this is fairly minimal since decodes + # should be around for a number of iterations so hopefully they are + # relatively stationary (and new request are generally appended to the + # persistent batch so already should be at the back) + # To achieve this we loop over the decodes in descending order and + # the prefills in ascending order. We swap decodes from the "back" + # i.e. past where the last decode should be in the reodorered with + # prefills from the front of the batch. + # `decodes` and `prefills` are already in ascending order just based on + # the above loop + num_decodes = len(decodes) + num_prefills = len(prefills) + modified_batch = False + + for i in range(1, min(num_decodes, num_prefills) + 1): + # If the decode is at the "back" of the batch, i, we can swap it + # with the prefill closest to the front of the batch + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: + break + + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + + return modified_batch diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d21f94727cf6..5bf4d3a2acb4 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -214,21 +214,18 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: raise ValueError( f"Cannot get {num_blocks} free blocks from the pool") - ret: list[KVCacheBlock] = [] - idx = 0 - while idx < num_blocks: - # First allocate blocks. - curr_block = self.free_block_queue.popleft() - assert curr_block.ref_cnt == 0 - - # If the block is cached, evict it. - if self.enable_caching: - self._maybe_evict_cached_block(curr_block) - - curr_block.incr_ref() - ret.append(curr_block) - idx += 1 - + ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks) + + # In order to only iterate the list once, we duplicated code a bit + if self.enable_caching: + for block in ret: + self._maybe_evict_cached_block(block) + assert block.ref_cnt == 0 + block.ref_cnt += 1 + else: + for block in ret: + assert block.ref_cnt == 0 + block.ref_cnt += 1 return ret def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: @@ -243,22 +240,27 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: True if the block is evicted, False otherwise. """ block_hash = block.block_hash - if block_hash and block_hash in self.cached_block_hash_to_block: - block.reset_hash() - del self.cached_block_hash_to_block[block_hash][block.block_id] - - if len(self.cached_block_hash_to_block[block_hash]) == 0: - del self.cached_block_hash_to_block[block_hash] + if block_hash is None: + # The block doesn't have hash, eviction is not needed + return False + blocks_by_id = self.cached_block_hash_to_block.get(block_hash) + if blocks_by_id is None: + # block_hash not found in cached_block_hash_to_block, + # eviction is not needed + return False + block.reset_hash() + blocks_by_id.pop(block.block_id, None) + if len(blocks_by_id) == 0: + del self.cached_block_hash_to_block[block_hash] - if self.enable_kv_cache_events: - # FIXME (Chen): Not sure whether we should return `hash_value` - # or `(hash_value, group_id)` here. But it's fine now because - # we disable hybrid kv cache manager when kv cache event is - # enabled, so there is only one group. - self.kv_event_queue.append( - BlockRemoved(block_hashes=[block_hash.get_hash_value()])) - return True - return False + if self.enable_kv_cache_events: + # FIXME (Chen): Not sure whether we should return `hash_value` + # or `(hash_value, group_id)` here. But it's fine now because + # we disable hybrid kv cache manager when kv cache event is + # enabled, so there is only one group. + self.kv_event_queue.append( + BlockRemoved(block_hashes=[block_hash.get_hash_value()])) + return True def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None: """Touch a block increases its reference count by 1, and may remove @@ -284,11 +286,14 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: ordered_blocks: A list of blocks to free ordered by their eviction priority. """ - for block in ordered_blocks: - block.decr_ref() - # null_block should not be added to the free list. - if block.ref_cnt == 0 and not block.is_null: - self.free_block_queue.append(block) + # Materialize the iterable to allow multiple passes. + blocks_list = list(ordered_blocks) + for block in blocks_list: + block.ref_cnt -= 1 + self.free_block_queue.append_n([ + block for block in blocks_list + if block.ref_cnt == 0 and not block.is_null + ]) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 38de00625e3f..de72e60434ad 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -171,6 +171,35 @@ def find_longest_cache_hit( pass +class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): + """ + KV cache coordinator to use if prefix caching is disabled or unsupported. + In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator, + supports arbitrary numbers of KV cache groups (including 0 groups). + Does not implement any features related to prefix caching. + """ + + def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int, + use_eagle: bool, caching_hash_fn: Callable, + enable_kv_cache_events: bool): + super().__init__(kv_cache_config, max_model_len, use_eagle, False, + caching_hash_fn, enable_kv_cache_events) + self.num_single_type_manager = len(self.single_type_managers) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> list[int]: + return [0] * self.num_single_type_manager + + def find_longest_cache_hit( + self, + block_hashes: list[BlockHash], + max_cache_hit_length: int, + ) -> tuple[tuple[list[KVCacheBlock], ...], int]: + blocks: tuple[list[KVCacheBlock], ...] = tuple( + [] for _ in range(self.num_single_type_manager)) + return blocks, 0 + + class UnitaryKVCacheCoordinator(KVCacheCoordinator): """ KV cache coordinator for models with only one KV cache group. This is the @@ -359,6 +388,10 @@ def get_kv_cache_coordinator( kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool, enable_caching: bool, caching_hash_fn: Callable, enable_kv_cache_events: bool) -> KVCacheCoordinator: + if not enable_caching: + return KVCacheCoordinatorNoPrefixCache(kv_cache_config, max_model_len, + use_eagle, caching_hash_fn, + enable_kv_cache_events) if len(kv_cache_config.kv_cache_groups) == 1: return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle, enable_caching, diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 6937455e7d85..e820a0ad6d5d 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -7,10 +7,10 @@ from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import sha256 +from vllm.utils import sha256, sha256_cbor_64bit from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, - hash_request_tokens) + hash_request_tokens, init_none_hash) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -78,8 +78,16 @@ def __init__( ) -> None: self.max_model_len = max_model_len + if len(kv_cache_config.kv_cache_groups) == 0: + # Attention free models don't have kv cache, + # thus don't need prefix caching. + enable_caching = False self.enable_caching = enable_caching - self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + + self.caching_hash_fn = ( + sha256_cbor_64bit if caching_hash_algo == "sha256_cbor_64bit" else + sha256 if caching_hash_algo == "sha256" else hash) + init_none_hash(self.caching_hash_fn) self.use_eagle = use_eagle self.log_stats = log_stats # FIXME: make prefix cache stats conditional on log_stats @@ -98,7 +106,7 @@ def __init__( kv_cache_config=kv_cache_config, max_model_len=self.max_model_len, use_eagle=self.use_eagle, - enable_caching=enable_caching, + enable_caching=self.enable_caching, caching_hash_fn=self.caching_hash_fn, enable_kv_cache_events=enable_kv_cache_events, ) @@ -190,7 +198,6 @@ def allocate_slots( num_new_tokens: int, num_new_computed_tokens: int = 0, new_computed_blocks: Optional[KVCacheBlocks] = None, - num_draft_tokens: int = 0, num_lookahead_tokens: int = 0, delay_cache_blocks: bool = False, ) -> Optional[KVCacheBlocks]: @@ -286,12 +293,17 @@ def allocate_slots( if not self.enable_caching or delay_cache_blocks: return KVCacheBlocks(new_blocks) - # Speculated tokens might be rejected in the future, so we does - # not cache any speculated tokens. We only cache blocks with - # generated (accepted) tokens. + # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + + # num_new_tokens, but must exclude "non-committable" tokens (e.g., + # draft tokens that could be rejected). Therefore, we cap the number + # at `request.num_tokens`, ensuring only "finalized" tokens are cached. + num_tokens_to_cache = min(num_computed_tokens + num_new_tokens, + request.num_tokens) self.coordinator.cache_blocks( - request, self.req_to_block_hashes[request.request_id], - num_computed_tokens + num_new_tokens - num_draft_tokens) + request, + self.req_to_block_hashes[request.request_id], + num_tokens_to_cache, + ) return KVCacheBlocks(new_blocks) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2fbcb569e3d5..5b0218640a8c 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -10,8 +10,9 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.utils import GiB_bytes, cdiv, sha256 -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, +from vllm.utils import GiB_bytes, cdiv, sha256_cbor_64bit +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, KVCacheTensor, SlidingWindowSpec) from vllm.v1.metrics.stats import PrefixCacheStats @@ -46,18 +47,30 @@ def get_hash_value(self) -> int: return self.block_hash.hash_value -# The hash seed for the first block of the prefix block sequence. -# -# Even if the hash function is the builtin hash(), we use sha256 to generate -# the initial hash to simplify the code. This is not performance critical -# as it is done one per process. +# The hash seed for the first block of any prefix block sequence. # # We use a random value to avoid hash collisions or PYTHONHASHSEED environment # variable if set such that processes can share the seed if needed. # This aligns with the behavior of Python's hash() function, which also uses # a random seed if PYTHONHASHSEED is not set. -NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv( - "PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED")) +# +# The function `init_none_hash` initializes this variable globally. +NONE_HASH: int + + +def init_none_hash(hash_fn: Callable): + global NONE_HASH + + hash_seed = os.getenv("PYTHONHASHSEED") + if hash_seed is None and hash_fn is sha256_cbor_64bit: + logger.warning( + "PYTHONHASHSEED is not set. This will lead to non-reproducible " + "block-hashes when using sha256_cbor_64bit as the hash function." + "Consider setting PYTHONHASHSEED to a fixed value for " + "reproducibility.") + + NONE_HASH = (int.from_bytes(os.urandom(32), byteorder="big") + if hash_seed is None else hash_fn(hash_seed)) class PrefixCachingMetrics: @@ -141,6 +154,8 @@ class KVCacheBlock: # Whether the block is a null block that should never be cached. is_null: bool = False + # TODO(Jialin): For performance, let callers handle ref_cnt bumps to + # avoid function calls. def incr_ref(self): self.ref_cnt += 1 @@ -200,27 +215,98 @@ class FreeKVCacheBlockQueue: def __init__(self, blocks: list[KVCacheBlock]) -> None: self.num_free_blocks = len(blocks) - # Initialize the doubly linked list of free blocks. - self.free_list_head: Optional[KVCacheBlock] = blocks[0] - self.free_list_tail: Optional[KVCacheBlock] = blocks[-1] + # Initialize doubly links of consecutive blocks for i in range(self.num_free_blocks): if i > 0: blocks[i].prev_free_block = blocks[i - 1] if i < self.num_free_blocks - 1: blocks[i].next_free_block = blocks[i + 1] + # Create a fake head and a tail block for the doubly linked list to + # reduce branching in the code + # + # The implementation garenteed that the fake head and tail + # are NEVER got popped, so we could safely assume each real blocks + # in the queue has prev and next blocks. + self.fake_free_list_head = KVCacheBlock(block_id=-1) + self.fake_free_list_tail = KVCacheBlock(block_id=-1) + if self.num_free_blocks > 0: + # Connect fake_head and fake_tail to the first and last block + # respectively. + self.fake_free_list_head.next_free_block = blocks[0] + blocks[0].prev_free_block = self.fake_free_list_head + self.fake_free_list_tail.prev_free_block = blocks[-1] + blocks[-1].next_free_block = self.fake_free_list_tail + else: + # For empty list, simply connect the fake head and tail. + self.fake_free_list_head.next_free_block = self.fake_free_list_tail + self.fake_free_list_tail.prev_free_block = self.fake_free_list_head + def popleft(self) -> KVCacheBlock: """Pop the first free block and reduce num_free_blocks by 1. Returns: The first free block. """ - if not self.free_list_head: + if (self.fake_free_list_head.next_free_block + is self.fake_free_list_tail + or self.fake_free_list_head.next_free_block is None): + assert self.num_free_blocks == 0, ( + f"num_free_blocks ({self.num_free_blocks}) is out of sync " + "with the free list.") raise ValueError("No free blocks available") - block = self.free_list_head - self.remove(block) - return block + first_block: KVCacheBlock = self.fake_free_list_head.next_free_block + + if first_block.next_free_block is None: + # This should not happen if the block is from the free list. + # It indicates a bug in the caller's logic. + raise RuntimeError("Invalid block found in popleft() " + "which doesn't have a valid next_free_block") + + # Connect fake_head and the next block of first_block (i.e. second block + # or fake tail). + self.fake_free_list_head.next_free_block = first_block.next_free_block + first_block.next_free_block.prev_free_block = self.fake_free_list_head + + # Remove the block from the linked list. + first_block.prev_free_block = first_block.next_free_block = None + + self.num_free_blocks -= 1 + return first_block + + def popleft_n(self, n: int) -> list[KVCacheBlock]: + """Pop the first n free blocks and reduce num_free_blocks by n. + + Args: + n: The number of blocks to pop. + + Returns: + A list of n free blocks. + """ + if n == 0: + return [] + assert self.num_free_blocks >= n + self.num_free_blocks -= n + + curr_block = self.fake_free_list_head.next_free_block + # Pop n blocks from the head of the list + ret = [] + for _ in range(n): + assert curr_block is not None + ret.append(curr_block) + last_block = curr_block + curr_block = curr_block.next_free_block + # Reset prev_free_block and next_free_block of all popped blocks + last_block.prev_free_block = None + last_block.next_free_block = None + + if curr_block is not None: + # The queue is not empty, connect the fake head to + # the new first block. + self.fake_free_list_head.next_free_block = curr_block + curr_block.prev_free_block = self.fake_free_list_head + return ret def remove(self, block: KVCacheBlock) -> None: """Remove a block in the free list and reduce num_free_blocks by 1. @@ -228,19 +314,15 @@ def remove(self, block: KVCacheBlock) -> None: Args: block: The block to remove. """ - if block.prev_free_block is not None: - # Link the previous block to the next block. - block.prev_free_block.next_free_block = block.next_free_block - if block.next_free_block is not None: - # Link the next block to the previous block. - block.next_free_block.prev_free_block = block.prev_free_block - - if block == self.free_list_head: - # Update the head if the block is the head. - self.free_list_head = block.next_free_block - if block == self.free_list_tail: - # Update the tail if the block is the tail. - self.free_list_tail = block.prev_free_block + if block.prev_free_block is None or block.next_free_block is None: + # This should not happen if the block is from the free list. + # It indicates a bug in the caller's logic. + raise RuntimeError(f"remove() called on an invalid block: {block}") + + # Link the previous block to the next block. + block.prev_free_block.next_free_block = block.next_free_block + # Link the next block to the previous block. + block.next_free_block.prev_free_block = block.prev_free_block # Remove the block from the linked list. block.prev_free_block = block.next_free_block = None @@ -253,19 +335,44 @@ def append(self, block: KVCacheBlock) -> None: Args: block: The block to append. """ - if self.free_list_tail is not None: - # Link the last block to the new block. - self.free_list_tail.next_free_block = block - block.prev_free_block = self.free_list_tail - self.free_list_tail = block - else: - # The free list is empty. - assert self.free_list_head is None - self.free_list_head = self.free_list_tail = block + if self.fake_free_list_tail.prev_free_block is None: + raise RuntimeError( + "prev_free_block of fake_free_list_tail should always exist") + last_block: KVCacheBlock = self.fake_free_list_tail.prev_free_block + + # Connect the new block after the last block. + last_block.next_free_block = block + block.prev_free_block = last_block + + # Connect the fake tail after the new block. + block.next_free_block = self.fake_free_list_tail + self.fake_free_list_tail.prev_free_block = block - block.next_free_block = None self.num_free_blocks += 1 + def append_n(self, blocks: list[KVCacheBlock]) -> None: + """Put a list of blocks back into the free list + + Args: + blocks: The blocks to append. + """ + if len(blocks) == 0: + return + self.num_free_blocks += len(blocks) + + last_block = self.fake_free_list_tail.prev_free_block + assert last_block is not None, ( + "prev_free_block of fake_free_list_tail should always exist") + # Add inter-connections between consecutive blocks + for block in blocks: + block.prev_free_block = last_block + last_block.next_free_block = block + last_block = block + + # Connect the last block of <blocks> to the fake tail + last_block.next_free_block = self.fake_free_list_tail + self.fake_free_list_tail.prev_free_block = last_block + def get_all_free_blocks(self) -> list[KVCacheBlock]: """Get all free blocks in the free list. Mainly used for testing. @@ -273,8 +380,14 @@ def get_all_free_blocks(self) -> list[KVCacheBlock]: A list of free blocks. """ ret = [] - curr_block = self.free_list_head - while curr_block is not None: + if self.fake_free_list_head.next_free_block is None: + raise RuntimeError( + "next_free_block of fake_free_list_head should always exist") + # Start from the first block + curr_block: KVCacheBlock = self.fake_free_list_head.next_free_block + # As long as next_free_block is available, we haven't reached to + # the fake tail yet. + while curr_block.next_free_block is not None: ret.append(curr_block) curr_block = curr_block.next_free_block return ret @@ -293,9 +406,9 @@ def need_extra_keys(request: Request) -> bool: # Multimodal requests need to include the MM hash. # LoRA requests need to include the LoRA ID. # Request with provided cache salt need to include the salt. - return bool(request.mm_positions) or (request.lora_request - is not None) or (request.cache_salt - is not None) + return bool(request.mm_hashes) or (request.lora_request + is not None) or (request.cache_salt + is not None) def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int, @@ -551,6 +664,10 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig, ValueError: If there is not enough memory available for the KV cache. """ + # No need to check for available memory if the kv_cache_spec is empty + if not kv_cache_spec: + return + if available_memory <= 0: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " @@ -737,6 +854,13 @@ def is_kv_cache_page_size_uniform( return len(page_sizes) == 1 +def is_kv_cache_type_attention_free( + kv_cache_spec: dict[str, KVCacheSpec]) -> bool: + + # kv_cache_spec is an empty dict for attention free models + return not kv_cache_spec + + def _get_kv_cache_config_uniform_page_size( vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec], available_memory: int) -> KVCacheConfig: @@ -879,6 +1003,10 @@ def _get_kv_cache_config_uniform_page_size( return kv_cache_config +def _get_kv_cache_config_attention_free() -> KVCacheConfig: + return KVCacheConfig(num_blocks=1, kv_cache_tensors=[], kv_cache_groups=[]) + + def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): """ This function tries to convert the KV cache specs to one type if the model @@ -907,7 +1035,11 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()) has_sliding_window = any( isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values()) - if has_full_attention and has_sliding_window: + has_chunked_local_attention = any( + isinstance(spec, ChunkedLocalAttentionSpec) + for spec in kv_cache_spec.values()) + if has_full_attention and (has_sliding_window + or has_chunked_local_attention): for layer_name, spec in kv_cache_spec.items(): if isinstance(spec, SlidingWindowSpec): kv_cache_spec[layer_name] = FullAttentionSpec( @@ -918,6 +1050,15 @@ def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool: use_mla=spec.use_mla, sliding_window=spec.sliding_window, ) + elif isinstance(spec, ChunkedLocalAttentionSpec): + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=spec.block_size, + num_kv_heads=spec.num_kv_heads, + head_size=spec.head_size, + dtype=spec.dtype, + use_mla=spec.use_mla, + attention_chunk_size=spec.attention_chunk_size, + ) if is_hybrid(kv_cache_spec): raise ValueError("Hybrid KV cache manager is disabled but failed to " @@ -941,11 +1082,14 @@ def get_kv_cache_config( The generated KVCacheConfigs """ check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager: unify_hybrid_kv_cache_specs(kv_cache_spec) - if is_kv_cache_type_uniform(kv_cache_spec): + if is_kv_cache_type_attention_free(kv_cache_spec): + # This returns a kv_cache config with 0 kv_cache groups and 1 block + # to allow for the KVCache manager to handle attention free models. + return _get_kv_cache_config_attention_free() + elif is_kv_cache_type_uniform(kv_cache_spec): # KV cache of all layers are the same, which is true for # most models. Allocate the same amount of memory for # each layer. diff --git a/vllm/v1/core/sched/async_scheduler.py b/vllm/v1/core/sched/async_scheduler.py new file mode 100644 index 000000000000..74ff6261732c --- /dev/null +++ b/vllm/v1/core/sched/async_scheduler.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.request import Request, RequestStatus + +logger = init_logger(__name__) + + +class AsyncScheduler(Scheduler): + + def _update_after_schedule( + self, + scheduler_output: SchedulerOutput, + ) -> None: + super()._update_after_schedule(scheduler_output) + for req_id in scheduler_output.num_scheduled_tokens: + request = self.requests[req_id] + if (request.num_computed_tokens == request.num_tokens + + request.num_output_placeholders): + # The request will generate a new token in this scheduling step. + # TODO(woosuk): Support speculative decoding. + request.num_output_placeholders += 1 + + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + status_before_update = request.status + new_token_ids, stopped = super()._update_request_with_output( + request, new_token_ids) + + # Update the number of output placeholders. + request.num_output_placeholders -= len(new_token_ids) + assert request.num_output_placeholders >= 0 + + # Cache the new tokens. Preempted requests should be skipped. + if status_before_update == RequestStatus.RUNNING: + self.kv_cache_manager.cache_blocks( + request, + request.num_computed_tokens - request.num_output_placeholders) + return new_token_ids, stopped diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2f..446f98034cb8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -204,7 +204,8 @@ def schedule(self) -> SchedulerOutput: while req_index < len(self.running) and token_budget > 0: request = self.running[req_index] - num_new_tokens = (request.num_tokens_with_spec - + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - request.num_computed_tokens) if (0 < self.scheduler_config.long_prefill_token_threshold < num_new_tokens): @@ -230,9 +231,11 @@ def schedule(self) -> SchedulerOutput: if num_new_tokens == 0: # The request cannot be scheduled because one of the following # reasons: - # 1. No new tokens to schedule. This may happen when PP>1 and - # we have already scheduled all prompt tokens but they are - # not finished yet. + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. # 2. The encoder budget is exhausted. # 3. The encoder cache is exhausted. # NOTE(woosuk): Here, by doing `continue` instead of `break`, @@ -241,15 +244,10 @@ def schedule(self) -> SchedulerOutput: req_index += 1 continue - num_draft_tokens = max( - num_new_tokens + request.num_computed_tokens - - request.num_tokens, 0) - while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, - num_draft_tokens=num_draft_tokens, num_lookahead_tokens=self.num_lookahead_tokens) if new_blocks is None: # The request cannot be scheduled. @@ -603,6 +601,14 @@ def _update_after_schedule( request = self.requests[req_id] request.num_computed_tokens += num_scheduled_token + # NOTE: _free_encoder_inputs relies on num_computed_tokens, which + # may be updated again in _update_from_output for speculative + # decoding. However, it is safe to call the method here because + # encoder inputs are always part of the prompt, not the output, + # and thus are unaffected by speculative decoding. + if request.has_encoder_inputs: + self._free_encoder_inputs(request) + # Clear the finished request IDs. # NOTE: We shouldn't do self.finished_req_ids.clear() here because # it will also affect the scheduler output. @@ -621,6 +627,7 @@ def _make_cached_request_data( new_block_ids: list[tuple[list[int], ...]] = [] num_computed_tokens: list[int] = [] + use_connector = self.connector is not None for req in itertools.chain(running_reqs, resumed_reqs): req_id = req.request_id req_ids.append(req_id) @@ -635,6 +642,11 @@ def _make_cached_request_data( token_ids = req.all_token_ids[req.num_computed_tokens:req. num_computed_tokens + num_tokens] new_token_ids.append(token_ids) + elif use_connector: + # When using a KVConnector, we add a placeholder to avoid index + # out of bounds errors. TODO: Remove this once the KVConnector + # is updated to handle token IDs properly. + new_token_ids.append([]) new_block_ids.append(req_to_new_block_ids[req_id]) num_computed_tokens.append(req.num_computed_tokens) # Because resumed_reqs is usually empty, it is more efficient to do @@ -746,19 +758,21 @@ def update_from_output( pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits - new_running: list[Request] = [] outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below - # loop can be a performance bottleneck. We should do our best to avoid - # expensive operations inside the loop. - for request in self.running: - req_id = request.request_id - num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. - new_running.append(request) + # NOTE(woosuk): As len(num_scheduled_tokens) can be up to 1K or more, + # the below loop can be a performance bottleneck. We should do our best + # to avoid expensive operations inside the loop. + stopped_running_reqs: set[Request] = set() + stopped_preempted_reqs: set[Request] = set() + for req_id, num_tokens_scheduled in num_scheduled_tokens.items(): + assert num_tokens_scheduled > 0 + request = self.requests.get(req_id) + if request is None: + # The request is already finished. This can happen if the + # request is aborted while the model is executing it (e.g., + # in pipeline parallelism). continue req_index = model_runner_output.req_id_to_index[req_id] @@ -782,37 +796,30 @@ def update_from_output( num_draft_tokens=len(scheduled_spec_token_ids), num_accepted_tokens=len(generated_token_ids) - 1) - # NOTE(woosuk): This has to be executed after updating - # `request.num_computed_tokens`. - if request.has_encoder_inputs: - self._free_encoder_inputs(request) - stopped = False new_logprobs = None new_token_ids = generated_token_ids kv_transfer_params = None + status_before_stop = request.status - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner - # to return empty token ids for the request. - for num_new, output_token_id in enumerate(new_token_ids, 1): - request.append_output_token_ids(output_token_id) - - # Check for stop and update request state. - # This must be called before we make the EngineCoreOutput. - stopped = check_stop(request, self.max_model_len) - if stopped: - kv_transfer_params = self._free_request(request) - del new_token_ids[num_new:] # Trim new tokens if needed. - break + # Check for stop and update request status. + if new_token_ids: + new_token_ids, stopped = self._update_request_with_output( + request, new_token_ids) + # Stop checking for pooler models. pooler_output = None if pooler_outputs: pooler_output = pooler_outputs[req_index] stopped = check_stop(request, self.max_model_len, pooler_output) - if stopped: - kv_transfer_params = self._free_request(request) + + if stopped: + kv_transfer_params = self._free_request(request) + if status_before_stop == RequestStatus.RUNNING: + stopped_running_reqs.add(request) + else: + stopped_preempted_reqs.add(request) # Extract sample logprobs if needed. if request.sampling_params is not None \ @@ -867,9 +874,14 @@ def update_from_output( # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors - if not stopped: - new_running.append(request) - self.running = new_running + # Remove the stopped requests from the running and waiting queues. + if stopped_running_reqs: + self.running = [ + req for req in self.running if req not in stopped_running_reqs + ] + if stopped_preempted_reqs: + # This is a rare case and unlikely to impact performance. + self.waiting.remove_requests(stopped_preempted_reqs) # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) @@ -901,6 +913,26 @@ def update_from_output( return engine_core_outputs + def _update_request_with_output( + self, + request: Request, + new_token_ids: list[int], + ) -> tuple[list[int], bool]: + # Append generated tokens and check for stop. Note that if + # a request is still being prefilled, we expect the model runner + # to return empty token ids for the request. + stopped = False + for num_new, output_token_id in enumerate(new_token_ids, 1): + request.append_output_token_ids(output_token_id) + + # Check for stop and update request state. + # This must be called before we make the EngineCoreOutput. + stopped = check_stop(request, self.max_model_len) + if stopped: + del new_token_ids[num_new:] # Trim new tokens if needed. + break + return new_token_ids, stopped + def _free_encoder_inputs(self, request: Request) -> None: cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5b4718038076..65a196e044ab 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -7,7 +7,8 @@ from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock -from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, +from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, + FullAttentionSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -256,8 +257,10 @@ def find_longest_cache_hit( kv_cache_spec: KVCacheSpec, use_eagle: bool, ) -> tuple[list[KVCacheBlock], ...]: - assert isinstance(kv_cache_spec, FullAttentionSpec), ( - "FullAttentionManager can only be used for full attention groups") + assert isinstance( + kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec) + ), "FullAttentionManager can only be used for full attention " \ + "and chunked local attention groups" computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( [] for _ in range(len(kv_cache_group_ids))) max_num_blocks = max_length // kv_cache_spec.block_size @@ -391,6 +394,129 @@ def get_num_common_prefix_blocks(self, request_id: str, return 0 +class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): + + def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, + block_pool: BlockPool, **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, **kwargs) + self.attention_chunk_size = kv_cache_spec.attention_chunk_size + self._null_block = block_pool.null_block + + @classmethod + def find_longest_cache_hit( + cls, + block_hashes: list[BlockHash], + max_length: int, + kv_cache_group_ids: list[int], + block_pool: BlockPool, + kv_cache_spec: KVCacheSpec, + use_eagle: bool, + ) -> tuple[list[KVCacheBlock], ...]: + """ + For chunked local attention, we need to find the longest cache hit + prefix of the blocks that is not longer than `max_length`. The prefix + should be a common prefix hit for all the kv cache groups in + `kv_cache_group_ids`. If no cache hit is found, return an empty list. + note we mark as computed if the whole block is outside of the local + window, and set the block as null. Examples: + + 1. Attention chunk size of 8, block size of 4, max length of 15 + for next token at 15th (zero-indexed), 8th - 14th tokens are in + the window(needs lookup), 0th - 7th are not in the window, + so they are already marked as computed. We check the complete + block3 (8th - 11th tokens), Assume block 3 is hit, we will return + [null, null, block 3], otherwise, we return [null, null] + + 2. Attention chunk size of 8, block size of 4, max length of 16 + for next token at 16th (zero-indexed), 0th - 15th tokens are not + in the window, so they are already marked as computed. + we return 4 blocks[null, null, null, null] + + Args: + block_hashes: The block hashes of the request. + max_length: The maximum length of the cache hit prefix. + kv_cache_group_ids: The ids of the kv cache groups. + block_pool: The block pool. + kv_cache_spec: The kv cache spec. + use_eagle: Whether to use eagle. + + Returns: + A list of cached blocks + """ + assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), ( + "ChunkedLocalAttentionManager can only be used for " + + "chunked local attention groups") + assert use_eagle is False, ("Hybrid KV cache is not supported for " + + "eagle + chunked local attention.") + max_num_blocks = max_length // kv_cache_spec.block_size + if max_length > 0: + local_attention_start_idx = (max_length // + kv_cache_spec.attention_chunk_size * + kv_cache_spec.attention_chunk_size) + else: + local_attention_start_idx = 0 + # we marked blocks out of window as computed + # with null blocks, and blocks inside window based on cache lookup + # result [null] [null] ... [null] [hit block 1 (1st block contain + # last window)] [hit block 2] ... [hit block x] + local_attention_start_block_idx = (local_attention_start_idx // + kv_cache_spec.block_size) + computed_blocks: tuple[list[KVCacheBlock], ...] = tuple( + [block_pool.null_block] * local_attention_start_block_idx + for _ in range(len(kv_cache_group_ids))) + for i in range(local_attention_start_block_idx, max_num_blocks): + block_hash = block_hashes[i] + if cached_block := block_pool.get_cached_block( + block_hash, kv_cache_group_ids): + for computed, cached in zip(computed_blocks, cached_block): + computed.append(cached) + else: + break + return computed_blocks + + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: + # Remove the blocks that are no longer be in the chunked attention + # window and skipped during the attention computation. + + # [chunk 0][chunk 1]local_attention_start_idx ... current + # we computed previous number of chunks to get the idx of + # current chunk window starting offset, + # e.g. for computed 1024 tokens, the 1024th token (0 indexed) + # is in the second chunk, there are 1 prev chunk, the start idx + # is 1024. for 1023, it will be 0. + num_cached_block = self.num_cached_block.get(request_id, 0) + local_attention_start_idx = ( + num_computed_tokens + ) // self.attention_chunk_size * self.attention_chunk_size + first_useful_block_idx = local_attention_start_idx // self.block_size + if num_cached_block > 0: + # Make sure we don't delete the last cached block + first_useful_block_idx = min(first_useful_block_idx, + num_cached_block - 1) + # if block size = 128, 0 -> block 0, 1024 (= 128 * 8) -> + # block 8, 372 (= 128 * 2 + 116) -> block 2 + blocks = self.req_to_blocks[request_id] + removed_blocks: list[KVCacheBlock] = [] + # we need to keep the last block to get the previous hash key + for i in range(first_useful_block_idx - 1, -1, -1): + if blocks[i] == self._null_block: + # If the block is already a null block, the blocks before it + # should also have been set to null blocks by the previous calls + # to this function. + break + removed_blocks.append(blocks[i]) + blocks[i] = self._null_block + self.block_pool.free_blocks(removed_blocks) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + cascade attention is not supported by chunked local attention. + """ + return 0 + + class MambaManager(SingleTypeKVCacheManager): @classmethod @@ -433,6 +559,7 @@ def allocate_new_blocks(self, request_id: str, spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, + ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, } diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cdd..79dc80d8fc54 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -177,3 +177,19 @@ class EngineCoreRequestType(enum.Enum): UTILITY = b'\x03' # Sentinel used within EngineCoreProc. EXECUTOR_FAILED = b'\x04' + + +class ReconfigureDistributedRequest(msgspec.Struct): + new_data_parallel_size: int + new_data_parallel_rank: int + new_data_parallel_rank_local: int + new_data_parallel_master_ip: str + new_data_parallel_master_port: int + + +class ReconfigureRankType(enum.IntEnum): + """ + Rank type for reconfiguring distributed request. + """ + KEEP_CURRENT_RANK = -1 + SHUTDOWN_CURRENT_RANK = -2 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 3754570dfaaa..02cb80197fa4 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import time from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Any, Optional, Union @@ -19,7 +20,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -35,10 +35,9 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, - setup_default_loggers) +from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager from vllm.v1.metrics.prometheus import shutdown_prometheus -from vllm.v1.metrics.stats import IterationStats, SchedulerStats +from vllm.v1.metrics.stats import IterationStats logger = init_logger(__name__) @@ -94,19 +93,14 @@ def __init__( self.log_requests = log_requests self.log_stats = log_stats - # Set up stat loggers; independent set for each DP rank. - self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( - vllm_config=vllm_config, - log_stats=self.log_stats, - engine_num=vllm_config.parallel_config.data_parallel_size, - custom_stat_loggers=stat_loggers, - ) - - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if self.model_config.skip_tokenizer_init: + self.tokenizer = None + else: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) # Processor (converts Inputs --> EngineCoreRequests). self.processor = Processor( @@ -120,7 +114,6 @@ def __init__( log_stats=self.log_stats) # EngineCore (starts the engine in background process). - self.engine_core = EngineCoreClient.make_async_mp_client( vllm_config=vllm_config, executor_class=executor_class, @@ -128,9 +121,17 @@ def __init__( client_addresses=client_addresses, client_index=client_index, ) - if self.stat_loggers: - for stat_logger in self.stat_loggers[0]: - stat_logger.log_engine_initialized() + + # Loggers. + self.logger_manager: Optional[StatLoggerManager] = None + if self.log_stats: + self.logger_manager = StatLoggerManager( + vllm_config=vllm_config, + engine_idxs=self.engine_core.engine_ranks_managed, + custom_stat_loggers=stat_loggers, + ) + self.logger_manager.log_engine_initialized() + self.output_handler: Optional[asyncio.Task] = None try: # Start output handler eagerly if we are in the asyncio eventloop. @@ -219,7 +220,6 @@ async def add_request( lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, ) -> RequestOutputCollector: @@ -236,8 +236,7 @@ async def add_request( # Convert Input --> Request. prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, - tokenization_kwargs, trace_headers, prompt_adapter_request, - priority, data_parallel_rank) + tokenization_kwargs, trace_headers, priority, data_parallel_rank) if is_pooling or params.n == 1: await self._add_request(request, prompt_str, None, 0, queue) @@ -281,7 +280,6 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, ) -> AsyncGenerator[RequestOutput, None]: @@ -312,7 +310,6 @@ async def generate( sampling_params, lora_request=lora_request, trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, priority=priority, data_parallel_rank=data_parallel_rank, ) @@ -369,7 +366,7 @@ def _run_output_handler(self): engine_core = self.engine_core output_processor = self.output_processor log_stats = self.log_stats - stat_loggers = self.stat_loggers if log_stats else None + logger_manager = self.logger_manager async def output_handler(): try: @@ -409,9 +406,9 @@ async def output_handler(): # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. - if stat_loggers: - AsyncLLM._record_stats( - stat_loggers[outputs.engine_index], + if logger_manager: + logger_manager.record( + engine_idx=outputs.engine_index, scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, ) @@ -430,18 +427,6 @@ async def abort(self, request_id: str) -> None: if self.log_requests: logger.info("Aborted request %s.", request_id) - @staticmethod - def _record_stats( - stat_loggers: list[StatLoggerBase], - scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats], - ): - """static so that it can be used from the output_handler task - without a circular ref to AsyncLLM.""" - for stat_logger in stat_loggers: - stat_logger.record(scheduler_stats=scheduler_stats, - iteration_stats=iteration_stats) - async def encode( self, prompt: PromptType, @@ -450,6 +435,7 @@ async def encode( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, + tokenization_kwargs: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[PoolingRequestOutput, None]: """ Main function called by the API server to kick off a request @@ -478,6 +464,7 @@ async def encode( lora_request=lora_request, trace_headers=trace_headers, priority=priority, + tokenization_kwargs=tokenization_kwargs, ) # The output_handler task pushes items into the queue. @@ -536,6 +523,10 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: + if self.tokenizer is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + return self.tokenizer.get_lora_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: @@ -546,9 +537,8 @@ async def do_log_stats( scheduler_outputs=None, model_output=None, ) -> None: - for loggers in self.stat_loggers: - for stat_logger in loggers: - stat_logger.log() + if self.logger_manager: + self.logger_manager.log() async def check_health(self) -> None: logger.debug("Called check_health.") @@ -608,6 +598,61 @@ async def collective_rpc(self, return await self.engine_core.collective_rpc_async( method, timeout, args, kwargs) + async def wait_for_requests_to_drain(self, drain_timeout: int = 300): + """Wait for all requests to be drained.""" + start_time = time.time() + while time.time() - start_time < drain_timeout: + if not self.engine_core.dp_engines_running(): + logger.info("Engines are idle, requests have been drained") + return + + logger.info( + "Engines are still running, waiting for requests to drain...") + await asyncio.sleep(1) # Wait 1 second before checking again + + raise TimeoutError(f"Timeout reached after {drain_timeout} seconds " + "waiting for requests to drain.") + + async def scale_elastic_ep(self, + new_data_parallel_size: int, + drain_timeout: int = 300): + """ + Scale up or down the data parallel size by adding or removing + engine cores. + Args: + new_data_parallel_size: The new number of data parallel workers + drain_timeout: + Maximum time to wait for requests to drain (seconds) + """ + old_data_parallel_size = \ + self.vllm_config.parallel_config.data_parallel_size + if old_data_parallel_size == new_data_parallel_size: + logger.info("Data parallel size is already %s, skipping scale", + new_data_parallel_size) + return + logger.info( + "Waiting for requests to drain before " + "scaling up to %s engines...", new_data_parallel_size) + await self.wait_for_requests_to_drain(drain_timeout) + logger.info( + "Requests have been drained, proceeding with scale " + "to %s engines", new_data_parallel_size) + await self.engine_core.scale_elastic_ep(new_data_parallel_size) + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + + # recreate stat loggers + if new_data_parallel_size > old_data_parallel_size and self.log_stats: + # TODO(rob): fix this after talking with Ray team. + # This resets all the prometheus metrics since we + # unregister during initialization. Need to understand + # the intended behavior here better. + self.logger_manager = StatLoggerManager( + vllm_config=self.vllm_config, + engine_idxs=list(range(new_data_parallel_size)), + custom_stat_loggers=None, + ) + @property def is_running(self) -> bool: # Is None before the loop is started. diff --git a/vllm/v1/engine/coordinator.py b/vllm/v1/engine/coordinator.py index b3e7a2e85b80..c0decd6ffa2c 100644 --- a/vllm/v1/engine/coordinator.py +++ b/vllm/v1/engine/coordinator.py @@ -61,11 +61,12 @@ def __init__(self, parallel_config: ParallelConfig): host = parallel_config.data_parallel_master_ip external_lb = parallel_config.data_parallel_external_lb + hybrid_lb = parallel_config.data_parallel_hybrid_lb # Assume coordinator is colocated with front-end procs when not in - # external DP LB mode. + # either external or hybrid DP LB mode. front_publish_address = get_engine_client_zmq_addr( - local_only=not external_lb, host=host) + local_only=not external_lb and not hybrid_lb, host=host) local_only_eng = dp_size == parallel_config.data_parallel_size_local back_publish_address = get_engine_client_zmq_addr(local_only_eng, host) @@ -200,11 +201,41 @@ def process_input_socket(self, front_publish_address: str, # Ignore subscription messages. continue + decoded = msgspec.msgpack.decode(buffer) + if isinstance(decoded, (list, tuple)) and len( + decoded) == 2 and decoded[0] == "SCALE_ELASTIC_EP": + # Handle scale up notification + new_engine_count = decoded[1] + current_count = len(self.engines) + if new_engine_count > current_count: + for _ in range(new_engine_count - current_count): + self.engines.append(EngineState()) + # NOTE(yongji): handle the case + # where newly started engines have current_wave = 0 + # if existing engines just finished a wave + # and engine_running isn't updated yet at + # CoordinatorProc requests routed to newly started + # engines may not wake up existing engines, as long + # as 0 < request.wave < existing engines' + # current_wave + # we note that 0 is the wave number for the new + # engine + self.engines_running = False + logger.info( + "DPCoordinator scaled up from %s to %s " + "engines", current_count, new_engine_count) + else: + self.engines = self.engines[:new_engine_count] + logger.info( + "DPCoordinator scaled down from %s to %s " + "engines", current_count, new_engine_count) + continue # Skip normal engine notification processing + # We received a message on the front-end XPUB socket, # from an API server sending a new request while the # engines are paused, so that we can wake the other # engines. - engine_to_exclude, wave = msgspec.msgpack.decode(buffer) + engine_to_exclude, wave = decoded if not self.engines_running: if wave < self.current_wave: # If the wave number is stale, ensure the message diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e2fdf6f8a11c..7779b559c20e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -32,7 +32,9 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, UtilityOutput) + EngineCoreRequestType, + ReconfigureDistributedRequest, ReconfigureRankType, + UtilityOutput) from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor @@ -77,6 +79,8 @@ def __init__(self, self.model_executor.register_failure_callback( executor_fail_callback) + self.available_gpu_memory_for_kv_cache = -1 + # Setup KV Caches and update CacheConfig after profiling. num_gpu_blocks, num_cpu_blocks, kv_cache_config = \ self._initialize_kv_caches(vllm_config) @@ -137,9 +141,26 @@ def _initialize_kv_caches( # Get all kv cache needed by the model kv_cache_specs = self.model_executor.get_kv_cache_specs() - # Profiles the peak memory usage of the model to determine how much - # memory can be allocated for kv cache. - available_gpu_memory = self.model_executor.determine_available_memory() + has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs) + if has_kv_cache: + if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1": + dp_group = getattr(self, "dp_group", None) + assert dp_group is not None + self.available_gpu_memory_for_kv_cache = \ + ParallelConfig.sync_kv_cache_memory_size(dp_group, -1) + available_gpu_memory = [ + self.available_gpu_memory_for_kv_cache + ] * len(kv_cache_specs) + else: + # Profiles the peak memory usage of the model to determine how + # much memory can be allocated for kv cache. + available_gpu_memory = ( + self.model_executor.determine_available_memory()) + self.available_gpu_memory_for_kv_cache = \ + available_gpu_memory[0] + else: + # Attention free models don't need memory for kv cache + available_gpu_memory = [0] * len(kv_cache_specs) assert len(kv_cache_specs) == len(available_gpu_memory) # Get the kv cache tensor size @@ -175,6 +196,12 @@ def _initialize_kv_caches( def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" + if pooling_params := request.pooling_params: + supported_pooling_tasks = ( + self.model_executor.supported_pooling_tasks) + if pooling_params.task not in supported_pooling_tasks: + raise ValueError(f"Unsupported task: {pooling_params.task!r} " + f"Supported tasks: {supported_pooling_tasks}") if request.mm_hashes is not None: # Here, if hash exists for a multimodal input, then it will be @@ -207,9 +234,14 @@ def abort_requests(self, request_ids: list[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) - def execute_model(self, scheduler_output: SchedulerOutput): + def execute_model_with_error_logging( + self, + model_fn: Callable[[SchedulerOutput], ModelRunnerOutput], + scheduler_output: SchedulerOutput, + ) -> ModelRunnerOutput: + """Execute the model and log detailed info on failure.""" try: - return self.model_executor.execute_model(scheduler_output) + return model_fn(scheduler_output) except Exception as err: # We do not want to catch BaseException here since we're only # interested in dumping info when the exception is due to an @@ -232,7 +264,9 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: if not self.scheduler.has_requests(): return {}, False scheduler_output = self.scheduler.schedule() - model_output = self.execute_model(scheduler_output) + model_output = self.execute_model_with_error_logging( + self.model_executor.execute_model, # type: ignore + scheduler_output) engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output) # type: ignore @@ -279,8 +313,11 @@ def step_with_batch_queue( # so we need more work. if not scheduled_batch and not self.batch_queue.empty(): future, scheduler_output = self.batch_queue.get_nowait() + # Blocking until the first result is available. - model_output = future.result() + model_output = self.execute_model_with_error_logging( + lambda _: future.result(), scheduler_output) + self.batch_queue.task_done() engine_core_outputs = (self.scheduler.update_from_output( scheduler_output, model_output)) @@ -440,13 +477,14 @@ def _perform_handshakes( For DP>1 with internal loadbalancing this is with the shared front-end process which may reside on a different node. - For DP>1 with external loadbalancing, two handshakes are performed: + For DP>1 with external or hybrid loadbalancing, two handshakes are + performed: - With the rank 0 front-end process which retrieves the DP Coordinator ZMQ addresses and DP process group address. - With the colocated front-end process which retrieves the client input/output socket addresses. - with the exception of the rank 0 engine itself which doesn't require - the second handshake. + with the exception of the rank 0 and colocated engines themselves which + don't require the second handshake. Here, "front-end" process can mean the process containing the engine core client (which is the API server process in the case the API @@ -455,15 +493,18 @@ def _perform_handshakes( """ input_ctx = zmq.Context() is_local = local_client and client_handshake_address is None + headless = not local_client handshake = self._perform_handshake(input_ctx, handshake_address, - identity, is_local, vllm_config, + identity, is_local, headless, + vllm_config, vllm_config.parallel_config) if client_handshake_address is None: with handshake as addresses: yield addresses else: + assert local_client local_handshake = self._perform_handshake( - input_ctx, client_handshake_address, identity, local_client, + input_ctx, client_handshake_address, identity, True, False, vllm_config) with handshake as addresses, local_handshake as client_addresses: addresses.inputs = client_addresses.inputs @@ -480,6 +521,7 @@ def _perform_handshake( handshake_address: str, identity: bytes, local_client: bool, + headless: bool, vllm_config: VllmConfig, parallel_config_to_update: Optional[ParallelConfig] = None, ) -> Generator[EngineZmqAddresses, None, None]: @@ -491,6 +533,7 @@ def _perform_handshake( bind=False) as handshake_socket: # Register engine with front-end. addresses = self.startup_handshake(handshake_socket, local_client, + headless, parallel_config_to_update) yield addresses @@ -504,6 +547,7 @@ def _perform_handshake( msgspec.msgpack.encode({ "status": "READY", "local": local_client, + "headless": headless, "num_gpu_blocks": num_gpu_blocks, "dp_stats_address": dp_stats_address, })) @@ -512,6 +556,7 @@ def _perform_handshake( def startup_handshake( handshake_socket: zmq.Socket, local_client: bool, + headless: bool, parallel_config: Optional[ParallelConfig] = None, ) -> EngineZmqAddresses: @@ -520,6 +565,7 @@ def startup_handshake( msgspec.msgpack.encode({ "status": "HELLO", "local": local_client, + "headless": headless, })) # Receive initialization message. @@ -864,22 +910,6 @@ def _init_data_parallel(self, vllm_config: VllmConfig): logger.debug("Setting kv_transfer_config.engine_id to %s", vllm_config.kv_transfer_config.engine_id) - from vllm.platforms import current_platform - device_control_env_var = current_platform.device_control_env_var - world_size = vllm_config.parallel_config.world_size - # Set CUDA_VISIBLE_DEVICES or equivalent. - try: - os.environ[device_control_env_var] = ",".join( - str(current_platform.device_id_to_physical_device_id(i)) - for i in range(local_dp_rank * - world_size, (local_dp_rank + 1) * world_size)) - except IndexError as e: - raise Exception( - f"Error setting {device_control_env_var}: " - f"local range: [{local_dp_rank * world_size}, " - f"{(local_dp_rank + 1) * world_size}) " - f"base value: \"{os.getenv(device_control_env_var)}\"") from e - self.dp_rank = dp_rank self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() @@ -977,6 +1007,50 @@ def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: return ParallelConfig.has_unfinished_dp(self.dp_group, local_unfinished) + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + stateless_destroy_torch_distributed_process_group(self.dp_group) + self.shutdown() + + parallel_config = self.vllm_config.parallel_config + old_dp_size = parallel_config.data_parallel_size + parallel_config.data_parallel_size = \ + reconfig_request.new_data_parallel_size + if reconfig_request.new_data_parallel_rank != -1: + parallel_config.data_parallel_rank = \ + reconfig_request.new_data_parallel_rank + # local rank specifies device visibility, it should not be changed + assert reconfig_request.new_data_parallel_rank_local == \ + ReconfigureRankType.KEEP_CURRENT_RANK + parallel_config.data_parallel_master_ip = \ + reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port + if reconfig_request.new_data_parallel_rank != -2: + self.dp_rank = parallel_config.data_parallel_rank + self.dp_group = parallel_config.stateless_init_dp_group() + reconfig_request.new_data_parallel_master_port = \ + parallel_config.data_parallel_master_port + + self.model_executor.reinitialize_distributed(reconfig_request) + if reconfig_request.new_data_parallel_size > old_dp_size: + assert self.available_gpu_memory_for_kv_cache > 0 + # pass available_gpu_memory_for_kv_cache from existing + # engine-cores to new engine-cores so they can directly + # use it in _initialize_kv_caches() rather than profiling. + ParallelConfig.sync_kv_cache_memory_size( + self.dp_group, self.available_gpu_memory_for_kv_cache) + # NOTE(yongji): newly joined workers require dummy_run even + # CUDA graph is not used + self.model_executor.collective_rpc("compile_or_warm_up_model") + if reconfig_request.new_data_parallel_rank == \ + ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + self.shutdown() + logger.info("DPEngineCoreProc %s shutdown", self.dp_rank) + else: + logger.info("Distributed environment reinitialized for DP rank %s", + self.dp_rank) + class DPEngineCoreActor(DPEngineCoreProc): """ @@ -998,14 +1072,41 @@ def __init__( vllm_config.parallel_config.data_parallel_rank_local = \ local_dp_rank - # Ray sets CUDA_VISIBLE_DEVICES to empty string, - # we clean this up to be able to properly initialize - # data parallel groups. - del os.environ['CUDA_VISIBLE_DEVICES'] + # Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle + # NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time, + # and this cannot be done in the same way for Ray because: + # 1) Ray manages life cycle of all ray workers (including + # DPEngineCoreActor) + # 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration + # To bypass 2, we need to also set + # RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created + # thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky: + # https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501 + # But vLLM worker assumes visibility into all local GPUs, therefore + # this results in incorrect indexing into the GPU ID list. + self._set_cuda_visible_devices(vllm_config, local_dp_rank) super().__init__(vllm_config, local_client, "", executor_class, log_stats) + def _set_cuda_visible_devices(self, vllm_config: VllmConfig, + local_dp_rank: int): + from vllm.platforms import current_platform + device_control_env_var = current_platform.device_control_env_var + world_size = vllm_config.parallel_config.world_size + # Set CUDA_VISIBLE_DEVICES or equivalent. + try: + os.environ[device_control_env_var] = ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * + world_size, (local_dp_rank + 1) * world_size)) + except IndexError as e: + raise Exception( + f"Error setting {device_control_env_var}: " + f"local range: [{local_dp_rank * world_size}, " + f"{(local_dp_rank + 1) * world_size}) " + f"base value: \"{os.getenv(device_control_env_var)}\"") from e + def _decorate_logs(self): pass diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index dafaa15f777d..69ae3690d00e 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -21,9 +21,11 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import get_open_zmq_inproc_path, make_zmq_socket +from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, - EngineCoreRequestType, UtilityOutput) + EngineCoreRequestType, + ReconfigureDistributedRequest, ReconfigureRankType, + UtilityOutput) from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.exceptions import EngineDeadError @@ -162,6 +164,9 @@ def dp_engines_running(self) -> bool: running state.""" raise NotImplementedError + async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -424,17 +429,23 @@ def __init__( parallel_config = vllm_config.parallel_config dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank - external_dp_lb = parallel_config.data_parallel_external_lb - + dp_local_size = parallel_config.data_parallel_size_local offline_mode = parallel_config.data_parallel_rank_local is not None - engine_ranks = [dp_rank] if (offline_mode - or external_dp_lb) else range(dp_size) + # Client manages local+remote EngineCores in pure internal LB case. + # Client manages local EngineCores in hybrid and external LB case. + local_engines_only = (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) + + num_ranks = dp_local_size if local_engines_only else dp_size + self.engine_ranks_managed = [dp_rank] if offline_mode else list( + range(dp_rank, dp_rank + num_ranks)) assert parallel_config.data_parallel_size_local <= len( - engine_ranks) + self.engine_ranks_managed) # ZMQ identity of each engine that this client will talk to. self.core_engines: list[EngineIdentity] = [ - index.to_bytes(2, "little") for index in engine_ranks + rank.to_bytes(2, "little") + for rank in self.engine_ranks_managed ] # Wait for ready messages from each engine on the input socket. @@ -889,6 +900,12 @@ def _ensure_stats_update_task(self): return assert self.stats_update_address is not None + assert len(self.engine_ranks_managed) > 0 + # NOTE: running and waiting counts are all global from + # the Coordinator include all global EngineCores. This + # slice includes just the cores managed by this client. + count_slice = slice(self.engine_ranks_managed[0], + self.engine_ranks_managed[-1] + 1) async def run_engine_stats_update_task(): with make_zmq_socket(self.ctx, self.stats_update_address, @@ -910,14 +927,30 @@ async def run_engine_stats_update_task(): events = await poller.poll() if not self.engines_running and len(events) == 2 or ( events[0][0] == first_req_rcv_socket): - # Send a message to notify the coordinator that + # Check if this is a regular request notification or + # scale up notification + buf = first_req_rcv_socket.recv( + flags=zmq.NOBLOCK).result() + + decoded = msgspec.msgpack.decode(buf) + if isinstance( + decoded, + (list, tuple)) and len(decoded) == 2 and decoded[ + 0] == "SCALE_ELASTIC_EP": + # Extract new engine count from the decoded message + new_engine_count = decoded[1] + # Send scale up notification to coordinator + scale_msg = msgspec.msgpack.encode( + ("SCALE_ELASTIC_EP", new_engine_count)) + await socket.send(scale_msg) + continue + # we're sending a request while the engines are # paused, so that it can wake the others up # (to run dummy EP loop). + assert decoded[0] == "FIRST_REQ" + target_eng_index = decoded[1] self.engines_running = True - buf = first_req_rcv_socket.recv( - flags=zmq.NOBLOCK).result() - target_eng_index = int.from_bytes(buf, "little") msg = msgspec.msgpack.encode( (target_eng_index, self.current_wave)) await socket.send(msg) @@ -937,7 +970,7 @@ async def run_engine_stats_update_task(): counts, wave, running = msgspec.msgpack.decode(buf) self.current_wave = wave self.engines_running = running - self.lb_engines = counts + self.lb_engines = counts[count_slice] resources.stats_update_task = asyncio.create_task( run_engine_stats_update_task()) @@ -953,7 +986,8 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: chosen_engine) if not self.engines_running: # Notify coordinator that we're sending a request - await self.first_req_send_socket.send(chosen_engine) + req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine)) + await self.first_req_send_socket.send(req_msg) await to_await @@ -1047,3 +1081,156 @@ async def _abort_requests(self, request_ids: list[str], engine: EngineIdentity) -> None: await self._send_input(EngineCoreRequestType.ABORT, request_ids, engine) + + async def _send_reconfig_message( + self, reconfig_request: ReconfigureDistributedRequest, + engine: EngineIdentity) -> asyncio.Future: + """Send reconfiguration message and return the result future without + waiting for completion.""" + call_id = uuid.uuid1().int >> 64 + future = asyncio.get_running_loop().create_future() + self.utility_results[call_id] = future + message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode( + (self.client_index, call_id, "reinitialize_distributed", + (reconfig_request, )))) + await self._send_input_message(message, engine, reconfig_request) + self._ensure_output_queue_task() + return future + + async def scale_elastic_ep(self, new_data_parallel_size: int) -> None: + """Scale elastic EP data parallel size""" + cur_data_parallel_size = len(self.core_engines) + + assert new_data_parallel_size != cur_data_parallel_size, ( + f"new_data_parallel_size {new_data_parallel_size} must be " + f"different from cur_data_parallel_size {cur_data_parallel_size}") + + assert self.vllm_config.parallel_config.data_parallel_backend == \ + "ray", ("Only ray DP backend supports scaling elastic EP") + + scale_up = new_data_parallel_size > cur_data_parallel_size + + if scale_up: + await self._scale_up_elastic_ep(cur_data_parallel_size, + new_data_parallel_size) + else: + await self._scale_down_elastic_ep(cur_data_parallel_size, + new_data_parallel_size) + + async def _scale_up_elastic_ep(self, cur_data_parallel_size: int, + new_data_parallel_size: int) -> None: + """Scale up the data parallel size by creating new engine cores + and reconfiguring existing ones.""" + cur_data_parallel_size = len(self.core_engines) + + # Phase 1: Send reconfigure messages to all existing engines and wait + # for them to be sent + reconfig_futures = [] + self.vllm_config.parallel_config.data_parallel_master_port = \ + get_open_port() + for engine in self.core_engines: + reconfig_request = ReconfigureDistributedRequest( + new_data_parallel_size=new_data_parallel_size, + new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_rank_local=\ + ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config. + data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config. + data_parallel_master_port) + future = await self._send_reconfig_message(reconfig_request, + engine) + reconfig_futures.append(future) + + logger.info("All reconfigure messages sent, starting engine creation") + + # Phase 2: Create new engines now that reconfig messages have been sent + # self.resources.engine_manager is guaranteed to be + # CoreEngineActorManager for RayDPClient + assert isinstance(self.resources.engine_manager, + CoreEngineActorManager) + self.resources.engine_manager.scale_up_elastic_ep( + self.vllm_config, new_data_parallel_size) + + # Create new CoreEngine objects for the new engines + new_engine_identities = set() + for i in range(cur_data_parallel_size, new_data_parallel_size): + new_engine = i.to_bytes(2, "little") + self.core_engines.append(new_engine) + new_engine_identities.add(new_engine) + + # Wait for ready messages from new engines on the input socket + sync_input_socket = zmq.Socket.shadow(self.input_socket) + while new_engine_identities: + if not sync_input_socket.poll(timeout=600_000): + raise TimeoutError( + "Timed out waiting for new engines to send initial " + "message on input socket.") + identity, _ = sync_input_socket.recv_multipart() + new_engine_identities.discard(identity) + + # Phase 3: Wait for all existing engines to complete reconfiguration + logger.info("Waiting for existing engines to complete reconfiguration") + await asyncio.gather(*reconfig_futures) + + # Notify coordinator about scale up through existing + # stats_update_task connection + self._ensure_stats_update_task() + scale_up_marker = msgspec.msgpack.encode( + ("SCALE_ELASTIC_EP", new_data_parallel_size)) + await self.first_req_send_socket.send(scale_up_marker) + + # Update the parallel config + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + logger.info( + "[Elastic EP] Scale up completed, new data parallel size: %s", + new_data_parallel_size) + + async def _scale_down_elastic_ep(self, cur_data_parallel_size: int, + new_data_parallel_size: int) -> None: + """Scale down the data parallel size by shutting down and + reconfiguring existing engine cores.""" + cur_data_parallel_size = len(self.core_engines) + + self.vllm_config.parallel_config.data_parallel_master_port = \ + get_open_port() + + reconfig_futures = [] + for cur_dp_rank, engine in enumerate(self.core_engines): + reconfig_request = ReconfigureDistributedRequest( + new_data_parallel_size=new_data_parallel_size, + new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_rank_local=\ + ReconfigureRankType.KEEP_CURRENT_RANK, + new_data_parallel_master_ip=self.vllm_config.parallel_config. + data_parallel_master_ip, + new_data_parallel_master_port=self.vllm_config.parallel_config. + data_parallel_master_port) + if cur_dp_rank >= new_data_parallel_size: + reconfig_request.new_data_parallel_rank = \ + ReconfigureRankType.SHUTDOWN_CURRENT_RANK + future = await self._send_reconfig_message(reconfig_request, + engine) + reconfig_futures.append(future) + + for _ in range(new_data_parallel_size, cur_data_parallel_size): + self.core_engines.pop() + + await asyncio.gather(*reconfig_futures) + + assert isinstance(self.resources.engine_manager, + CoreEngineActorManager) + self.resources.engine_manager.scale_down_elastic_ep( + cur_data_parallel_size, new_data_parallel_size) + + self._ensure_stats_update_task() + scale_down_marker = msgspec.msgpack.encode( + ("SCALE_ELASTIC_EP", new_data_parallel_size)) + await self.first_req_send_socket.send(scale_down_marker) + + self.vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + logger.info( + "[Elastic EP] Scale down completed, new data parallel size: %s", + new_data_parallel_size) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a2328c37ba0c..991242e18278 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -17,7 +17,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import ( TokenizerGroup, init_tokenizer_from_configs) @@ -82,11 +81,14 @@ def __init__( self.dp_group = None self.should_execute_dummy_batch = False - # Tokenizer (+ ensure liveness if running in another process). - self.tokenizer = init_tokenizer_from_configs( - model_config=vllm_config.model_config, - scheduler_config=vllm_config.scheduler_config, - lora_config=vllm_config.lora_config) + if self.model_config.skip_tokenizer_init: + self.tokenizer = None + else: + # Tokenizer (+ ensure liveness if running in another process). + self.tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) # Processor (convert Inputs --> EngineCoreRequests) self.processor = Processor(vllm_config=vllm_config, @@ -189,7 +191,6 @@ def add_request( lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: # Validate the request_id type. @@ -200,8 +201,7 @@ def add_request( # Process raw inputs into the request. prompt_str, request = self.processor.process_inputs( request_id, prompt, params, arrival_time, lora_request, - tokenization_kwargs, trace_headers, prompt_adapter_request, - priority) + tokenization_kwargs, trace_headers, priority) n = params.n if isinstance(params, SamplingParams) else 1 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0aa..3be6c4821214 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -327,14 +327,16 @@ def add_request( if request_id in self.request_states: raise ValueError(f"Request id {request_id} already running.") - req_state = RequestState.from_new_request( - tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), - request=request, - prompt=prompt, - parent_req=parent_req, - request_index=request_index, - queue=queue, - log_stats=self.log_stats) + tokenizer = None if not self.tokenizer else \ + self.tokenizer.get_lora_tokenizer(request.lora_request) + + req_state = RequestState.from_new_request(tokenizer=tokenizer, + request=request, + prompt=prompt, + parent_req=parent_req, + request_index=request_index, + queue=queue, + log_stats=self.log_stats) self.request_states[request_id] = req_state self.lora_states.add_request(req_state) if parent_req: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 9fc52543efde..0f2f404a130e 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -16,13 +16,14 @@ from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import merge_and_sort_multimodal_metadata from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine.mm_input_cache import MirroredProcessingCache from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) +from vllm.v1.structured_output.backend_outlines import ( + validate_structured_output_request_outlines) from vllm.v1.structured_output.backend_xgrammar import ( validate_xgrammar_grammar) @@ -193,6 +194,9 @@ def _validate_structured_output(self, params: SamplingParams) -> None: # https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens # Without tokenizer these are disallowed in grammars. validate_guidance_grammar(params, tokenizer=None) + elif engine_level_backend == "outlines": + # outlines backend + validate_structured_output_request_outlines(params) else: # NOTE: engine_level_backend must be "auto" here, because we have # checked supported_backends above. @@ -221,7 +225,6 @@ def process_inputs( lora_request: Optional[LoRARequest] = None, tokenization_kwargs: Optional[dict[str, Any]] = None, trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, data_parallel_rank: Optional[int] = None, ) -> tuple[Optional[str], EngineCoreRequest]: @@ -232,8 +235,6 @@ def process_inputs( self._validate_params(params, lora_request) if trace_headers is not None: raise ValueError("V1 does not support tracing yet.") - if prompt_adapter_request is not None: - raise ValueError("V1 does not support prompt_adapter_request.") data_parallel_size = self.vllm_config.parallel_config.data_parallel_size if data_parallel_rank is not None and not (0 <= data_parallel_rank < @@ -248,12 +249,10 @@ def process_inputs( # 1. Tokenize text prompt, with LoRA request if one exists. # 2. For multimodal models with a merged preprocessor, preprocess # multimodal data and expand prompt token ids accordingly. - # 3. Apply prompt adapter to prompt token ids if one exists. processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( prompt, tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, - prompt_adapter_request=prompt_adapter_request, return_mm_hashes=self.use_hash, ) from vllm.platforms import current_platform @@ -375,7 +374,6 @@ def _validate_model_input( prompt_type: Literal["encoder", "decoder"], ): model_config = self.model_config - tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) prompt_ids = prompt_inputs["prompt_token_ids"] if not prompt_ids: @@ -384,9 +382,14 @@ def _validate_model_input( else: raise ValueError(f"The {prompt_type} prompt cannot be empty") - max_input_id = max(prompt_ids, default=0) - if max_input_id > tokenizer.max_token_id: - raise ValueError(f"Token id {max_input_id} is out of vocabulary") + if self.model_config.skip_tokenizer_init: + tokenizer = None + else: + tokenizer = self.tokenizer.get_lora_tokenizer(lora_request) + max_input_id = max(prompt_ids, default=0) + if max_input_id > tokenizer.max_token_id: + raise ValueError( + f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len if len(prompt_ids) > max_prompt_len: diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index c4012419411a..f39aa4059326 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +import os import weakref from collections.abc import Iterator from dataclasses import dataclass @@ -9,12 +10,15 @@ from multiprocessing import Process, connection from multiprocessing.process import BaseProcess from typing import TYPE_CHECKING, Callable, Optional, Union +from unittest.mock import patch import msgspec import zmq from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.ray.ray_env import get_env_vars_to_copy from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.executor.abstract import Executor @@ -103,10 +107,13 @@ def __init__( "client_handshake_address"] = client_handshake_address self.processes: list[BaseProcess] = [] + local_dp_ranks = [] for index in range(local_engine_count): local_index = local_start_index + index global_index = start_index + index + # Start EngineCore in background process. + local_dp_ranks.append(local_index) self.processes.append( context.Process(target=target_fn, name=f"EngineCore_{global_index}", @@ -116,9 +123,14 @@ def __init__( })) self._finalizer = weakref.finalize(self, shutdown, self.processes) + + data_parallel = vllm_config.parallel_config.data_parallel_size > 1 try: - for proc in self.processes: - proc.start() + for proc, local_dp_rank in zip(self.processes, local_dp_ranks): + with set_device_control_env_var( + vllm_config, local_dp_rank) if ( + data_parallel) else contextlib.nullcontext(): + proc.start() finally: # Kill other procs if not all are running. if self.finished_procs(): @@ -143,6 +155,30 @@ def finished_procs(self) -> dict[str, int]: } +@contextlib.contextmanager +def set_device_control_env_var(vllm_config: VllmConfig, + local_dp_rank: int) -> Iterator[None]: + """ + Temporarily set CUDA_VISIBLE_DEVICES or equivalent + for engine subprocess. + """ + world_size = vllm_config.parallel_config.world_size + evar = current_platform.device_control_env_var + try: + value = ",".join( + str(current_platform.device_id_to_physical_device_id(i)) + for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * + world_size)) + except IndexError as e: + raise Exception(f"Error setting {evar}: " + f"local range: [{local_dp_rank * world_size}, " + f"{(local_dp_rank + 1) * world_size}) " + "base value: " + f"\"{os.getenv(evar)}\"") from e + with patch.dict(os.environ, values=((evar, value), )): + yield + + class CoreEngineActorManager: """ Utility class to handle creation, readiness, and shutdown @@ -164,6 +200,7 @@ def __init__( import copy import ray + from ray.runtime_env import RuntimeEnv from ray.util.scheduling_strategies import ( PlacementGroupSchedulingStrategy) @@ -171,6 +208,17 @@ def __init__( self.local_engine_actors: list[ray.ActorHandle] = [] self.remote_engine_actors: list[ray.ActorHandle] = [] + + env_vars_list = get_env_vars_to_copy(destination="DPEngineCoreActor") + self.env_vars_dict = { + name: os.environ[name] + for name in env_vars_list if name in os.environ + } + runtime_env = RuntimeEnv(env_vars=self.env_vars_dict) + + self.addresses = addresses + self.executor_class = executor_class + self.log_stats = log_stats dp_size = vllm_config.parallel_config.data_parallel_size local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local @@ -199,28 +247,30 @@ def __init__( assert len(placement_groups) == dp_size, ( "Number of placement groups must match data parallel size") + self.placement_group_is_local = [] refs = [] - for index in range(dp_size): - local_index = local_dp_ranks[index] + for index, local_index, pg in zip(range(dp_size), local_dp_ranks, + placement_groups): dp_vllm_config = copy.deepcopy(vllm_config) - pg = placement_groups[index] dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count actor = ray.remote(DPEngineCoreActor).options( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=pg, placement_group_bundle_index=world_size, - )).remote(vllm_config=dp_vllm_config, - executor_class=executor_class, - log_stats=log_stats, - local_client=local_client, - addresses=addresses, - dp_rank=index, - local_dp_rank=local_index) + ), + runtime_env=runtime_env).remote(vllm_config=dp_vllm_config, + executor_class=executor_class, + log_stats=log_stats, + local_client=local_client, + addresses=addresses, + dp_rank=index, + local_dp_rank=local_index) if local_client: self.local_engine_actors.append(actor) else: self.remote_engine_actors.append(actor) + self.placement_group_is_local.append(local_client) refs.append(actor.wait_for_init.remote()) ray.get(refs) @@ -232,6 +282,9 @@ def __init__( def create_dp_placement_groups( vllm_config: VllmConfig ) -> tuple[list["PlacementGroup"], list[int]]: + """ + Create placement groups for data parallel. + """ import ray from ray._private.state import available_resources_per_node @@ -240,7 +293,7 @@ def create_dp_placement_groups( logger.info("Creating placement groups for data parallel") dp_master_ip = \ vllm_config.parallel_config.data_parallel_master_ip - dp_size = vllm_config.parallel_config.data_parallel_size + num_pg_to_create = vllm_config.parallel_config.data_parallel_size local_engine_count = \ vllm_config.parallel_config.data_parallel_size_local @@ -283,7 +336,7 @@ def create_dp_placement_groups( local_dp_ranks.append(i) else: for i in range(available_engine_count): - if len(placement_groups) == dp_size: + if len(placement_groups) == num_pg_to_create: break bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] pg = ray.util.placement_group( @@ -295,6 +348,204 @@ def create_dp_placement_groups( local_dp_ranks.append(i) return placement_groups, local_dp_ranks + @staticmethod + def add_dp_placement_groups( + old_vllm_config: VllmConfig, new_data_parallel_size: int + ) -> tuple[list["PlacementGroup"], list[int]]: + """ + Add placement groups for new data parallel size. + """ + import ray + from ray._private.state import (available_resources_per_node, + total_resources_per_node) + from ray.util.state import list_nodes + + old_dp_size = old_vllm_config.parallel_config.data_parallel_size + num_pg_to_create = new_data_parallel_size - old_dp_size + + if num_pg_to_create <= 0: + return [], [] + + dp_master_ip = old_vllm_config.parallel_config.data_parallel_master_ip + world_size = old_vllm_config.parallel_config.world_size + + nodes = list_nodes() + nodes = sorted(nodes, key=lambda node: node.node_ip != dp_master_ip) + assert nodes[0].node_ip == dp_master_ip, ( + "The first node must be the head node") + assert len(nodes) == 1 or nodes[1].node_ip != dp_master_ip, ( + "There can only be one head node") + + available_resources = available_resources_per_node() + total_resources = total_resources_per_node() + + placement_groups = [] + local_dp_ranks = [] + num_pg_created = 0 + + for node in nodes: + if num_pg_created >= num_pg_to_create: + break + + node_ip = node.node_ip + node_id = node.node_id + available_gpus = int(available_resources[node_id]["GPU"]) + + # Get total GPUs on this node from the node's resources + # Ray stores node resources with node ID as key + total_gpus = int(total_resources[node_id]["GPU"]) + + # Calculate used GPUs and used engines on this node + used_gpus = max(0, total_gpus - available_gpus) + used_engines_on_node = used_gpus // world_size + + # Calculate how many new engines this node can accommodate + available_engine_count = available_gpus // world_size + + # Create placement groups for new engines on this node + for i in range(available_engine_count): + if num_pg_created >= num_pg_to_create: + break + + rank = old_dp_size + num_pg_created + + # Create bundles with node constraint for master node + if node_ip == dp_master_ip: + bundles = [{ + "GPU": 1.0, + "node:" + dp_master_ip: 0.001 + }] * world_size + [{ + "CPU": 1.0 + }] + else: + bundles = [{"GPU": 1.0}] * world_size + [{"CPU": 1.0}] + + pg = ray.util.placement_group( + name=f"dp_rank_{rank}", + strategy="STRICT_PACK", + bundles=bundles, + ) + placement_groups.append(pg) + + # Local rank starts from the number of engines already used + # on this node + local_rank = used_engines_on_node + i + local_dp_ranks.append(local_rank) + num_pg_created += 1 + + return placement_groups, local_dp_ranks + + def scale_up_elastic_ep(self, cur_vllm_config: VllmConfig, + new_data_parallel_size: int) -> None: + import copy + + import ray + from ray.runtime_env import RuntimeEnv + from ray.util.scheduling_strategies import ( + PlacementGroupSchedulingStrategy) + + from vllm.v1.engine.core import DPEngineCoreActor + + cur_data_parallel_size = len(self.local_engine_actors) + \ + len(self.remote_engine_actors) + + assert new_data_parallel_size > cur_data_parallel_size, ( + f"New data parallel size {new_data_parallel_size} must be greater " + f"than current data parallel size {cur_data_parallel_size} " + "for scale up") + + placement_groups, local_dp_ranks = \ + self.add_dp_placement_groups( + cur_vllm_config, new_data_parallel_size) + + world_size = cur_vllm_config.parallel_config.world_size + dp_master_ip = cur_vllm_config.parallel_config.data_parallel_master_ip + new_local_engines = 0 + + runtime_env = RuntimeEnv(env_vars=self.env_vars_dict + | {"VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": "1"}) + for i, (pg, + local_rank) in enumerate(zip(placement_groups, + local_dp_ranks)): + rank = cur_data_parallel_size + i + dp_vllm_config = copy.deepcopy(cur_vllm_config) + dp_vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + dp_vllm_config.parallel_config.placement_group = pg + + # Check if this placement group is on the head node + local_client = any( + bundle.get("node:" + dp_master_ip, 0) > 0 + for bundle in pg.bundle_specs) + + if local_client: + new_local_engines += 1 + # Update data_parallel_size_local + dp_vllm_config.parallel_config.data_parallel_size_local = ( + cur_vllm_config.parallel_config.data_parallel_size_local + + new_local_engines) + + actor = ray.remote(DPEngineCoreActor).options( + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=world_size, + ), + runtime_env=runtime_env).remote( + vllm_config=dp_vllm_config, + executor_class=self.executor_class, + log_stats=self.log_stats, + local_client=local_client, + addresses=self.addresses, + dp_rank=rank, + local_dp_rank=local_rank) + + if local_client: + self.local_engine_actors.append(actor) + else: + self.remote_engine_actors.append(actor) + self.created_placement_groups.append(pg) + self.placement_group_is_local.append(local_client) + + ray.get([ + actor.wait_for_init.remote() + for actor in (self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 else []) + + self.remote_engine_actors[-(len(placement_groups) - + new_local_engines):] + ]) + + actors = (self.local_engine_actors[-new_local_engines:] + if new_local_engines > 0 else []) + \ + self.remote_engine_actors[-(len(placement_groups) - + new_local_engines):] + + for actor in actors: + self.run_refs.append(actor.run.remote()) + + cur_vllm_config.parallel_config.data_parallel_size = \ + new_data_parallel_size + # Update old_vllm_config with new data_parallel_size_local if any new + # local engines were added + if new_local_engines > 0: + cur_vllm_config.parallel_config.data_parallel_size_local += \ + new_local_engines + + def scale_down_elastic_ep(self, cur_data_parallel_size: int, + new_data_parallel_size: int) -> None: + import ray + assert cur_data_parallel_size > new_data_parallel_size, ( + f"cur_data_parallel_size {cur_data_parallel_size} must be greater " + f"than new_data_parallel_size {new_data_parallel_size} " + "for scale down") + for _ in range(cur_data_parallel_size - new_data_parallel_size): + pg = self.created_placement_groups.pop() + is_local = self.placement_group_is_local.pop() + if is_local: + self.local_engine_actors.pop() + else: + self.remote_engine_actors.pop() + ray.util.remove_placement_group(pg) + def get_run_refs(self): return self.run_refs @@ -325,7 +576,8 @@ def launch_core_engines( local_start_index = parallel_config.data_parallel_rank_local dp_rank = parallel_config.data_parallel_rank host = parallel_config.data_parallel_master_ip - external_dp_lb = parallel_config.data_parallel_external_lb + local_engines_only = (parallel_config.data_parallel_hybrid_lb + or parallel_config.data_parallel_external_lb) # In offline mode there is an LLM instance per DP rank and # one core engine per LLM, see @@ -334,8 +586,8 @@ def launch_core_engines( # client_local_only = True for cases where this front-end # sends requests only to colocated engines. - client_local_only = offline_mode or external_dp_lb or (local_engine_count - == dp_size) + client_local_only = (offline_mode or local_engines_only + or (local_engine_count == dp_size)) # Set up input and output addresses. addresses = EngineZmqAddresses( @@ -379,14 +631,27 @@ def launch_core_engines( yield engine_actor_manager, coordinator, addresses return - if offline_mode or (external_dp_lb and dp_rank > 0): + if offline_mode: assert local_engine_count == 1 engines_to_handshake = [CoreEngine(index=dp_rank, local=True)] - else: + elif dp_rank == 0: + # Rank 0 holds Coordinator, so it handshakes with all Cores + # in both external dplb and internal dplb mode. + # Note this also covers the case where we have zero local engines + # and rank 0 is headless. engines_to_handshake = [ CoreEngine(index=i, local=(i < local_engine_count)) for i in range(dp_size) ] + else: + # Rank > 0 handshakes with just the local cores it is managing. + assert local_engines_only, ( + "Attempting to launch core_engines from dp_rank > 0, but " + "found internal DPLB, which is incompatible.") + engines_to_handshake = [ + CoreEngine(index=i, local=True) + for i in range(dp_rank, dp_rank + local_engine_count) + ] # Whether the started engines will handshake only with co-located # front-end processes. In external_dp_lb mode, ranks > 0 handshake with @@ -397,7 +662,7 @@ def launch_core_engines( handshake_address = get_engine_client_zmq_addr( handshake_local_only, host, parallel_config.data_parallel_rpc_port) - if external_dp_lb and dp_rank > 0: + if local_engines_only and dp_rank > 0: assert not handshake_local_only local_handshake_address = get_open_zmq_ipc_path() client_handshake_address = local_handshake_address @@ -412,8 +677,6 @@ def launch_core_engines( # Start local engines. if local_engine_count: - # In server mode, start_index and local_start_index will - # both be 0. local_engine_manager = CoreEngineProcManager( EngineCoreProc.run_engine_core, vllm_config=vllm_config, @@ -459,6 +722,9 @@ def wait_for_engine_startup( poller = zmq.Poller() poller.register(handshake_socket, zmq.POLLIN) + remote_should_be_headless = not parallel_config.data_parallel_hybrid_lb \ + and not parallel_config.data_parallel_external_lb + if proc_manager is not None: for sentinel in proc_manager.sentinels(): poller.register(sentinel, zmq.POLLIN) @@ -494,13 +760,24 @@ def wait_for_engine_startup( raise RuntimeError(f"Message from engine with unexpected data " f"parallel rank: {eng_index}") msg = msgspec.msgpack.decode(ready_msg_bytes) - status, local = msg["status"], msg["local"] + status, local, headless = msg["status"], msg["local"], msg["headless"] if local != engine.local: raise RuntimeError(f"{status} message from " f"{'local' if local else 'remote'} " f"engine {eng_index}, expected it to be " f"{'local' if engine.local else 'remote'}") + # Remote engines must be headless iff we aren't in hybrid dp lb mode. + if not local and headless != remote_should_be_headless: + if headless: + raise RuntimeError(f"Remote engine {eng_index} must not use " + f"--headless in external or hybrid dp lb " + f"mode") + else: + raise RuntimeError(f"Remote engine {eng_index} must use " + f"--headless unless in external or hybrid " + f"dp lb mode") + if status == "HELLO" and engine.state == CoreEngineState.NEW: # Send init message with DP config info. diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b06b7cc804d5..11ddade3eb70 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -26,11 +26,12 @@ destroy_model_parallel) from vllm.distributed.device_communicators.shm_broadcast import (Handle, MessageQueue) +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.multiproc_worker_utils import ( _add_prefix, set_multiprocessing_worker_envs) from vllm.logger import init_logger -from vllm.utils import (get_distributed_init_method, get_mp_context, - get_open_port) +from vllm.utils import (get_distributed_init_method, get_loopback_ip, + get_mp_context, get_open_port) from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.outputs import ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase @@ -62,9 +63,9 @@ def _init_executor(self) -> None: # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address - # 127.0.0.1 for communication. + # get_loopback_ip() for communication. distributed_init_method = get_distributed_init_method( - "127.0.0.1", get_open_port()) + get_loopback_ip(), get_open_port()) # Initialize worker and set up message queues for SchedulerOutputs # and ModelRunnerOutputs @@ -111,10 +112,14 @@ def _init_executor(self) -> None: if self.max_concurrent_batches > 1: # Note: must use only 1 IO thread to keep dequeue sequence # from the response queue + # _async_aggregate_workers_output also assumes a single IO thread self.io_thread_pool = ThreadPoolExecutor( max_workers=1, thread_name_prefix="mp_exec_io") self.output_rank = self._get_output_rank() + self.has_connector = self.vllm_config.kv_transfer_config is not None + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) def start_worker_monitor(self): workers = self.workers @@ -155,13 +160,30 @@ def execute_model( self, scheduler_output, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - (output, ) = self.collective_rpc( + non_block = self.max_concurrent_batches > 1 + + if not self.has_connector: + # get output only from a single worker (output_rank) + (output, ) = self.collective_rpc( + "execute_model", + args=(scheduler_output, ), + unique_reply_rank=self.output_rank, + non_block=non_block, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + return output + + # get output from all workers + outputs = self.collective_rpc( "execute_model", args=(scheduler_output, ), - unique_reply_rank=self.output_rank, - non_block=self.max_concurrent_batches > 1, + non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) - return output + + # aggregate all workers output to a single output + if non_block: + return self.kv_output_aggregator.async_aggregate( + outputs, self.output_rank) + return self.kv_output_aggregator.aggregate(outputs, self.output_rank) def collective_rpc(self, method: Union[str, Callable], @@ -271,6 +293,8 @@ def check_health(self) -> None: @property def max_concurrent_batches(self) -> int: + if self.scheduler_config.async_scheduling: + return 2 return self.parallel_config.pipeline_parallel_size def _get_output_rank(self) -> int: diff --git a/vllm/v1/executor/ray_distributed_executor.py b/vllm/v1/executor/ray_distributed_executor.py index 257564793cf4..b86ac048f520 100644 --- a/vllm/v1/executor/ray_distributed_executor.py +++ b/vllm/v1/executor/ray_distributed_executor.py @@ -2,37 +2,62 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Union +from typing import Optional, Union +from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.executor.ray_distributed_executor import ( # noqa RayDistributedExecutor as RayDistributedExecutorV0) +from vllm.logger import init_logger +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.executor.abstract import Executor from vllm.v1.outputs import ModelRunnerOutput +logger = init_logger(__name__) + class FutureWrapper(Future): - """A wrapper around a Ray output reference to meet the interface - of .execute_model(). + """A wrapper around Ray output reference to meet the interface + of .execute_model(): The top level (core busy loop) expects .result() api + to block and return a single output. + + If aggregator is provided, the outputs from all workers are aggregated upon + the result() call. If not only the first worker's output is returned. """ - def __init__(self, ref): + def __init__(self, refs, aggregator: Optional[KVOutputAggregator] = None): super().__init__() - self.ref = ref + self.refs = refs + self.aggregator = aggregator def result(self, timeout=None): if timeout is not None: raise NotImplementedError("timeout is not supported") - return self.ref.get() + + if self.aggregator is None: + return self.refs[0].get() + + outputs = [ref.get() for ref in self.refs] + return self.aggregator.aggregate(outputs, output_rank=0) class RayDistributedExecutor(RayDistributedExecutorV0, Executor): """Ray distributed executor using Ray Compiled Graphs.""" + def _init_executor(self) -> None: + super()._init_executor() + + # KV connector setup + self.has_connector = self.vllm_config.kv_transfer_config is not None + self.kv_output_aggregator = KVOutputAggregator( + self.parallel_config.world_size) + @property def max_concurrent_batches(self) -> int: """Ray distributed executor supports pipeline parallelism, meaning that it allows PP size batches to be executed concurrently. """ + if self.scheduler_config.async_scheduling: + return 2 return self.parallel_config.pipeline_parallel_size def execute_model( @@ -53,10 +78,29 @@ def execute_model( refs = self.forward_dag.execute(scheduler_output) # type: ignore - # When PP is not used, we block here until the result is available. + if not self.has_connector: + # Get output only from a single worker (output_rank) + # When PP is not used, we block here until the result is available. + if self.max_concurrent_batches == 1: + return refs[0].get() + + # When PP is used, we return a FutureWrapper immediately so that + # the scheduler can yield to the next batch. + return FutureWrapper(refs) + + # Get output from all workers when connector is present if self.max_concurrent_batches == 1: - return refs[0].get() + # Block and get results from all workers + outputs = [ref.get() for ref in refs] + return self.kv_output_aggregator.aggregate(outputs) + + # Return a future that will aggregate outputs from all workers + return FutureWrapper(refs, self.kv_output_aggregator) - # When PP is used, we return a FutureWrapper immediately so that - # the scheduler can yield to the next batch. - return FutureWrapper(refs[0]) + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + self._run_workers("reinitialize_distributed", reconfig_request) + if reconfig_request.new_data_parallel_rank == \ + ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + self.shutdown() + return \ No newline at end of file diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index 43456a987def..bec31a7a058d 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -87,6 +87,7 @@ def page_size_bytes(self) -> int: @dataclass class FullAttentionSpec(AttentionSpec): sliding_window: Optional[int] = None + attention_chunk_size: Optional[int] = None """ When hybrid allocator is disabled and the model contains both full attention layers and sliding window attention layers, sliding @@ -105,6 +106,17 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: max_model_len = vllm_config.model_config.max_model_len return cdiv(max_model_len, self.block_size) * self.page_size_bytes + @classmethod + def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]: + if len(window_sizes) == 0: + return None + elif len(window_sizes) == 1: + return window_sizes.pop() + else: + raise ValueError( + "All attention layers in the same KV cache group must have the " + "same window size.") + @classmethod def merge(cls, specs: list[Self]) -> Self: """ @@ -114,17 +126,45 @@ def merge(cls, specs: list[Self]) -> Self: merged_spec = super().merge(specs) sliding_window = set(spec.sliding_window for spec in specs if spec.sliding_window is not None) - if len(sliding_window) == 0: - merged_spec.sliding_window = None - elif len(sliding_window) == 1: - merged_spec.sliding_window = sliding_window.pop() - else: - raise ValueError( - "All sliding window layers in the same KV cache group " - "must have the same window size.") + attention_chunk_size = set(spec.attention_chunk_size for spec in specs + if spec.attention_chunk_size is not None) + + merged_spec.sliding_window = cls.merge_window_sizes(sliding_window) + merged_spec.attention_chunk_size = ( + cls.merge_window_sizes(attention_chunk_size)) + assert ( + (merged_spec.sliding_window is not None) + + (merged_spec.attention_chunk_size is not None) <= 1 + ), ("Model with both sliding window layers and chunked local attention " + "layers is not supported.") return merged_spec +@dataclass +class ChunkedLocalAttentionSpec(AttentionSpec): + attention_chunk_size: int + + @property + def type_id(self) -> str: + return ( + f"local_attention_{self.attention_chunk_size}_{self.block_size}_{self.page_size_bytes}" + ) # noqa + + def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = ( + vllm_config.scheduler_config.max_num_batched_tokens) + + # During chunked prefill, we allocate KV cache for at most + # `self.attention_chunk_size` computed tokens plus the newly scheduled + # tokens. And we won't allocate KV cache for more than `max_model_len` + # tokens. + num_tokens = min(self.attention_chunk_size + max_num_batched_tokens, + max_model_len) + + return cdiv(num_tokens, self.block_size) * self.page_size_bytes + + @dataclass class SlidingWindowSpec(AttentionSpec): sliding_window: int diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index c720ca13e51b..7f2556bab5a4 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -4,7 +4,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Callable, Optional, Union import numpy as np import prometheus_client @@ -35,8 +35,10 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ... @abstractmethod - def record(self, scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats]): + def record(self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0): ... @abstractmethod @@ -78,8 +80,10 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def record(self, scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats]): + def record(self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0): """Log Stats to standard output.""" if iteration_stats: @@ -146,233 +150,290 @@ class PrometheusStatLogger(StatLoggerBase): _histogram_cls = prometheus_client.Histogram _spec_decoding_cls = SpecDecodingProm - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + def __init__(self, + vllm_config: VllmConfig, + engine_indexes: Optional[list[int]] = None): + if engine_indexes is None: + engine_indexes = [0] + self.engine_indexes = engine_indexes unregister_vllm_metrics() self.vllm_config = vllm_config - self.engine_index = engine_index # Use this flag to hide metrics that were deprecated in # a previous release and which will be removed future self.show_hidden_metrics = \ vllm_config.observability_config.show_hidden_metrics labelnames = ["model_name", "engine"] - labelvalues = [ - vllm_config.model_config.served_model_name, - str(engine_index) - ] - + model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len + if (len(self.engine_indexes) > 1 + and vllm_config.speculative_config is not None): + raise NotImplementedError("Prometheus metrics with Spec Decoding " + "with >1 EngineCore per AsyncLLM is not " + "supported yet.") + spec_decode_labelvalues = [ + vllm_config.model_config.served_model_name, + str(self.engine_indexes[0]) + ] self.spec_decoding_prom = self._spec_decoding_cls( - vllm_config.speculative_config, labelnames, labelvalues) + vllm_config.speculative_config, labelnames, + spec_decode_labelvalues) # # Scheduler state # - self.gauge_scheduler_running = self._gauge_cls( + gauge_scheduler_running = self._gauge_cls( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", multiprocess_mode="mostrecent", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.gauge_scheduler_running = make_per_engine(gauge_scheduler_running, + engine_indexes, + model_name) - self.gauge_scheduler_waiting = self._gauge_cls( + gauge_scheduler_waiting = self._gauge_cls( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", multiprocess_mode="mostrecent", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.gauge_scheduler_waiting = make_per_engine(gauge_scheduler_waiting, + engine_indexes, + model_name) # # GPU cache # # Deprecated in 0.9 - Renamed as vllm:kv_cache_usage_perc # TODO: in 0.10, only enable if show_hidden_metrics=True - self.gauge_gpu_cache_usage = self._gauge_cls( + gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation=( "GPU KV-cache usage. 1 means 100 percent usage." "DEPRECATED: Use vllm:kv_cache_usage_perc instead."), multiprocess_mode="mostrecent", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.gauge_gpu_cache_usage = make_per_engine(gauge_gpu_cache_usage, + engine_indexes, + model_name) # Deprecated in 0.9 - Renamed as vllm:prefix_cache_queries # TODO: in 0.10, only enable if show_hidden_metrics=True - self.counter_gpu_prefix_cache_queries = self._counter_cls( + counter_gpu_prefix_cache_queries = self._counter_cls( name="vllm:gpu_prefix_cache_queries", - documentation= - ("GPU prefix cache queries, in terms of number of queried tokens." - "DEPRECATED: Use vllm:prefix_cache_queries instead."), - labelnames=labelnames).labels(*labelvalues) + documentation=( + "GPU prefix cache queries, in terms of number of queried" + "tokens. DEPRECATED: Use vllm:prefix_cache_queries instead."), + labelnames=labelnames) + self.counter_gpu_prefix_cache_queries = make_per_engine( + counter_gpu_prefix_cache_queries, engine_indexes, model_name) # Deprecated in 0.9 - Renamed as vllm:prefix_cache_hits # TODO: in 0.10, only enable if show_hidden_metrics=True - self.counter_gpu_prefix_cache_hits = self._counter_cls( + counter_gpu_prefix_cache_hits = self._counter_cls( name="vllm:gpu_prefix_cache_hits", documentation=( - "GPU prefix cache hits, in terms of number of cached tokens." - "DEPRECATED: Use vllm:prefix_cache_hits instead."), - labelnames=labelnames).labels(*labelvalues) + "GPU prefix cache hits, in terms of number of cached " + "tokens. DEPRECATED: Use vllm:prefix_cache_hits instead."), + labelnames=labelnames) + self.counter_gpu_prefix_cache_hits = make_per_engine( + counter_gpu_prefix_cache_hits, engine_indexes, model_name) - self.gauge_kv_cache_usage = self._gauge_cls( + gauge_kv_cache_usage = self._gauge_cls( name="vllm:kv_cache_usage_perc", documentation="KV-cache usage. 1 means 100 percent usage.", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.gauge_kv_cache_usage = make_per_engine(gauge_kv_cache_usage, + engine_indexes, model_name) - self.counter_prefix_cache_queries = self._counter_cls( + counter_prefix_cache_queries = self._counter_cls( name="vllm:prefix_cache_queries", documentation=( "Prefix cache queries, in terms of number of queried tokens."), - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.counter_prefix_cache_queries = make_per_engine( + counter_prefix_cache_queries, engine_indexes, model_name) - self.counter_prefix_cache_hits = self._counter_cls( + counter_prefix_cache_hits = self._counter_cls( name="vllm:prefix_cache_hits", documentation=( "Prefix cache hits, in terms of number of cached tokens."), - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.counter_prefix_cache_hits = make_per_engine( + counter_prefix_cache_hits, engine_indexes, model_name) # # Counters # - self.counter_num_preempted_reqs = self._counter_cls( + counter_num_preempted_reqs = self._counter_cls( name="vllm:num_preemptions", documentation="Cumulative number of preemption from the engine.", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.counter_num_preempted_reqs = make_per_engine( + counter_num_preempted_reqs, engine_indexes, model_name) - self.counter_prompt_tokens = self._counter_cls( + counter_prompt_tokens = self._counter_cls( name="vllm:prompt_tokens", documentation="Number of prefill tokens processed.", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.counter_prompt_tokens = make_per_engine(counter_prompt_tokens, + engine_indexes, + model_name) - self.counter_generation_tokens = self._counter_cls( + counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens", documentation="Number of generation tokens processed.", - labelnames=labelnames).labels(*labelvalues) + labelnames=labelnames) + self.counter_generation_tokens = make_per_engine( + counter_generation_tokens, engine_indexes, model_name) - self.counter_request_success: dict[FinishReason, - prometheus_client.Counter] = {} + self.counter_request_success: dict[FinishReason, dict[ + int, prometheus_client.Counter]] = {} counter_request_success_base = self._counter_cls( name="vllm:request_success", documentation="Count of successfully processed requests.", labelnames=labelnames + ["finished_reason"]) for reason in FinishReason: - self.counter_request_success[ - reason] = counter_request_success_base.labels(*(labelvalues + - [str(reason)])) + self.counter_request_success[reason] = { + idx: + counter_request_success_base.labels(model_name, str(idx), + str(reason)) + for idx in engine_indexes + } # # Histograms of counts # - self.histogram_num_prompt_tokens_request = \ - self._histogram_cls( - name="vllm:request_prompt_tokens", - documentation="Number of prefill tokens processed.", - buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames).labels(*labelvalues) - - self.histogram_num_generation_tokens_request = \ - self._histogram_cls( - name="vllm:request_generation_tokens", - documentation="Number of generation tokens processed.", - buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames).labels(*labelvalues) + histogram_num_prompt_tokens_request = self._histogram_cls( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + self.histogram_num_prompt_tokens_request = make_per_engine( + histogram_num_prompt_tokens_request, engine_indexes, model_name) + + histogram_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + self.histogram_num_generation_tokens_request = make_per_engine( + histogram_num_generation_tokens_request, engine_indexes, + model_name) # TODO: This metric might be incorrect in case of using multiple # api_server counts which uses prometheus mp. # See: https://github.com/vllm-project/vllm/pull/18053 - self.histogram_iteration_tokens = \ - self._histogram_cls( - name="vllm:iteration_tokens_total", - documentation="Histogram of number of tokens per engine_step.", - buckets=[ - 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, - 16384 - ], - labelnames=labelnames).labels(*labelvalues) - - self.histogram_max_num_generation_tokens_request = \ - self._histogram_cls( - name="vllm:request_max_num_generation_tokens", - documentation= - "Histogram of maximum number of requested generation tokens.", - buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames).labels(*labelvalues) - - self.histogram_n_request = \ - self._histogram_cls( - name="vllm:request_params_n", - documentation="Histogram of the n request parameter.", - buckets=[1, 2, 5, 10, 20], - labelnames=labelnames).labels(*labelvalues) - - self.histogram_max_tokens_request = \ - self._histogram_cls( - name="vllm:request_params_max_tokens", - documentation="Histogram of the max_tokens request parameter.", - buckets=build_1_2_5_buckets(max_model_len), - labelnames=labelnames).labels(*labelvalues) + histogram_iteration_tokens = self._histogram_cls( + name="vllm:iteration_tokens_total", + documentation="Histogram of number of tokens per engine_step.", + buckets=[ + 1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + ], + labelnames=labelnames) + self.histogram_iteration_tokens = make_per_engine( + histogram_iteration_tokens, engine_indexes, model_name) + + histogram_max_num_generation_tokens_request = self._histogram_cls( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + self.histogram_max_num_generation_tokens_request = make_per_engine( + histogram_max_num_generation_tokens_request, engine_indexes, + model_name) + + histogram_n_request = self._histogram_cls( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + buckets=[1, 2, 5, 10, 20], + labelnames=labelnames) + self.histogram_n_request = make_per_engine(histogram_n_request, + engine_indexes, model_name) + + histogram_max_tokens_request = self._histogram_cls( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames) + self.histogram_max_tokens_request = make_per_engine( + histogram_max_tokens_request, engine_indexes, model_name) # # Histogram of timing intervals # - self.histogram_time_to_first_token = \ - self._histogram_cls( - name="vllm:time_to_first_token_seconds", - documentation="Histogram of time to first token in seconds.", - buckets=[ - 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, - 640.0, 2560.0 - ], - labelnames=labelnames).labels(*labelvalues) - - self.histogram_time_per_output_token = \ - self._histogram_cls( - name="vllm:time_per_output_token_seconds", - documentation="Histogram of time per output token in seconds.", - buckets=[ - 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, - 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 - ], - labelnames=labelnames).labels(*labelvalues) + histogram_time_to_first_token = self._histogram_cls( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0, 160.0, 640.0, + 2560.0 + ], + labelnames=labelnames) + self.histogram_time_to_first_token = make_per_engine( + histogram_time_to_first_token, engine_indexes, model_name) + + histogram_time_per_output_token = self._histogram_cls( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5, 5.0, 7.5, 10.0, 20.0, 40.0, 80.0 + ], + labelnames=labelnames) + self.histogram_time_per_output_token = make_per_engine( + histogram_time_per_output_token, engine_indexes, model_name) request_latency_buckets = [ 0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0, 120.0, 240.0, 480.0, 960.0, 1920.0, 7680.0 ] - self.histogram_e2e_time_request = \ - self._histogram_cls( - name="vllm:e2e_request_latency_seconds", - documentation="Histogram of e2e request latency in seconds.", - buckets=request_latency_buckets, - labelnames=labelnames).labels(*labelvalues) - self.histogram_queue_time_request = \ - self._histogram_cls( - name="vllm:request_queue_time_seconds", - documentation= - "Histogram of time spent in WAITING phase for request.", - buckets=request_latency_buckets, - labelnames=labelnames).labels(*labelvalues) - self.histogram_inference_time_request = \ - self._histogram_cls( - name="vllm:request_inference_time_seconds", - documentation= - "Histogram of time spent in RUNNING phase for request.", - buckets=request_latency_buckets, - labelnames=labelnames).labels(*labelvalues) - self.histogram_prefill_time_request = \ - self._histogram_cls( - name="vllm:request_prefill_time_seconds", - documentation= - "Histogram of time spent in PREFILL phase for request.", - buckets=request_latency_buckets, - labelnames=labelnames).labels(*labelvalues) - self.histogram_decode_time_request = \ - self._histogram_cls( - name="vllm:request_decode_time_seconds", - documentation= - "Histogram of time spent in DECODE phase for request.", - buckets=request_latency_buckets, - labelnames=labelnames).labels(*labelvalues) + histogram_e2e_time_request = self._histogram_cls( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of e2e request latency in seconds.", + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_e2e_time_request = make_per_engine( + histogram_e2e_time_request, engine_indexes, model_name) + + histogram_queue_time_request = self._histogram_cls( + name="vllm:request_queue_time_seconds", + documentation= + "Histogram of time spent in WAITING phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_queue_time_request = make_per_engine( + histogram_queue_time_request, engine_indexes, model_name) + + histogram_inference_time_request = self._histogram_cls( + name="vllm:request_inference_time_seconds", + documentation= + "Histogram of time spent in RUNNING phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_inference_time_request = make_per_engine( + histogram_inference_time_request, engine_indexes, model_name) + + histogram_prefill_time_request = self._histogram_cls( + name="vllm:request_prefill_time_seconds", + documentation= + "Histogram of time spent in PREFILL phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_prefill_time_request = make_per_engine( + histogram_prefill_time_request, engine_indexes, model_name) + + histogram_decode_time_request = self._histogram_cls( + name="vllm:request_decode_time_seconds", + documentation= + "Histogram of time spent in DECODE phase for request.", + buckets=request_latency_buckets, + labelnames=labelnames) + self.histogram_decode_time_request = make_per_engine( + histogram_decode_time_request, engine_indexes, model_name) # # LoRA metrics @@ -382,6 +443,9 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): # api_server counts which uses prometheus mp. self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: + if len(self.engine_indexes) > 1: + raise NotImplementedError( + "LoRA in DP mode is not supported yet.") self.labelname_max_lora = "max_lora" self.labelname_waiting_lora_adapters = "waiting_lora_adapters" self.labelname_running_lora_adapters = "running_lora_adapters" @@ -399,9 +463,8 @@ def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): ) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): - metrics_info = config_obj.metrics_info() - metrics_info["engine"] = self.engine_index + metrics_info["engine"] = "" name, documentation = None, None if type == "cache_config": @@ -417,27 +480,36 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): documentation=documentation, multiprocess_mode="mostrecent", labelnames=metrics_info.keys(), - ).labels(**metrics_info) - info_gauge.set(1) - - def record(self, scheduler_stats: Optional[SchedulerStats], - iteration_stats: Optional[IterationStats]): + ) + for engine_index in self.engine_indexes: + metrics_info = config_obj.metrics_info() + metrics_info["engine"] = str(engine_index) + info_gauge.labels(**metrics_info).set(1) + + def record(self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: int = 0): """Log to prometheus.""" if scheduler_stats is not None: - self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) - self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + self.gauge_scheduler_running[engine_idx].set( + scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting[engine_idx].set( + scheduler_stats.num_waiting_reqs) - self.gauge_gpu_cache_usage.set(scheduler_stats.kv_cache_usage) - self.gauge_kv_cache_usage.set(scheduler_stats.kv_cache_usage) + self.gauge_gpu_cache_usage[engine_idx].set( + scheduler_stats.kv_cache_usage) + self.gauge_kv_cache_usage[engine_idx].set( + scheduler_stats.kv_cache_usage) - self.counter_gpu_prefix_cache_queries.inc( + self.counter_gpu_prefix_cache_queries[engine_idx].inc( scheduler_stats.prefix_cache_stats.queries) - self.counter_gpu_prefix_cache_hits.inc( + self.counter_gpu_prefix_cache_hits[engine_idx].inc( scheduler_stats.prefix_cache_stats.hits) - self.counter_prefix_cache_queries.inc( + self.counter_prefix_cache_queries[engine_idx].inc( scheduler_stats.prefix_cache_stats.queries) - self.counter_prefix_cache_hits.inc( + self.counter_prefix_cache_hits[engine_idx].inc( scheduler_stats.prefix_cache_stats.hits) if scheduler_stats.spec_decoding_stats is not None: @@ -447,42 +519,45 @@ def record(self, scheduler_stats: Optional[SchedulerStats], if iteration_stats is None: return - self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs) - self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) - self.counter_generation_tokens.inc( + self.counter_num_preempted_reqs[engine_idx].inc( + iteration_stats.num_preempted_reqs) + self.counter_prompt_tokens[engine_idx].inc( + iteration_stats.num_prompt_tokens) + self.counter_generation_tokens[engine_idx].inc( iteration_stats.num_generation_tokens) - self.histogram_iteration_tokens.observe( + self.histogram_iteration_tokens[engine_idx].observe( iteration_stats.num_prompt_tokens + \ iteration_stats.num_generation_tokens) for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: - self.histogram_max_num_generation_tokens_request.observe( - max_gen_tokens) + self.histogram_max_num_generation_tokens_request[ + engine_idx].observe(max_gen_tokens) for n_param in iteration_stats.n_params_iter: - self.histogram_n_request.observe(n_param) + self.histogram_n_request[engine_idx].observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: - self.histogram_time_to_first_token.observe(ttft) + self.histogram_time_to_first_token[engine_idx].observe(ttft) for tpot in iteration_stats.time_per_output_tokens_iter: - self.histogram_time_per_output_token.observe(tpot) + self.histogram_time_per_output_token[engine_idx].observe(tpot) for finished_request in iteration_stats.finished_requests: - self.counter_request_success[finished_request.finish_reason].inc() - self.histogram_e2e_time_request.observe( + self.counter_request_success[ + finished_request.finish_reason][engine_idx].inc() + self.histogram_e2e_time_request[engine_idx].observe( finished_request.e2e_latency) - self.histogram_queue_time_request.observe( + self.histogram_queue_time_request[engine_idx].observe( finished_request.queued_time) - self.histogram_prefill_time_request.observe( + self.histogram_prefill_time_request[engine_idx].observe( finished_request.prefill_time) - self.histogram_inference_time_request.observe( + self.histogram_inference_time_request[engine_idx].observe( finished_request.inference_time) - self.histogram_decode_time_request.observe( + self.histogram_decode_time_request[engine_idx].observe( finished_request.decode_time) - self.histogram_num_prompt_tokens_request.observe( + self.histogram_num_prompt_tokens_request[engine_idx].observe( finished_request.num_prompt_tokens) - self.histogram_num_generation_tokens_request.observe( + self.histogram_num_generation_tokens_request[engine_idx].observe( finished_request.num_generation_tokens) if finished_request.max_tokens_param: - self.histogram_max_tokens_request.observe( + self.histogram_max_tokens_request[engine_idx].observe( finished_request.max_tokens_param) if self.gauge_lora_info is not None: @@ -502,6 +577,18 @@ def log_engine_initialized(self): self.log_metrics_info("cache_config", self.vllm_config.cache_config) +PromMetric = Union[ + prometheus_client.Gauge, + prometheus_client.Counter, + prometheus_client.Histogram, +] + + +def make_per_engine(metric: PromMetric, engine_idxs: list[int], + model_name: str) -> dict[int, PromMetric]: + return {idx: metric.labels(model_name, str(idx)) for idx in engine_idxs} + + def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]: """ Builds a list of buckets with increasing powers of 10 multiplied by @@ -529,29 +616,79 @@ def build_1_2_5_buckets(max_value: int) -> list[int]: return build_buckets([1, 2, 5], max_value) -def setup_default_loggers( - vllm_config: VllmConfig, - log_stats: bool, - engine_num: int, - custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, -) -> list[list[StatLoggerBase]]: - """Setup logging and prometheus metrics.""" - if not log_stats: - return [] - - factories: list[StatLoggerFactory] - if custom_stat_loggers is not None: - factories = custom_stat_loggers - else: - factories = [PrometheusStatLogger] - if logger.isEnabledFor(logging.INFO): - factories.append(LoggingStatLogger) - - stat_loggers: list[list[StatLoggerBase]] = [] - for i in range(engine_num): - per_engine_stat_loggers: list[StatLoggerBase] = [] - for logger_factory in factories: - per_engine_stat_loggers.append(logger_factory(vllm_config, i)) - stat_loggers.append(per_engine_stat_loggers) - - return stat_loggers +class StatLoggerManager: + """ + StatLoggerManager: + Logging happens at the level of the EngineCore (per scheduler). + * DP: >1 EngineCore per AsyncLLM - loggers for each EngineCore. + * With Local Logger, just make N copies for N EngineCores. + * With Prometheus, we need a single logger with N "labels" + + This class abstracts away this implementation detail from + the AsyncLLM, allowing the AsyncLLM to just call .record() + and .log() to a simple interface. + """ + + def __init__( + self, + vllm_config: VllmConfig, + engine_idxs: Optional[list[int]] = None, + custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, + ): + self.engine_idxs = engine_idxs if engine_idxs else [0] + + factories: list[StatLoggerFactory] + if custom_stat_loggers is not None: + factories = custom_stat_loggers + else: + factories = [] + if logger.isEnabledFor(logging.INFO): + factories.append(LoggingStatLogger) + + # engine_idx: StatLogger + self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {} + prometheus_factory = PrometheusStatLogger + for engine_idx in self.engine_idxs: + loggers: list[StatLoggerBase] = [] + for logger_factory in factories: + # If we get a custom prometheus logger, use that + # instead. This is typically used for the ray case. + if (isinstance(logger_factory, type) + and issubclass(logger_factory, PrometheusStatLogger)): + prometheus_factory = logger_factory + continue + loggers.append(logger_factory(vllm_config, + engine_idx)) # type: ignore + self.per_engine_logger_dict[engine_idx] = loggers + + # For Prometheus, need to share the metrics between EngineCores. + # Each EngineCore's metrics are expressed as a unique label. + self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs) + + def record( + self, + scheduler_stats: Optional[SchedulerStats], + iteration_stats: Optional[IterationStats], + engine_idx: Optional[int] = None, + ): + if engine_idx is None: + engine_idx = 0 + + per_engine_loggers = self.per_engine_logger_dict[engine_idx] + for logger in per_engine_loggers: + logger.record(scheduler_stats, iteration_stats, engine_idx) + + self.prometheus_logger.record(scheduler_stats, iteration_stats, + engine_idx) + + def log(self): + for per_engine_loggers in self.per_engine_logger_dict.values(): + for logger in per_engine_loggers: + logger.log() + + def log_engine_initialized(self): + self.prometheus_logger.log_engine_initialized() + + for per_engine_loggers in self.per_engine_logger_dict.values(): + for logger in per_engine_loggers: + logger.log_engine_initialized() diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index cce692d6c09e..ae8f9447e9c8 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -3,7 +3,6 @@ import time from typing import Optional, Union -from vllm.config import VllmConfig from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.spec_decode.metrics import SpecDecodingProm @@ -51,7 +50,13 @@ class RayGaugeWrapper(RayPrometheusMetric): def __init__(self, name: str, documentation: Optional[str] = "", - labelnames: Optional[list[str]] = None): + labelnames: Optional[list[str]] = None, + multiprocess_mode: Optional[str] = ""): + + # All Ray metrics are keyed by WorkerId, so multiprocess modes like + # "mostrecent", "all", "sum" do not apply. This logic can be manually + # implemented at the observability layer (Prometheus/Grafana). + del multiprocess_mode labelnames_tuple = tuple(labelnames) if labelnames else None self.metric = ray_metrics.Gauge(name=name, description=documentation, @@ -122,9 +127,6 @@ class RayPrometheusStatLogger(PrometheusStatLogger): _histogram_cls = RayHistogramWrapper _spec_decoding_cls = RaySpecDecodingProm - def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): - super().__init__(vllm_config, engine_index) - @staticmethod def _unregister_vllm_metrics(): # No-op on purpose diff --git a/vllm/v1/pool/metadata.py b/vllm/v1/pool/metadata.py index 5f321cd87c52..28af720d05fd 100644 --- a/vllm/v1/pool/metadata.py +++ b/vllm/v1/pool/metadata.py @@ -15,3 +15,11 @@ class PoolingMetadata: prompt_lens: torch.Tensor prompt_token_ids: Optional[torch.Tensor] pooling_params: list[PoolingParams] + + def __getitem__(self, indices: slice): + return PoolingMetadata( + prompt_lens=self.prompt_lens[indices], + prompt_token_ids=None if self.prompt_token_ids is None else + self.prompt_token_ids[indices], + pooling_params=self.pooling_params[indices], + ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9b96f4599f92..85f5dcb92eb4 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -77,6 +77,7 @@ def __init__( self.num_prompt_tokens = len(self.prompt_token_ids) self._output_token_ids: list[int] = [] self._all_token_ids: list[int] = self.prompt_token_ids.copy() + self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 self.cache_salt: Optional[str] = cache_salt diff --git a/vllm/v1/sample/logits_processor.py b/vllm/v1/sample/logits_processor.py index 16bd2b9ffd84..3a06e71057cd 100644 --- a/vllm/v1/sample/logits_processor.py +++ b/vllm/v1/sample/logits_processor.py @@ -234,10 +234,16 @@ def __init__(self, max_num_reqs: int, pin_memory: bool, device="cpu", pin_memory=pin_memory) self.min_p_cpu = self.min_p_cpu_tensor.numpy() - # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) + + self.use_double_tensor = torch.device("cpu") != torch.device(device) + + if self.use_double_tensor: + # Pre-allocated device tensor + self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + else: + self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor self.min_p: torch.Tensor = self.min_p_device[:0] @@ -284,7 +290,9 @@ def update_state(self, batch_update: Optional[BatchUpdate]): size = batch_update.batch_size if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] - self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) + if self.use_double_tensor: + self.min_p.copy_(self.min_p_cpu_tensor[:size], + non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -327,14 +335,19 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if not batch_update: return + needs_update: bool = False # Process added requests. - needs_update = bool(batch_update.added) for index, params, _ in batch_update.added: if isinstance(params, SamplingParams) and (lb := params.logit_bias): self.biases[index] = lb + needs_update = True else: - self.biases.pop(index, None) + # Drop biases metadata at batch index + if self.biases.pop(index, None) is not None: + # If a new request replaces an old request which + # specified biases, we should update processor tensors + needs_update = True if self.biases: # Process removed requests. @@ -411,7 +424,6 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: # Process added requests. - needs_update |= bool(batch_update.added) for index, params, output_tok_ids in batch_update.added: if (isinstance(params, SamplingParams) and (min_tokens := params.min_tokens) @@ -419,9 +431,13 @@ def update_state(self, batch_update: Optional[BatchUpdate]): # Replace request metadata at batch index self.min_toks[index] = (min_tokens, output_tok_ids, params.all_stop_token_ids) + needs_update = True else: - # Drop request metadata at batch index - self.min_toks.pop(index, None) + # Drop min_toks metadata at batch index + if self.min_toks.pop(index, None) is not None: + # If a new request replaces an old request which + # specified min_toks, we should update processor tensors + needs_update = True if self.min_toks: # Process removed requests. diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py new file mode 100644 index 000000000000..a4d65485140e --- /dev/null +++ b/vllm/v1/sample/ops/logprobs.py @@ -0,0 +1,24 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Some utilities for logprobs, including logits.""" + +import torch + + +@torch.compile(dynamic=True) +def batched_count_greater_than(x: torch.Tensor, + values: torch.Tensor) -> torch.Tensor: + """ + Counts elements in each row of x that are greater than the corresponding + value in values. Use torch.compile to generate an optimized kernel for + this function. otherwise, it will create additional copies of the input + tensors and cause memory issues. + + Args: + x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements). + values (torch.Tensor): A 2D tensor of shape (batch_size, 1). + + Returns: + torch.Tensor: A 1D tensor of shape (batch_size,) with the counts. + """ + return (x >= values).sum(-1) diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index e79e4451a3a3..82f51298f1b5 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -5,10 +5,12 @@ import torch import torch.nn as nn +from vllm.config import LogprobsMode from vllm.utils import is_pin_memory_available from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words +from vllm.v1.sample.ops.logprobs import batched_count_greater_than from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler @@ -17,10 +19,11 @@ class Sampler(nn.Module): - def __init__(self): + def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"): super().__init__() self.topk_topp_sampler = TopKTopPSampler() self.pin_memory = is_pin_memory_available() + self.logprobs_mode = logprobs_mode def forward( self, @@ -35,7 +38,10 @@ def forward( # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: - raw_logprobs = self.compute_logprobs(logits) + if self.logprobs_mode == "raw_logprobs": + raw_logprobs = self.compute_logprobs(logits) + elif self.logprobs_mode == "raw_logits": + raw_logprobs = logits.clone() # Use float32 for the logits. logits = logits.to(torch.float32) @@ -50,6 +56,14 @@ def forward( # Apply penalties (e.g., min_tokens, freq_penalties). logits = self.apply_penalties(logits, sampling_metadata) + + # Get the process logprobs or logits. + if num_logprobs is not None: + if self.logprobs_mode == "processed_logprobs": + raw_logprobs = self.compute_logprobs(logits) + elif self.logprobs_mode == "processed_logits": + raw_logprobs = logits.clone() + # Sample the next token. sampled = self.sample(logits, sampling_metadata) # Convert sampled token ids to int64 (long) type to ensure compatibility @@ -174,7 +188,7 @@ def gather_logprobs( token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token. - token_ranks = (logprobs >= token_logprobs).sum(-1) + token_ranks = batched_count_greater_than(logprobs, token_logprobs) # Concatenate together with the topk. indices = torch.cat((token_ids, topk_indices), dim=1) diff --git a/vllm/v1/sample/tpu/sampler.py b/vllm/v1/sample/tpu/sampler.py index 1056eb1d7b7f..2c9f4892bc24 100644 --- a/vllm/v1/sample/tpu/sampler.py +++ b/vllm/v1/sample/tpu/sampler.py @@ -15,6 +15,7 @@ class Sampler(nn.Module): def __init__(self): + # TODO(houseroad): Add support for logprobs_mode. super().__init__() self.topk_topp_sampler = TopKTopPSampler() diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 6661d984a771..967847c02ff2 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch import torch.nn as nn @@ -12,11 +13,11 @@ from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.utils import is_pin_memory_available from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel logger = init_logger(__name__) @@ -37,7 +38,6 @@ def __init__( self.method = self.speculative_config.method self.runner = runner - self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size @@ -45,6 +45,7 @@ def __init__( self.speculative_config.num_speculative_tokens) self.max_num_tokens = ( vllm_config.scheduler_config.max_num_batched_tokens) + self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because # the draft model's hidden size can be different from the target model's # hidden size (e.g., Llama 3.3 70B). @@ -83,19 +84,14 @@ def propose( target_positions: torch.Tensor, # [num_tokens, hidden_size] target_hidden_states: torch.Tensor, - # [num_tokens] - target_slot_mapping: torch.Tensor, # [batch_size] next_token_ids: torch.Tensor, - # [batch_size + 1] starting with 0 - cu_num_tokens: torch.Tensor, - # [batch_size, max_num_blocks_per_req] - block_table: torch.Tensor, + common_attn_metadata: CommonAttentionMetadata, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] - last_token_indices = cu_num_tokens[1:] - 1 + last_token_indices = common_attn_metadata.query_start_loc[1:] - 1 if self.method == "eagle3": assert isinstance(self.model, Eagle3LlamaForCausalLM) @@ -110,50 +106,14 @@ def propose( # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids - # FA requires seq_len to have dtype int32. - seq_lens = (target_positions[last_token_indices] + 1).int() - - if self.method in ["eagle", "eagle3"]: - # FIXME(woosuk): The below two ops cause synchronization. Optimize. - max_seq_len = seq_lens.max().item() - max_num_tokens = (cu_num_tokens[1:] - - cu_num_tokens[:-1]).max().item() - attn_metadata = FlashAttentionMetadata( - num_actual_tokens=num_tokens, - max_query_len=max_num_tokens, - query_start_loc=cu_num_tokens, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, - slot_mapping=target_slot_mapping, - # TODO(woosuk): Support cascade attention. - use_cascade=False, - common_prefix_len=0, - cu_prefix_query_lens=None, - prefix_kv_lens=None, - suffix_kv_lens=None, - ) - elif self.method == "deepseek_mtp": - query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1] - max_query_len = query_lens.max().item() - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=cu_num_tokens, - seq_lens=seq_lens, - num_reqs=batch_size, - num_actual_tokens=num_tokens, - max_query_len=max_query_len, - ) - - assert self.runner is not None + assert self.runner is not None - # FIXME: need to consider multiple kv_cache_groups - attn_metadata = self.runner.attn_metadata_builders[0].build( - common_prefix_len=0, - common_attn_metadata=common_attn_metadata, - ) - else: - raise ValueError(f"Unsupported method: {self.method}") + # FIXME: need to consider multiple kv_cache_groups + attn_metadata = self.runner.attn_metadata_builders[0].build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + fast_build=True, + ) # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. @@ -194,6 +154,11 @@ def propose( # one layer. Adapt this code to support multiple layers once # there's a multi-layer MTP module. + # Currently FlashAttention is the only backend that supports + # multi-token eagle spec decode. This is because the code below + # makes assumptions about attn_metadata attributes available. + assert isinstance(attn_metadata, FlashAttentionMetadata) + # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] @@ -238,8 +203,8 @@ def propose( # Compute the slot mapping. block_numbers = clamped_positions // self.block_size - block_ids = block_table.gather(dim=1, - index=block_numbers.view(-1, 1)) + block_ids = attn_metadata.block_table.gather( + dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) attn_metadata.slot_mapping = (block_ids * self.block_size + clamped_positions % self.block_size) @@ -275,46 +240,99 @@ def propose( draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - @staticmethod def prepare_inputs( - # [batch_size + 1] - cu_target_query_lens: torch.Tensor, + self, + common_attn_metadata: CommonAttentionMetadata, # [batch_size] - num_rejected_tokens: torch.Tensor, - num_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - # cu_target_query_lens: [0, a, a + b, a + b + c] - # num_rejected_tokens: [n1, n2, n3] - # num_tokens_per_req: [a - n1, b - n2, c - n3] - # cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - # token_indices: [0, 1, ..., a - n1 - 1, - # a, a + 1, ..., a + b - n2 - 1, - # a + b, a + b + 1, ..., a + b + c - n3 - 1] - - # [0, a, a + b, a + b + c] -> [a, b, c] - query_len_per_req = (cu_target_query_lens[1:] - - cu_target_query_lens[:-1]) - # [a, b, c] -> [a - n1, b - n2, c - n3] - num_tokens_per_req = query_len_per_req - num_rejected_tokens - - # [a - n1, b - n2, c - n3] -> - # [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3] - cu_num_tokens = torch.zeros_like(cu_target_query_lens) - torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:]) - token_indices = torch.empty( - num_tokens, + num_rejected_tokens: torch.Tensor + ) -> tuple[CommonAttentionMetadata, torch.Tensor]: + """ + This function is used to prepare the inputs for the spec decode. + It updates to the common_attn_metadata to account for the rejected + tokens (and newly sampled tokens). It also returns the token indices + of the tokens that should be fed to the speculator. + """ + # E.g. + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1, q1 + q2, q1 + q2 + q3] + # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] + # num_rejected_tokens: [n1, n2, n3] + # This function computes the intermediate values: + # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] + # And returns: + # common_attn_metadata.query_start_loc{_cpu}: + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + # common_attn_metadata.seq_lens{_cpu}: + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] + # token_indices: [0, 1, ..., q1 - n1 - 1, + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] + + device = common_attn_metadata.query_start_loc.device + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \ + - num_rejected_tokens + + # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] + new_query_len_per_req = (query_start_loc_cpu[1:] - + query_start_loc_cpu[:-1]) + # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] + new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens + new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() + + # [q1 - n1, q2 - n2, q3 - n3] -> + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] + new_query_start_loc_cpu = torch.zeros( + query_start_loc_cpu.shape, dtype=torch.int32, - device=cu_target_query_lens.device, - ) - batch_size = num_rejected_tokens.shape[0] - BLOCK_SIZE = 1024 - prepare_eagle_input_kernel[(batch_size, )]( - token_indices, - cu_target_query_lens, - cu_num_tokens, - BLOCK_SIZE=BLOCK_SIZE, + pin_memory=is_pin_memory_available()) + new_query_start_loc_np = new_query_start_loc_cpu.numpy() + np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) + + total_num_tokens = new_query_start_loc_np[-1] + # Example assuming num_tokens_per_req_np = [2, 4, 3] + # this implies that `new_query_start_locs` is: + # [0, 2, 6, 9] -> + # [0, 0, 2, 2, 2, 2, 6, 6, 6] + # _r1_ ____r2____ ___r3__ + new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], + new_num_tokens_per_req_np) + # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> + # [0, 1, 0, 1, 2, 3, 0, 1, 2] + # _r1_ ____r2____ ___r3__ + token_offests = self.token_arange_np[:total_num_tokens] \ + - new_query_start_locs_expanded + + # Expand starting positions to match token pattern + # [0, q1, q1 + q2] -> + # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] + # _r1_ _____r2_______ ___________r3____________ + old_query_start_locs_expanded = np.repeat( + query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) + # Final token indices are: + # [0, 1, // req 1 + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 + token_indices_np = token_offests + old_query_start_locs_expanded + token_indices = torch.from_numpy(token_indices_np).to( + device, non_blocking=True) + + spec_common_attn_metadata = CommonAttentionMetadata( + query_start_loc=new_query_start_loc_cpu.to(device, + non_blocking=True), + seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), + query_start_loc_cpu=new_query_start_loc_cpu, + seq_lens_cpu=new_seq_lens_cpu, + num_computed_tokens_cpu=common_attn_metadata. + num_computed_tokens_cpu, + num_reqs=common_attn_metadata.num_reqs, + num_actual_tokens=total_num_tokens, + max_query_len=new_query_len_per_req.max().item(), + block_table_tensor=common_attn_metadata.block_table_tensor, + slot_mapping=common_attn_metadata.slot_mapping[token_indices], ) - return cu_num_tokens, token_indices + + return spec_common_attn_metadata, token_indices def load_model(self, target_model: nn.Module) -> None: draft_model_config = \ diff --git a/vllm/v1/spec_decode/utils.py b/vllm/v1/spec_decode/utils.py index 3a86fea146f3..1116179dc5b6 100644 --- a/vllm/v1/spec_decode/utils.py +++ b/vllm/v1/spec_decode/utils.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.sampling_params import SamplingParams -from vllm.triton_utils import tl, triton _SAMPLING_EPS = 1e-5 @@ -13,29 +12,3 @@ def is_spec_decode_unsupported(sampling_params: SamplingParams) -> bool: or sampling_params.repetition_penalty != 1.0 or sampling_params.min_p > _SAMPLING_EPS or sampling_params.logprobs is not None) - - -@triton.jit -def prepare_eagle_input_kernel( - out_ptr, - cu_query_lens_ptr, - cu_num_tokens_ptr, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - - # [start_pos, end_pos) - start_pos = tl.load(cu_num_tokens_ptr + pid) - end_pos = tl.load(cu_num_tokens_ptr + pid + 1) - num_tokens = end_pos - start_pos - - index_start = tl.load(cu_query_lens_ptr + pid) - - num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) - for i in tl.range(num_blocks): - offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - tl.store( - out_ptr + start_pos + offset, - index_start + offset, - mask=offset < num_tokens, - ) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 839f1da8dd0d..bd1dd01f9063 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -88,6 +88,15 @@ def grammar_init(self, request: Request) -> None: tokenizer=self.tokenizer, vocab_size=vocab_size, ) + elif backend == "outlines": + from vllm.v1.structured_output.backend_outlines import ( + OutlinesBackend) + + self.backend = OutlinesBackend( + self.vllm_config, + tokenizer=self.tokenizer, + vocab_size=vocab_size, + ) else: raise ValueError( f"Unsupported structured output backend: {backend}") diff --git a/vllm/v1/structured_output/backend_outlines.py b/vllm/v1/structured_output/backend_outlines.py new file mode 100644 index 000000000000..572e4984480f --- /dev/null +++ b/vllm/v1/structured_output/backend_outlines.py @@ -0,0 +1,320 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright 2025-present the Outlines developers +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + +import ast +import importlib +import json +import sys +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch +from regex import escape as regex_escape + +from vllm.sampling_params import SamplingParams +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) +from vllm.v1.structured_output.utils import (OutlinesVocabulary, + get_outlines_cache, + get_outlines_vocabulary) + +if TYPE_CHECKING: + import outlines_core as oc + import outlines_core.json_schema as json_schema +else: + oc = LazyLoader("oc", globals(), "outlines_core") + json_schema = LazyLoader("json_schema", globals(), + "outlines_core.json_schema") + +# Python 3.11+ sre_parse and sre_constants +# are deprecated, so we must import them from re +if sys.version_info >= (3, 11): + # Hack to get around pre-commit regex module rule + # because going through re is the only way to get sre_parse + # and sre_constants in Python 3.11+ + _re = importlib.import_module("re") + sre_parse = _re._parser + sre_constants = _re._constants +else: + import sre_constants + import sre_parse + + +@dataclass +class OutlinesBackend(StructuredOutputBackend): + + def __post_init__(self): + self.vocabulary = get_outlines_vocabulary(self.tokenizer) + self.cache = get_outlines_cache() + + def _compile_index(self, regex_string: str, + vocabulary: OutlinesVocabulary) -> oc.Index: + cache_key = f"{vocabulary._hash}_{regex_string}" + if cache_key in self.cache: + return self.cache[cache_key] + + index = oc.Index(regex_string, vocabulary.inner) + self.cache[cache_key] = index + + return index + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + if request_type == StructuredOutputOptions.JSON: + regex = json_schema.build_regex_from_schema(grammar_spec) + elif request_type == StructuredOutputOptions.REGEX: + regex = grammar_spec + elif request_type == StructuredOutputOptions.CHOICE: + choices = ast.literal_eval(grammar_spec) + choices = [regex_escape(c) for c in choices] + regex = "(" + "|".join(choices) + ")" + else: + raise ValueError( + f"Invalid request type for Outlines backend ({request_type!s})" + ) + index = self._compile_index(regex, self.vocabulary) + max_rollback_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config is not None else 0) + return OutlinesGrammar(vocab_size=self.vocab_size, + guide=oc.Guide( + index, max_rollback=max_rollback_tokens)) + + def allocate_token_bitmask(self, max_num_seqs: int) -> torch.Tensor: + return torch.full( + (max_num_seqs, (self.vocab_size + 31) // 32), + -1, + dtype=torch.int32, + pin_memory=torch.cuda.is_available(), + ) + + def destroy(self): + pass + + +@dataclass +class OutlinesGrammar(StructuredOutputGrammar): + + vocab_size: int + guide: oc.Guide = field(hash=False) + num_processed_tokens: int = field(default_factory=lambda: 0, + repr=False, + hash=False, + init=False) + + # outlines_core signals done on DFA accept; vLLM expects done after EOS. + # We delay the finished flag by one step so EOS can still be emitted. + _prev_finished: bool = field(default=False, + init=False, + repr=False, + hash=False) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the FSM. + + Returns True if the FSM was advanced successfully. + Returns False if the FSM failed to advance. + """ + if self.guide.accepts_tokens(tokens): + # Advance cannot fail because we checked Guide.accepts_tokens() + for t in tokens: + self.guide.advance(t) + self.num_processed_tokens += 1 + return True + return False + + def rollback(self, num_tokens: int) -> None: + self.guide.rollback_state(num_tokens) + self.num_processed_tokens -= num_tokens + + def validate_tokens(self, tokens: list[int]) -> list[int]: + accepted: list[int] = [] + for tok in tokens: + accepted.append(tok) + if not self.guide.accepts_tokens(accepted): + accepted.pop() + break + return accepted + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + mask = bitmask[idx] + self.guide.write_mask_into(mask.data_ptr(), mask.numel(), + mask.element_size()) + + def is_terminated(self) -> bool: + curr = self.guide.is_finished() + prev = self._prev_finished + self._prev_finished = curr + return prev + + def reset(self): + self.num_processed_tokens = 0 + self._prev_finished = False + self.guide.reset() + + +def validate_structured_output_request_outlines(params: SamplingParams): + if params.guided_decoding is None: + return + + gd_params = params.guided_decoding + + if gd_params.regex: + validate_regex_is_buildable(gd_params.regex) + elif gd_params.json: + if isinstance(gd_params.json, str): + try: + # make sure schema is valid json + json.loads(gd_params.json) + schema = gd_params.json + except json.JSONDecodeError as e: + raise ValueError("Invalid JSON grammar specification.") from e + else: + try: + schema = json.dumps(gd_params.json) + except Exception as e: + raise ValueError( + f"Error serializing guided decoding jsonschema: {e}" + ) from e + pattern = json_schema.build_regex_from_schema(schema) + validate_regex_is_buildable(pattern) + elif gd_params.choice: + choices = [regex_escape(str(choice)) for choice in gd_params.choice] + regex = "(" + "|".join(choices) + ")" + validate_regex_is_buildable(regex) + elif gd_params.grammar: + raise ValueError("Outlines guided decoding backend " + "does not support grammar specifications") + + +def _prefix_needs_context(parsed) -> bool: + """Return True if there's a look-around/anchor before any consumer.""" + + def subpattern_consumes(parsed) -> bool: + """Return True if subpattern can consume at least one character.""" + tokens = parsed.data if hasattr(parsed, 'data') else parsed + for ttype, tval in tokens: + # literal, character class, or dot always consumes + if ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): + return True + # quantified subpattern: check inner pattern + elif ttype == sre_parse.MAX_REPEAT: + _, mx, sub = tval + if mx != 0 and subpattern_consumes(sub): + return True + # alternation: if any branch consumes, the whole does + elif ttype == sre_parse.BRANCH: + _, branches = tval + if any(subpattern_consumes(br) for br in branches): + return True + # grouped subpattern: recurse into its contents + elif ttype == sre_parse.SUBPATTERN and subpattern_consumes( + tval[3]): + return True + # No consumers, return False + return False + + tokens = parsed.data if hasattr(parsed, 'data') else parsed + for ttype, tval in tokens: + # Direct anchors or look-around + if ttype == sre_parse.AT or ttype in (sre_constants.ASSERT, + sre_constants.ASSERT_NOT): + return True + + # Nested subpattern: check + if ttype == sre_parse.SUBPATTERN: + # tval: (group, add_flags, del_flags, subpattern) + if _prefix_needs_context(tval[3]): + return True + if subpattern_consumes(tval[3]): + return False + + # if any branch has a prefix anchor => True, + # else if at least one branch consumes => prefix ends => False + elif ttype == sre_parse.BRANCH: + saw_consumer = False + for br in tval[1]: + if _prefix_needs_context(br): + return True + if subpattern_consumes(br): + saw_consumer = True + if saw_consumer: + return False + + # Immediate consumer tokens + elif ttype in (sre_parse.LITERAL, sre_parse.IN, sre_parse.ANY): + return False + + # if subpattern has anchor => True, if it can consume => stop + elif ttype == sre_parse.MAX_REPEAT: + if _prefix_needs_context(tval[2]): + return True + if subpattern_consumes(tval[2]): + return False + + return False + + +def _check_unsupported(parsed) -> None: + """Check for regex features unsupported by regex-automata""" + tokens = parsed.data if hasattr(parsed, 'data') else parsed + for ttype, tval in tokens: + + # backreference + if ttype in (sre_parse.GROUPREF, sre_parse.GROUPREF_EXISTS): + raise ValueError("Backreferences are unsupported.") + + # look-around assertion + elif ttype in (sre_constants.ASSERT, sre_constants.ASSERT_NOT): + raise ValueError("Look-Around assertion are unsupported.") + + # unicode word boundaries + elif ttype == sre_parse.AT: + if tval in (sre_constants.AT_BOUNDARY, + sre_constants.AT_NON_BOUNDARY): + raise ValueError("Unicode word boundaries are unsupported.") + + elif ttype == sre_parse.BRANCH: + # tval is (None, branches) + for branch in tval[1]: + _check_unsupported(branch) + + # tval is (min, max, subpattern) + elif ttype == sre_parse.MAX_REPEAT: + _check_unsupported(tval[2]) + + +def validate_regex_is_buildable(pattern: str) -> None: + """ + Validates that the input regex is not using unsupported features + of the `regex-automata` crate (outlines_core regex engine) and has a + universal start state. + definition of universal start state used can be found at: + https://docs.rs/regex-automata/latest/regex_automata/dfa/trait.Automaton.html#method.universal_start_state + """ + try: + parsed = sre_parse.parse(pattern) + + except sre_constants.error as e: + raise ValueError(f"Error parsing regex: {e}") from e + + try: + _check_unsupported(parsed) + except ValueError as e: + raise ValueError( + f"Regex uses unsupported feature for guided decoding: {e}. " + "Only basic matching constructs are supported—lookarounds, " + "backreferences, and unicode boundaries are not.") from e + + if _prefix_needs_context(parsed): + raise ValueError( + "Regex does not have a anchored universal start state" + "This means that the Regex uses anchors (^) or look-arounds " + "in a way which requires context before any token is matched." + "Guided decoding needs regexes that can match without needing " + "that context. Try rewriting the pattern without using these " + f"constructs. Pattern:\n{pattern}") diff --git a/vllm/v1/structured_output/utils.py b/vllm/v1/structured_output/utils.py index 7adee7237bd1..95319831d512 100644 --- a/vllm/v1/structured_output/utils.py +++ b/vllm/v1/structured_output/utils.py @@ -3,7 +3,205 @@ from __future__ import annotations +import hashlib +import importlib.metadata +import os +from typing import TYPE_CHECKING + import regex as re +from cachetools import LRUCache +from diskcache import Cache + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils import LazyLoader + +if TYPE_CHECKING: + import outlines_core as oc + import transformers.file_utils as file_utils + import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2 + + from vllm.transformers_utils.tokenizer import AnyTokenizer +else: + oc = LazyLoader("oc", globals(), "outlines_core") + file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils") + tokenization_gpt2 = LazyLoader( + "tokenization_gpt2", + globals(), + "transformers.models.gpt2.tokenization_gpt2", + ) + +logger = init_logger(__name__) + +CACHE = None + + +class OutlinesVocabulary: + """ + Wrapper class for `outlines_core.Vocabulary`, + which allows us to store a hash with the vocabulary + """ + + def __init__(self, vocabulary: oc.Vocabulary) -> None: + # Actual vocabulary object + self.inner = vocabulary + # Have to do abs(hash()) because python hashes can + # be negative, and we are using hash as a cache key. + hex_str = hashlib.sha256( + vocabulary.__repr__().encode('utf-8')).hexdigest() + hash_int = int(hex_str, 16) + self._hash = hash_int + + +def get_outlines_cache_path() -> str: + """Get the context object that contains previously-computed return values""" + outlines_cache_dir = os.getenv("OUTLINES_CACHE_DIR") + xdg_cache_home = os.getenv("XDG_CACHE_HOME") + home_dir = os.path.expanduser("~") + + if outlines_cache_dir: + # OUTLINES_CACHE_DIR takes precedence + return outlines_cache_dir + elif xdg_cache_home: + return os.path.join(xdg_cache_home, ".cache", "outlines") + # If homedir is "/", we may be inside a container, and thus writing to + # root would be problematic, so we fallback to using a tempfile. + # Also validate the path exists, since os.path.expanduser does + # not garuntee existence. + elif os.path.isdir(home_dir) and home_dir != "/": + # Default Unix fallback: ~/.cache/outlines + return os.path.join(home_dir, ".cache", "outlines") + else: + import tempfile + + # home_dir may be / inside a docker container without existing user + tempdir = tempfile.gettempdir() + return os.path.join(tempdir, ".cache", "outlines") + + +def get_outlines_cache(): + """Get the Cache instance to be used for index caching""" + + cache_dir = get_outlines_cache_path() + if envs.VLLM_V1_USE_OUTLINES_CACHE: + logger.warning("Enabling outlines cache. This is an unbounded on-disk " + "cache. It may consume a lot of disk space and should " + "not be used with untrusted clients.") + cache = Cache(cache_dir, eviction_policy="none", cull_limit=0) + outlines_version = importlib.metadata.version("outlines_core") + + cached_version = cache.get('__version__', None) + if cached_version != outlines_version: + cache.clear() + cache.set('__version__', outlines_version) + return cache + else: + return LRUCache(maxsize=128) + + +re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$") +re_replacement_seq = re.compile(r"^.{0,6}�+.{0,6}$") + + +def _reduced_vocabulary( + tokenizer: AnyTokenizer, + eos_token_id: int, +) -> dict[bytes, list[int]]: + """Create a map from vocabulary tokens to lists of equivalent token ids. + + Returns: + A Dict of token string -> equivalent token ids + """ + + unicode_to_bytes = { + v: k + for k, v in tokenization_gpt2.bytes_to_unicode().items() + } + + def convert_token_to_string(token: str) -> str: + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if (type(token) is str + and token.startswith(file_utils.SPIECE_UNDERLINE) + or token == "<0x20>"): + return " " + string + + return string + + vocabulary: dict[bytes, list[int]] = {} + empty_token_ids: list[int] = [] + for token, token_idx in tokenizer.get_vocab().items(): + if token in tokenizer.all_special_tokens: # type: ignore + continue + + token_str = convert_token_to_string(token) + if token_str: + if isinstance(token, (bytes, bytearray)): + # For BPE tokenizers where tokens are stored as bytes. + + # safe to ignore since token_str is of type (bytearray, bytes) + # by this point. + token_bytes = bytes(token_str) # type: ignore[arg-type] + + elif "\ufffd" in token_str and not re_replacement_seq.match( + token_str): + # Handle tokens with invalid UTF-8 sequences. + if re_llama_byte_token.match(token): + # Llama-like tokenizers use <0xXX> for incomplete sequences. + token_bytes = bytes([int(token[3:5], 16)]) + else: + # GPT2 tokenizers: map each byte back using unicode_to_bytes + byte_vals = [unicode_to_bytes.get(c) for c in token] + if None in byte_vals: + raise RuntimeError( + f"Cannot convert token `{token}`" + f" ({token_idx}) to bytes: {token_str}") + # safe to ignore, since if None in byte_vals, + # an error is thrown. + token_bytes = bytes(byte_vals) # type: ignore[arg-type] + else: + token_bytes = token_str.encode('utf-8') + + if token_idx != eos_token_id: + vocabulary.setdefault(token_bytes, []).append(token_idx) + else: + empty_token_ids.append(token_idx) + + return vocabulary + + +def get_outlines_vocabulary(tokenizer: AnyTokenizer) -> oc.Vocabulary: + """Get the `Vocabulary` object for a given tokenizer. + """ + if hasattr(tokenizer, "_outlines_vocabulary"): + return tokenizer._outlines_vocabulary # type: ignore + + try: + if hasattr( + tokenizer, + "eos_token_id", + ) and tokenizer.eos_token_id is not None: + eos_token_id = tokenizer.eos_token_id + else: + raise ValueError( + f"Error during structured outputs setup for outlines: Tokenizer ({type(tokenizer)}) has no `eos_token_id` property, but `eos_token_id` is required for structured outputs to work properly." # noqa: E501 + ) + + reduced_vocab = _reduced_vocabulary( + tokenizer, + eos_token_id #type: ignore + ) + vocabulary = OutlinesVocabulary( + oc.Vocabulary(eos_token_id, reduced_vocab)) + tokenizer._outlines_vocabulary = vocabulary # type: ignore + + return vocabulary + except AttributeError as e: + raise ValueError(f"Cannot get the vocabulary of the tokenizer " + f"({type(tokenizer)}). The tokenizer should have a " + "get_vocab method.") from e def grammar_is_likely_lark(grammar_str: str) -> bool: @@ -77,7 +275,7 @@ def check_quotes(text: str, rule_name: str, line_num: int) -> None: raise ValueError( f"Mismatched quotes in {rule_name} on line {line_num}") - def extract_references(text: str) -> set: + def extract_references(text: str) -> set[str]: """Extract rule references from text.""" # Remove quoted strings and special characters text = re.sub(r'"[^"]*"', '', text) diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 6b40cf6fd36d..c74d8c543f76 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -4,7 +4,6 @@ import multiprocessing import time import weakref -from collections import defaultdict from collections.abc import Sequence from multiprocessing import connection from multiprocessing.process import BaseProcess @@ -14,14 +13,12 @@ import torch from vllm.logger import init_logger -from vllm.model_executor.models.utils import extract_layer_index from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri, kill_process_tree) if TYPE_CHECKING: - from vllm.attention.layer import Attention from vllm.v1.engine.coordinator import DPCoordinator from vllm.v1.engine.utils import (CoreEngineActorManager, CoreEngineProcManager) @@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]): kill_process_tree(pid) -def bind_kv_cache( - kv_caches: dict[str, torch.Tensor], - forward_context: dict[str, "Attention"], - runner_kv_caches: list[torch.Tensor], -) -> None: - """ - Bind the allocated KV cache to both ModelRunner and forward context so - that the KV cache can be used in the forward pass. - - This function: - 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with - kv_caches. - 2) Associates each attention layer in the `forward_context` with its - corresponding KV cache in kv_caches. - - Args: - kv_caches: The allocated kv_caches with layer names as keys. - forward_context: The global forward context containing all Attention - layers with layer names as keys. - runner_kv_caches: The kv_cache declared by ModelRunner. - """ - # Bind kv_caches to ModelRunner - assert len(runner_kv_caches) == 0 - - # Convert kv_caches dict to a list of tensors in the order of layer_index. - index2name = defaultdict(list) - for layer_name in kv_caches: - index2name[extract_layer_index(layer_name)].append(layer_name) - - for layer_index in sorted(index2name.keys()): - layer_names = index2name[layer_index] - if len(layer_names) > 1: - # One typical case is encoder-decoder model, e.g., bart. - # The cross attention and self attention in the same decoder layer - # has different layer_name but the same layer_index. - raise NotImplementedError - layer_name = layer_names[0] - runner_kv_caches.append(kv_caches[layer_name]) - - # Bind kv_caches to forward context - for layer_name, kv_cache in kv_caches.items(): - # NOTE: Use list because of v0 PP virtual engine. - forward_context[layer_name].kv_cache = [kv_cache] - - def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, length: int) -> torch.Tensor: """ @@ -366,8 +318,6 @@ def report_usage_stats( # Feature flags "enable_lora": bool(vllm_config.lora_config), - "enable_prompt_adapter": - bool(vllm_config.prompt_adapter_config), "enable_prefix_caching": vllm_config.cache_config.enable_prefix_caching, "enforce_eager": diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 8f4e8d64c615..bf38e88f0c2a 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -14,12 +14,14 @@ class BlockTable: def __init__( self, + block_size: int, max_num_reqs: int, max_num_blocks_per_req: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, ): + self.block_size = block_size self.max_num_reqs = max_num_reqs self.max_num_blocks_per_req = max_num_blocks_per_req self.max_num_batched_tokens = max_num_batched_tokens @@ -79,10 +81,31 @@ def swap_row(self, src: int, tgt: int) -> None: self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]] - def commit(self, num_reqs: int) -> None: + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] + # where K is the max_num_blocks_per_req and the block size is 2. + # NOTE(woosuk): We can't simply use `token_indices // block_size` + # here because M (max_model_len) is not necessarily divisible by + # block_size. + block_table_indices = (req_indices * self.max_num_blocks_per_req + + positions // self.block_size) + block_table_cpu = self.get_cpu_tensor() + block_numbers = block_table_cpu.flatten()[block_table_indices].numpy() + block_offsets = positions % self.block_size + np.add(block_numbers * self.block_size, + block_offsets, + out=self.slot_mapping_np[:req_indices.shape[0]]) + + def commit_block_table(self, num_reqs: int) -> None: self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs], non_blocking=True) + def commit_slot_mapping(self, num_tokens: int) -> None: + self.slot_mapping[:num_tokens].copy_( + self.slot_mapping_cpu[:num_tokens], non_blocking=True) + def clear(self) -> None: self.block_table.fill_(0) self.block_table_cpu.fill_(0) @@ -107,7 +130,8 @@ def __init__(self, max_num_reqs: int, max_model_len: int, max_num_batched_tokens: int, pin_memory: bool, device: torch.device, block_sizes: list[int]) -> None: self.block_tables = [ - BlockTable(max_num_reqs, cdiv(max_model_len, block_size), + BlockTable(block_size, max_num_reqs, cdiv(max_model_len, + block_size), max_num_batched_tokens, pin_memory, device) for block_size in block_sizes ] @@ -129,9 +153,18 @@ def swap_row(self, src: int, tgt: int) -> None: for block_table in self.block_tables: block_table.swap_row(src, tgt) - def commit(self, num_reqs: int) -> None: + def compute_slot_mapping(self, req_indices: np.ndarray, + positions: np.ndarray) -> None: + for block_table in self.block_tables: + block_table.compute_slot_mapping(req_indices, positions) + + def commit_block_table(self, num_reqs: int) -> None: + for block_table in self.block_tables: + block_table.commit_block_table(num_reqs) + + def commit_slot_mapping(self, num_tokens: int) -> None: for block_table in self.block_tables: - block_table.commit(num_reqs) + block_table.commit_slot_mapping(num_tokens) def clear(self) -> None: for block_table in self.block_tables: diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py index 410a54e7466f..ca94ac8c6054 100644 --- a/vllm/v1/worker/cpu_model_runner.py +++ b/vllm/v1/worker/cpu_model_runner.py @@ -8,7 +8,6 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models.interfaces import has_step_pooler from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -46,17 +45,15 @@ def replace_tensor(obj: Any, cpu_attr_name: str, if k.endswith("_cpu_tensor") and isinstance(v, torch.Tensor): replace_tensor(self.input_batch, k, k[:-11]) - for k, v in vars(self.input_batch.block_table).items(): - if k.endswith("_cpu") and isinstance(v, torch.Tensor): - replace_tensor(self.input_batch.block_table, k, k[:-4]) + for block_table in self.input_batch.block_table.block_tables: + for k, v in vars(block_table).items(): + if k.endswith("_cpu") and isinstance(v, torch.Tensor): + replace_tensor(block_table, k, k[:-4]) - def load_model(self) -> None: + def load_model(self, eep_scale_up: bool = False) -> None: logger.info("Starting to load model %s...", self.model_config.model) self.model = get_model(vllm_config=self.vllm_config) - if has_step_pooler(self.model): - self.input_batch.logits_processing_needs_token_ids = True - if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, self.scheduler_config, diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index de575d604055..2dc28d93049a 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from importlib import util -from typing import Optional +import platform +from typing import Callable, Optional import torch @@ -11,6 +11,8 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.model_executor.utils import set_random_seed +from vllm.platforms import CpuArchEnum, current_platform +from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo from vllm.sequence import IntermediateTensors from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput @@ -40,10 +42,17 @@ def __init__(self, def init_device(self): # Setup OpenMP threads affinity. omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND - self.local_omp_cpuid = "all" - if omp_cpuids == "auto": - self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( - ) + if omp_cpuids == "auto" and platform.system() == "Linux": + if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC: + # For POWERPC SMT-8/4/2 + self.local_omp_cpuid = self._get_autobind_cpu_ids( + lambda cpus: [cpu for cpu in cpus if cpu.id % 8 < 4]) + elif current_platform.get_cpu_architecture() == CpuArchEnum.X86: + # For x86 SMT-2, use 1 CPU per core + self.local_omp_cpuid = self._get_autobind_cpu_ids( + lambda cpus: cpus[-1:]) + else: + self.local_omp_cpuid = "all" else: self.local_omp_cpuid = omp_cpuids.split("|")[self.rank] @@ -58,7 +67,8 @@ def init_device(self): # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank, "gloo") + self.local_rank, + current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) @@ -106,48 +116,58 @@ def execute_model( assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None - def get_cpus_id_binding_based_on_numa_nodes(self) -> str: - """Return CPUs id binding based on NUMA nodes. + def _get_autobind_cpu_ids( + self, cpu_selector: Callable[[list[LogicalCPUInfo]], + list[LogicalCPUInfo]] + ) -> str: """ - rank_to_cpus = self.local_omp_cpuid - # Setup OpenMP thread affinity based on NUMA nodes automatically - world_size = self.vllm_config.parallel_config.world_size - libnuma_found = util.find_spec("numa") is not None - psutil_found = util.find_spec("psutil") is not None - if libnuma_found and psutil_found: - import psutil - from numa import info - cpu_count = psutil.cpu_count(logical=False) - cpus_allow_list = psutil.Process().cpu_affinity() - numa_size = info.get_num_configured_nodes() - cpu_count_per_numa = cpu_count // numa_size - num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU, - cpu_count_per_numa // 2) - - # check allow node_to_cpus list - node_to_cpus = [] - for i in range(numa_size): - node_intersect = set( - info.node_to_cpus(i)).intersection(cpus_allow_list) - if bool(node_intersect): - node_to_cpus.append(list(node_intersect)) - - if world_size > len(node_to_cpus): - logger.error( - "Auto thread-binding failed due to " - "world size: %d is larger than " - "allowed NUMA nodes number: %d." - "Please try to bind threads manually.", world_size, - len(node_to_cpus)) - else: - end = cpu_count_per_numa - num_of_reserved_cpu - rank_to_cpus_list = node_to_cpus[self.rank][:end] - rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list) - logger.info("auto thread-binding list: %s", rank_to_cpus) - else: - logger.warning( - "Auto thread-binding is not supported due to " - "the lack of package numa and psutil," - "fallback to no thread-binding. To get better performance," - "please try to manually bind threads.") - return rank_to_cpus + Return CPU ids to bind based on NUMA nodes. + Currently for rank N, only CPU ids on the N-th node in available NUMA + node list will be selected. + Args: + cpu_selector: a callable object to select CPUs from a CPU list + of a physical core. The input is a LogicalCPUInfo list, sorted by + the LogicalCPUInfo.id. A selected LogicalCPUInfo list should be + returned. + """ + + allowed_numa_nodes, logical_cpu_list = \ + CpuPlatform.get_allowed_cpu_memory_node_list() + assert len(allowed_numa_nodes) >= self.parallel_config.world_size, ( + f"No enough allowed NUMA nodes to bind threads of " + f"{self.parallel_config.world_size} CPUWorkers. " + f"Allowed NUMA nodes are {allowed_numa_nodes}. " + "Please try to bind threads manually.") + + # Get CPUs on NUMA node `allowed_numa_nodes[local_rank]`` + selected_numa_node = allowed_numa_nodes[ + self.local_rank] # type: ignore + logical_cpu_list = [ + x for x in logical_cpu_list if x.numa_node == selected_numa_node + ] + + # Select CPUs from each physical core via cpu_selector + core_to_cpus: dict[int, list[LogicalCPUInfo]] = {} + for cpu_info in logical_cpu_list: + if cpu_info.physical_core not in core_to_cpus: + core_to_cpus[cpu_info.physical_core] = [] + core_to_cpus[cpu_info.physical_core].append(cpu_info) + logical_cpu_list = [] + for cpu_list in core_to_cpus.values(): + cpu_list = sorted(cpu_list, key=lambda x: x.id) + logical_cpu_list.extend(cpu_selector(cpu_list)) + logical_cpu_list = sorted(logical_cpu_list, key=lambda x: x.id) + + # Reserve CPUs for other processes + reserve_cpu_num = envs.VLLM_CPU_NUM_OF_RESERVED_CPU + if reserve_cpu_num is None: + reserve_cpu_num = 1 if self.parallel_config.world_size > 1 else 0 + assert len(logical_cpu_list) > reserve_cpu_num, ( + f"VLLM_CPU_NUM_OF_RESERVED_CPU ({reserve_cpu_num}) " + f"should less than {len(logical_cpu_list)}.") + if reserve_cpu_num != 0: + logical_cpu_list = logical_cpu_list[:-reserve_cpu_num] + + logger.info("auto thread-binding list (id, physical core): %s", + [(x.id, x.physical_core) for x in logical_cpu_list]) + return ",".join([str(x.id) for x in logical_cpu_list]) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1a79d72be0a9..c63041600f38 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -70,7 +70,6 @@ def __init__( vocab_size: int, block_sizes: list[int], # The block_size of each kv cache group is_spec_decode: bool = False, - logits_processing_needs_token_ids: bool = False, ): self.is_spec_decode = is_spec_decode self.max_num_reqs = max_num_reqs @@ -79,8 +78,6 @@ def __init__( self.device = device self.pin_memory = pin_memory self.vocab_size = vocab_size - self.logits_processing_needs_token_ids = ( - logits_processing_needs_token_ids) self._req_ids: list[Optional[str]] = [] self.req_id_to_index: dict[str, int] = {} @@ -233,6 +230,9 @@ def __init__( # req_index -> bad_words_token_ids self.bad_words_token_ids: dict[int, list[list[int]]] = {} + self.logits_processing_needs_token_ids = np.zeros(max_num_reqs, + dtype=bool) + self.req_output_token_ids: list[Optional[list[int]]] = [] # This is updated each time the batch constituents change. @@ -365,9 +365,12 @@ def add_request( if sampling_params.bad_words_token_ids: self.bad_words_token_ids[ req_index] = sampling_params.bad_words_token_ids + elif pooling_params := request.pooling_params: + self.pooling_params[req_id] = pooling_params + self.logits_processing_needs_token_ids[req_index] = ( + pooling_params.requires_token_ids) else: - assert request.pooling_params is not None - self.pooling_params[req_id] = request.pooling_params + raise NotImplementedError(request) # Add request lora ID if request.lora_request: @@ -386,7 +389,7 @@ def add_request( def remove_request(self, req_id: str) -> Optional[int]: """This method must always be followed by a call to condense(). - + Args: req_id: request to remove @@ -587,7 +590,7 @@ def condense(self) -> None: def refresh_metadata(self): """Apply batch updates, reset input batch at end of step - + * Apply batch add/remove/permute to logits procs' states * If batch state is modified, update sampling metadata """ @@ -620,9 +623,9 @@ def _make_sampling_metadata(self) -> SamplingMetadata: copy_slice(self.repetition_penalties_cpu_tensor, self.repetition_penalties, num_reqs) - needs_prompt_token_ids = (not self.no_penalties or - (self.num_reqs > 0 - and self.logits_processing_needs_token_ids)) + needs_prompt_token_ids = ( + not self.no_penalties + or self.logits_processing_needs_token_ids[:num_reqs].any()) if needs_prompt_token_ids: # The prompt tokens are used only for applying penalties or # step pooling during the sampling/pooling process. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db1f7..a5bf197ba161 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4,9 +4,8 @@ import copy import gc import time -import weakref from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union, cast import numpy as np import torch @@ -20,7 +19,7 @@ from vllm.attention.layer import Attention from vllm.compilation.counter import compilation_counter from vllm.config import (CompilationLevel, VllmConfig, - get_layers_from_vllm_config) + get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -31,27 +30,30 @@ from vllm.forward_context import (DPMetadata, get_forward_context, set_forward_context) from vllm.logger import init_logger -from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader -from vllm.model_executor.models.interfaces import (has_step_pooler, - is_mixture_of_experts) +from vllm.model_executor.models.interfaces import is_mixture_of_experts +from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, + is_pooling_model) from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.utils import group_mm_inputs_by_modality -from vllm.pooling_params import PoolingParams +from vllm.pooling_params import PoolingParams, PoolingTask from vllm.sampling_params import SamplingType -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, - check_use_alibi, get_dtype_size, + GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, CommonAttentionMetadata, + make_local_attention_virtual_batches) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheSpec, MambaSpec, +from vllm.v1.kv_cache_interface import (AttentionSpec, + ChunkedLocalAttentionSpec, + FullAttentionSpec, KVCacheConfig, + KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) @@ -63,13 +65,12 @@ from vllm.v1.spec_decode.medusa import MedusaProposer from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer -from vllm.v1.utils import bind_kv_cache -from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from ..sample.logits_processor import LogitsProcessorManager -from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing, +from .utils import (bind_kv_cache, gather_mm_placeholders, + initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) if TYPE_CHECKING: @@ -103,7 +104,6 @@ def __init__( self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config from vllm.model_executor.models.utils import set_cpu_offload_max_bytes @@ -125,6 +125,8 @@ def __init__( self.is_multimodal_model = model_config.is_multimodal_model self.is_pooling_model = model_config.pooler_config is not None + self.model_supports_multimodal_raw_input = ( + model_config.model_supports_multimodal_raw_input) self.max_model_len = model_config.max_model_len self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -150,7 +152,7 @@ def __init__( self.encoder_cache_size = encoder_cache_size # Sampler - self.sampler = Sampler() + self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.eplb_state: Optional[EplbState] = None """ @@ -327,6 +329,14 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: Args: scheduler_output: The scheduler output. """ + # Attention free models have zero kv_cache_goups, however models + # like Mamba are also attention free but use the kv_cache for + # keeping its internal state. This is why we check the number + # of kv_cache groups instead of solely checking + # for self.model_config.is_attention_free. + if len(self.kv_cache_config.kv_cache_groups) == 0: + return + self.attn_metadata_builders[0].reorder_batch(self.input_batch, scheduler_output) @@ -405,6 +415,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_id = new_req_data.req_id sampling_params = new_req_data.sampling_params pooling_params = new_req_data.pooling_params + if sampling_params and \ sampling_params.sampling_type == SamplingType.RANDOM_SEED: generator = torch.Generator(device=self.device) @@ -412,6 +423,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: else: generator = None + if pooling_params: + assert (task := pooling_params.task) is not None, ( + "You did not set `task` in the API") + + model = cast(VllmModelForPooling, self.model) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, @@ -555,6 +574,38 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Refresh batch metadata with any pending updates. self.input_batch.refresh_metadata() + def _init_model_kwargs_for_multimodal_model( + self, + scheduler_output: Optional["SchedulerOutput"] = None, + num_reqs: int = -1, + ) -> dict[str, Any]: + + model_kwargs: dict[str, Any] = {} + if self.model_supports_multimodal_raw_input: + # This model requires the raw multimodal data in input. + if scheduler_output: + multi_modal_kwargs_list = [] + for req in scheduler_output.scheduled_new_reqs: + req_mm_inputs = req.mm_inputs + if not isinstance(req_mm_inputs, list): + req_mm_inputs = list(req_mm_inputs) + multi_modal_kwargs_list.extend(req_mm_inputs) + multi_modal_kwargs = MultiModalKwargs.batch( + multi_modal_kwargs_list) + else: + # The only case where SchedulerOutput is None is for + # a dummy run let's get some dummy data. + dummy_data = [ + self.mm_registry.get_decoder_dummy_data( + model_config=self.model_config, + seq_len=1).multi_modal_data for i in range(num_reqs) + ] + multi_modal_kwargs = MultiModalKwargs.batch(dummy_data) + + model_kwargs.update(multi_modal_kwargs) + + return model_kwargs + def _get_cumsum_and_arange( self, num_tokens: np.ndarray, @@ -578,8 +629,9 @@ def _get_cumsum_and_arange( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], bool, torch.Tensor, - Optional[SpecDecodeMetadata], np.ndarray]: + ) -> tuple[dict[str, + Any], bool, torch.Tensor, Optional[SpecDecodeMetadata], + np.ndarray, Optional[CommonAttentionMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -594,7 +646,7 @@ def _prepare_inputs( # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table.commit(num_reqs) + self.input_batch.block_table.commit_block_table(num_reqs) # Get the number of scheduled tokens for each request. req_ids = self.input_batch.req_ids @@ -638,29 +690,10 @@ def _prepare_inputs( torch.from_numpy(token_indices), out=self.input_ids_cpu[:total_num_scheduled_tokens]) - # Calculate the slot mapping for each KV cache group. - for kv_cache_group_id, kv_cache_group_spec in enumerate( - self.kv_cache_config.kv_cache_groups): - block_size = kv_cache_group_spec.kv_cache_spec.block_size - block_table: BlockTable = self.input_batch.block_table[ - kv_cache_group_id] - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] - # where K is the max_num_blocks_per_req and the block size is 2. - # NOTE(woosuk): We can't simply use `token_indices // block_size` - # here because M (max_model_len) is not necessarily divisible by - # block_size. - block_table_indices = ( - req_indices * block_table.max_num_blocks_per_req + - positions_np // block_size) - block_table_cpu = block_table.get_cpu_tensor() - block_numbers = block_table_cpu.flatten( - )[block_table_indices].numpy() - block_offsets = positions_np % block_size - np.add( - block_numbers * block_size, - block_offsets, - out=block_table.slot_mapping_np[:total_num_scheduled_tokens]) + self.input_batch.block_table.compute_slot_mapping( + req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) # Prepare the attention metadata. self.query_start_loc_np[0] = 0 @@ -689,7 +722,7 @@ def _prepare_inputs( self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - # Fill unused with -1. Needed for reshape_and_cache + # Fill unused with 0 for full cuda graph mode. self.seq_lens[num_reqs:].fill_(0) # Note: pad query_start_loc to be non-decreasing, as kernels # like FlashAttention requires that @@ -697,15 +730,8 @@ def _prepare_inputs( self.query_start_loc_cpu[num_reqs].item()) query_start_loc = self.query_start_loc[:num_reqs + 1] - seq_lens = self.seq_lens[:num_reqs] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=total_num_scheduled_tokens, - max_query_len=max_num_scheduled_tokens, - ) + + spec_decode_common_attn_metadata = None attn_metadata: dict[str, Any] = {} # Prepare the attention metadata for each KV cache group and make layers @@ -713,6 +739,38 @@ def _prepare_inputs( for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor()[:num_reqs] + slot_mapping = blk_table.slot_mapping[:total_num_scheduled_tokens] + + # Fill unused with -1. Needed for reshape_and_cache in full cuda + # graph mode. + blk_table.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=total_num_scheduled_tokens, + max_query_len=max_num_scheduled_tokens, + block_table_tensor=blk_table_tensor, + slot_mapping=slot_mapping, + ) + + if self.speculative_config and \ + spec_decode_common_attn_metadata is None: + spec_decode_common_attn_metadata = common_attn_metadata + + if isinstance(kv_cache_group_spec.kv_cache_spec, + ChunkedLocalAttentionSpec): + common_attn_metadata = make_local_attention_virtual_batches( + kv_cache_group_spec.kv_cache_spec.attention_chunk_size, + common_attn_metadata, self.cache_config.block_size) + # Prepare for cascade attention if enabled & beneficial. common_prefix_len = 0 builder = self.attn_metadata_builders[kv_cache_group_id] @@ -766,7 +824,8 @@ def _prepare_inputs( self.set_active_loras(self.input_batch, num_scheduled_tokens) return (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, num_scheduled_tokens) + spec_decode_metadata, num_scheduled_tokens, + spec_decode_common_attn_metadata) def _compute_cascade_attn_prefix_len( self, @@ -846,6 +905,10 @@ def _compute_cascade_attn_prefix_len( use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or (isinstance(kv_cache_spec, FullAttentionSpec) and kv_cache_spec.sliding_window is not None)) + use_local_attention = ( + isinstance(kv_cache_spec, ChunkedLocalAttentionSpec) + or (isinstance(kv_cache_spec, FullAttentionSpec) + and kv_cache_spec.attention_chunk_size is not None)) assert isinstance(kv_cache_spec, AttentionSpec) use_cascade = attn_metadata_builder.use_cascade_attention( common_prefix_len=common_prefix_len, @@ -854,6 +917,7 @@ def _compute_cascade_attn_prefix_len( num_kv_heads=kv_cache_spec.num_kv_heads, use_alibi=self.use_alibi, use_sliding_window=use_sliding_window, + use_local_attention=use_local_attention, num_sms=self.num_sms, ) return common_prefix_len if use_cascade else 0 @@ -1090,6 +1154,13 @@ def _gather_mm_embeddings( def get_model(self) -> nn.Module: return self.model + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + return list(model.pooler.get_supported_tasks()) + def apply_grammar_bitmask( self, scheduler_output: "SchedulerOutput", @@ -1290,8 +1361,9 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, - spec_decode_metadata, - num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output)) + spec_decode_metadata, num_scheduled_tokens_np, + spec_decode_common_attn_metadata) = ( + self._prepare_inputs(scheduler_output)) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -1328,11 +1400,14 @@ def execute_model( # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. input_ids = self.input_ids[:num_scheduled_tokens] - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(input_ids) + + model_kwargs = self._init_model_kwargs_for_multimodal_model( + scheduler_output=scheduler_output) + inputs_embeds = self.model.get_input_embeddings( + input_ids=input_ids, + multimodal_embeddings=mm_embeds or None, + ) + # TODO(woosuk): Avoid the copy. Optimize. self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) inputs_embeds = self.inputs_embeds[:num_input_tokens] @@ -1344,6 +1419,7 @@ def execute_model( # then the embedding layer is not included in the CUDA graph. input_ids = self.input_ids[:num_input_tokens] inputs_embeds = None + model_kwargs = {} if self.uses_mrope: positions = self.mrope_positions[:, :num_input_tokens] else: @@ -1376,6 +1452,10 @@ def execute_model( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), ) self.maybe_wait_for_kv_save() @@ -1398,6 +1478,9 @@ def execute_model( if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: + if finished_sending or finished_recving: + hidden_states.finished_sending = finished_sending + hidden_states.finished_recving = finished_recving return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict(hidden_states.tensors, @@ -1535,6 +1618,7 @@ def execute_model( # Speculative decoding is not enabled. spec_token_ids = None else: + assert spec_decode_common_attn_metadata is not None spec_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, @@ -1543,13 +1627,9 @@ def execute_model( sample_hidden_states, aux_hidden_states, spec_decode_metadata, - attn_metadata, + spec_decode_common_attn_metadata, ) - # Clear KVConnector state after all KVs are generated. - if has_kv_transfer_group(): - get_kv_transfer_group().clear_connector_metadata() - self.eplb_step() return ModelRunnerOutput( @@ -1574,7 +1654,7 @@ def propose_draft_token_ids( sample_hidden_states: torch.Tensor, aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], - attn_metadata: dict[str, Any], + common_attn_metadata: CommonAttentionMetadata, ) -> list[list[int]]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": @@ -1621,16 +1701,6 @@ def propose_draft_token_ids( next_token_ids = torch.tensor(next_token_ids, dtype=torch.int32, device=self.device) - # At this moment, we assume all eagle layers belong to the same KV - # cache group, thus using the same attention metadata. - eagle_attn_metadata = attn_metadata[ - self.drafter.attn_layer_names[0]] - - # NOTE: deepseek_mtp uses MLA which does not have `block_table` - if hasattr(eagle_attn_metadata, "block_table"): - block_table = eagle_attn_metadata.block_table - else: - block_table = None if spec_decode_metadata is None: # input_ids can be None for multimodal models. @@ -1643,8 +1713,6 @@ def propose_draft_token_ids( dim=-1) else: target_hidden_states = hidden_states[:num_scheduled_tokens] - target_slot_mapping = eagle_attn_metadata.slot_mapping - cu_num_tokens = eagle_attn_metadata.query_start_loc else: # TODO(woosuk): Refactor this. num_draft_tokens = spec_decode_metadata.num_draft_tokens @@ -1652,17 +1720,12 @@ def propose_draft_token_ids( n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 for i, n in enumerate(num_draft_tokens) ] - num_rejected_tokens_tensor = async_tensor_h2d( - num_rejected_tokens, - dtype=torch.int32, - target_device=self.device, - pin_memory=True) - num_tokens = num_scheduled_tokens - sum(num_rejected_tokens) - cu_num_tokens, token_indices = self.drafter.prepare_inputs( - eagle_attn_metadata.query_start_loc, - num_rejected_tokens_tensor, - num_tokens, - ) + num_rejected_tokens_cpu = torch.tensor(num_rejected_tokens, + dtype=torch.int32) + common_attn_metadata, token_indices =\ + self.drafter.prepare_inputs( + common_attn_metadata, num_rejected_tokens_cpu) + target_token_ids = self.input_ids[token_indices] # TODO(woosuk): Support M-RoPE. target_positions = self.positions[token_indices] @@ -1671,37 +1734,17 @@ def propose_draft_token_ids( [h[token_indices] for h in aux_hidden_states], dim=-1) else: target_hidden_states = hidden_states[token_indices] - target_slot_mapping = eagle_attn_metadata.slot_mapping[ - token_indices] draft_token_ids = self.drafter.propose( target_token_ids=target_token_ids, target_positions=target_positions, target_hidden_states=target_hidden_states, - target_slot_mapping=target_slot_mapping, next_token_ids=next_token_ids, - cu_num_tokens=cu_num_tokens, - block_table=block_table, sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, ) spec_token_ids = draft_token_ids.tolist() return spec_token_ids - def kv_connector_no_forward( - self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: - # KV send/recv even if no work to do. - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving - return output - @staticmethod def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). @@ -1732,6 +1775,22 @@ def get_finished_kv_transfers( scheduler_output.finished_req_ids) return None, None + def kv_connector_no_forward( + self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: + # KV send/recv even if no work to do. + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) + finished_sending, finished_recving = ( + self.get_finished_kv_transfers(scheduler_output)) + + if not finished_sending and not finished_recving: + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + return output + def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], @@ -1766,24 +1825,56 @@ def propose_ngram_draft_token_ids( draft_token_ids.append(drafter_output.tolist()) return draft_token_ids - def load_model(self) -> None: + def update_config(self, overrides: dict[str, Any]) -> None: + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + + def load_model(self, eep_scale_up: bool = False) -> None: + """ + Args: + eep_scale_up: the model loading is for elastic EP scale up. + """ logger.info("Starting to load model %s...", self.model_config.model) - with DeviceMemoryProfiler() as m: # noqa: SIM117 + if eep_scale_up: + from vllm.distributed.parallel_state import get_ep_group + num_local_physical_experts = torch.empty(1, + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = int(num_local_physical_experts.item()) + new_ep_size = get_ep_group().world_size + global_expert_load, old_global_expert_indices = ( + EplbState.recv_state()) + num_logical_experts = global_expert_load.shape[1] + self.parallel_config.num_redundant_experts = ( + num_local_physical_experts * new_ep_size - num_logical_experts) + assert old_global_expert_indices.shape[ + 1] % num_local_physical_experts == 0 + old_ep_size = old_global_expert_indices.shape[ + 1] // num_local_physical_experts + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + else: + global_expert_load = None + old_global_expert_indices = None + rank_mapping = None + + with DeviceMemoryProfiler() as m: time_before_load = time.perf_counter() model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): - logger.info("Loading model from scratch...") - self.model = model_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.model_config) - else: - logger.info( - "Model was already initialized. Loading weights inplace..." - ) - model_loader.load_weights(self.model, - model_config=self.model_config) - if has_step_pooler(self.model): - self.input_batch.logits_processing_needs_token_ids = True + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, model_config=self.model_config) if self.lora_config: self.model = self.load_lora_model(self.model, self.model_config, @@ -1811,8 +1902,18 @@ def load_model(self) -> None: self.model, self.device, self.parallel_config, + global_expert_load, + old_global_expert_indices, + rank_mapping, ) + def reload_weights(self) -> None: + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + def save_tensorized_model( self, tensorizer_config: "TensorizerConfig", @@ -1820,6 +1921,7 @@ def save_tensorized_model( TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, + model_config=self.model_config, ) def _get_prompt_logprobs_dict( @@ -1943,7 +2045,7 @@ def maybe_randomize_inputs(self, input_ids: torch.Tensor): Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. This is to help balance expert-selection - during profile_run - - during DP rank dummy run + - during DP rank dummy run """ dp_size = self.vllm_config.parallel_config.data_parallel_size randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 @@ -1997,24 +2099,29 @@ def _dummy_run( if capture_attn_cudagraph: attn_metadata = {} - query_start_loc = self.query_start_loc[:num_reqs + 1] # Make sure max_model_len is used at the graph capture time. self.seq_lens_np[:num_reqs] = self.max_model_len self.seq_lens_np[num_reqs:] = 0 self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], non_blocking=True) - seq_lens = self.seq_lens[:num_reqs] - - common_attn_metadata = CommonAttentionMetadata( - query_start_loc=query_start_loc, - seq_lens=seq_lens, - num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - ) for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=self.query_start_loc[:num_reqs + 1], + query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + + 1], + seq_lens=self.seq_lens[:num_reqs], + seq_lens_cpu=self.seq_lens_cpu[:num_reqs], + num_computed_tokens_cpu=self.input_batch. + num_computed_tokens_cpu_tensor[:num_reqs], + num_reqs=num_reqs, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + block_table_tensor=self.input_batch.block_table[ + kv_cache_group_id].get_device_tensor()[:num_reqs], + slot_mapping=self.input_batch. + block_table[kv_cache_group_id].slot_mapping[:num_tokens]) attn_metadata_i = self.attn_metadata_builders[ kv_cache_group_id].build_for_cudagraph_capture( @@ -2026,11 +2133,15 @@ def _dummy_run( num_scheduled_tokens): model = self.model if self.is_multimodal_model: + model_kwargs = self._init_model_kwargs_for_multimodal_model( + num_reqs=num_reqs) input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] else: input_ids = self.input_ids[:num_tokens] inputs_embeds = None + model_kwargs = {} + if self.uses_mrope: positions = self.mrope_positions[:, :num_tokens] else: @@ -2059,7 +2170,12 @@ def _dummy_run( positions=positions, intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, + **MultiModalKwargs.as_kwargs( + model_kwargs, + device=self.device, + ), ) + if self.use_aux_hidden_state_outputs: hidden_states, _ = outputs else: @@ -2157,12 +2273,11 @@ def _dummy_sampler_run( ) return sampler_output - @torch.inference_mode() - def _dummy_pooler_run( + def _dummy_pooler_run_task( self, hidden_states: torch.Tensor, - ) -> torch.Tensor: - + task: PoolingTask, + ) -> PoolerOutput: num_tokens = hidden_states.shape[0] max_num_reqs = self.scheduler_config.max_num_seqs num_reqs = min(num_tokens, max_num_reqs) @@ -2174,30 +2289,55 @@ def _dummy_pooler_run( hidden_states_list = list( torch.split(hidden_states, num_scheduled_tokens_list)) - req_num_tokens = num_tokens // num_reqs + dummy_prompt_lens = torch.tensor( + [h.shape[0] for h in hidden_states_list], + device=self.device, + ) + dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), + dtype=torch.int32, + device=self.device) + + model = cast(VllmModelForPooling, self.model) + dummy_pooling_params = PoolingParams(task=task) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(dummy_pooling_params) + dummy_metadata = PoolingMetadata( - prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list], - device=self.device), - prompt_token_ids=torch.zeros((num_reqs, req_num_tokens), - dtype=torch.int32, - device=self.device), - pooling_params=[PoolingParams()] * num_reqs) + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) try: - pooler_output = self.model.pooler(hidden_states=hidden_states_list, - pooling_metadata=dummy_metadata) + return model.pooler(hidden_states=hidden_states_list, + pooling_metadata=dummy_metadata) except RuntimeError as e: if 'out of memory' in str(e): raise RuntimeError( - "CUDA out of memory occurred when warming up pooler with " - f"{num_reqs} dummy requests. Please try lowering " - "`max_num_seqs` or `gpu_memory_utilization` when " + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " "initializing the engine.") from e else: raise e - return pooler_output + + @torch.inference_mode() + def _dummy_pooler_run( + self, + hidden_states: torch.Tensor, + ) -> PoolerOutput: + # Find the task that has the largest output for subsequent steps + output_size = dict[PoolingTask, float]() + for task in self.get_supported_pooling_tasks(): + # Run a full batch with each task to ensure none of them OOMs + output = self._dummy_pooler_run_task(hidden_states, task) + output_size[task] = output.get_data_nbytes() + del output # Allow GC + + max_task = max(output_size.items(), key=lambda x: x[1])[0] + return self._dummy_pooler_run_task(hidden_states, max_task) def profile_run(self) -> None: # Profile with multimodal encoder & encoder cache. @@ -2218,8 +2358,8 @@ def profile_run(self) -> None: encoder_budget = min(self.max_num_encoder_input_tokens, self.encoder_cache_size) - max_num_mm_items_encoder_budget = cdiv(encoder_budget, - max_tokens_per_mm_item) + max_num_mm_items_encoder_budget = encoder_budget // \ + max_tokens_per_mm_item # Check how many items of this modality can be supported by # the decoder budget. @@ -2232,8 +2372,10 @@ def profile_run(self) -> None: max_num_mm_items_decoder_budget = self.max_num_reqs * \ max_mm_items_per_req - max_num_mm_items = min(max_num_mm_items_encoder_budget, - max_num_mm_items_decoder_budget) + max_num_mm_items = max( + 1, + min(max_num_mm_items_encoder_budget, + max_num_mm_items_decoder_budget)) logger.info( "Encoder cache will be initialized with a budget of %s tokens," @@ -2243,7 +2385,7 @@ def profile_run(self) -> None: # Create dummy batch of multimodal inputs. dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, - seq_len=self.max_num_tokens, + seq_len=max_tokens_per_mm_item, mm_counts={ dummy_data_modality: 1 }, @@ -2297,16 +2439,33 @@ def capture_model(self) -> None: start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0] + @contextmanager + def freeze_gc(): + # Optimize garbage collection during CUDA graph capture. + # Clean up, then freeze all remaining objects from being included + # in future collections. + gc.collect() + should_freeze = not envs.VLLM_ENABLE_CUDAGRAPH_GC + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - with graph_capture(device=self.device): + with freeze_gc(), graph_capture(device=self.device): full_cg = self.full_cuda_graph # Only rank 0 should print progress bar during capture compilation_cases = reversed(self.cudagraph_batch_sizes) if is_global_first_rank(): - compilation_cases = tqdm(list(compilation_cases), - desc="Capturing CUDA graph shapes") + compilation_cases = tqdm( + list(compilation_cases), + disable=not self.load_config.use_tqdm_on_load, + desc="Capturing CUDA graph shapes") for num_tokens in compilation_cases: # We skip EPLB here since we don't want to record dummy metrics for _ in range( @@ -2362,11 +2521,10 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: raise ValueError( f"Unknown KV cache spec type: {type(kv_cache_spec)}") - block_table_i = self.input_batch.block_table[i] attn_metadata_builder_i = attn_backend_i.get_builder_cls()( - weakref.proxy(self), kv_cache_spec, - block_table_i, + self.vllm_config, + self.device, ) if (self.full_cuda_graph @@ -2633,6 +2791,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # TODO: Support other attention modules, e.g., cross-attention if attn_module.attn_type == AttentionType.DECODER: + use_local_attention = (self.attention_chunk_size is not None + and attn_module.use_irope) if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -2641,6 +2801,17 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, use_mla=use_mla) + assert not use_local_attention, ( + "attention module can not be with ", + "both local attention and sliding window") + elif use_local_attention: + kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + attention_chunk_size=self.attention_chunk_size, + use_mla=use_mla) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, @@ -2658,23 +2829,18 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: raise ValueError( f"Unknown attention type: {attn_module.attn_type}") - mamba_layers = get_layers_from_vllm_config(self.vllm_config, - MambaMixer2) + mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase) if len(mamba_layers) > 0: if self.vllm_config.speculative_config is not None: raise NotImplementedError( "Mamba with speculative decoding is not supported yet.") - if not self.vllm_config.model_config.enforce_eager: - raise NotImplementedError( - "Mamba with cuda graph is not supported yet.") if self.vllm_config.cache_config.enable_prefix_caching: raise NotImplementedError( "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = self._maybe_pad_mamba_page_size( - attn_layers, mamba_layers, kv_cache_spec, max_model_len, - block_size) + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -2686,54 +2852,3 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: page_size_padded=page_size_padded) return kv_cache_spec - - def _maybe_pad_mamba_page_size( - self, - attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaMixer2], - kv_cache_spec: dict[str, KVCacheSpec], - max_model_len: int, - block_size: int, - ) -> Optional[int]: - """ - Ensure that page size of attention KV cache groups is greater than or - equal to the mamba KV cache groups. If not, we suggest to the user - how to set the attention block size to ensure that it is. - - If the attention page size is strictly greater than the mamba page size, - we pad the mamba page size to make them equal. - - Args: - attn_layers: Attention layers - mamba_layers: Mamba layers - kv_cache_spec: KV cache spec (populated with attention layers) - - Returns: - Optional[int]: Mamba page size with padding (None if no padding). - """ - - if len(attn_layers) == 0: - return None - - attn_layer_name = next(iter(attn_layers)) - attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = MambaSpec( - shapes=mamba_layers[mamba_layer_name].get_state_shape(), - dtype=self.kv_cache_dtype, - block_size=max_model_len).page_size_bytes - if attn_page_size < mamba_page_size: - # attention page size (for 16 tokens) - attn_page_size_16 = 16 * attn_page_size // block_size - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - suggest_attn_block_size = 16 * cdiv(mamba_page_size, - attn_page_size_16) - raise ValueError( - "Attention block size should be increased to at least " - f"{suggest_attn_block_size} in order to match " - "the mamba page size") - - return attn_page_size diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 9e7e44d06861..522946351148 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" +import copy import gc import os -from typing import TYPE_CHECKING, Optional +from contextlib import AbstractContextManager, nullcontext +from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed @@ -11,20 +13,22 @@ import vllm.envs as envs from vllm.config import VllmConfig -from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -79,6 +83,8 @@ def __init__( self.profiler = None def sleep(self, level: int = 1) -> None: + from vllm.device_allocator.cumem import CuMemAllocator + free_bytes_before_sleep = torch.cuda.mem_get_info()[0] # Save the buffers before level 2 sleep @@ -101,6 +107,8 @@ def sleep(self, level: int = 1) -> None: used_bytes / GiB_bytes) def wake_up(self, tags: Optional[list[str]] = None) -> None: + from vllm.device_allocator.cumem import CuMemAllocator + allocator = CuMemAllocator.get_instance() allocator.wake_up(tags) @@ -112,6 +120,21 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: buffer.data.copy_(self._sleep_saved_buffers[name].data) self._sleep_saved_buffers = {} + def _maybe_get_memory_pool_context(self, + tag: str) -> AbstractContextManager: + if self.vllm_config.model_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + + allocator = CuMemAllocator.get_instance() + if tag == "weights": + assert allocator.get_current_usage() == 0, ( + "Sleep mode can only be " + "used for one instance per process.") + context = allocator.use_memory_pool(tag=tag) + else: + context = nullcontext() + return context + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks @@ -130,7 +153,7 @@ def init_device(self): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) + current_platform.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() @@ -157,7 +180,8 @@ def init_device(self): # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank) + self.local_rank, + current_platform.dist_backend) # Set random seed. set_random_seed(self.model_config.seed) @@ -172,17 +196,16 @@ def init_device(self): # FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool # to hijack tensor allocation. def load_model(self) -> None: - if self.vllm_config.model_config.enable_sleep_mode: - allocator = CuMemAllocator.get_instance() - assert allocator.get_current_usage() == 0, ( - "Sleep mode can only be " - "used for one instance per process.") - context = allocator.use_memory_pool(tag="weights") - else: - from contextlib import nullcontext - context = nullcontext() - with context: - self.model_runner.load_model() + eep_scale_up = os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1" + with self._maybe_get_memory_pool_context(tag="weights"): + self.model_runner.load_model(eep_scale_up=eep_scale_up) + + def update_config(self, overrides: dict[str, Any]) -> None: + self.model_runner.update_config(overrides) + + def reload_weights(self) -> None: + with self._maybe_get_memory_pool_context(tag="weights"): + self.model_runner.reload_weights() @torch.inference_mode() def determine_available_memory(self) -> int: @@ -240,7 +263,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + if self.vllm_config.model_config.enable_sleep_mode: + from vllm.device_allocator.cumem import CuMemAllocator + allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache") else: @@ -294,6 +320,9 @@ def compile_or_warm_up_model(self) -> None: def get_model(self) -> nn.Module: return self.model_runner.get_model() + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + return self.model_runner.get_supported_pooling_tasks() + @torch.inference_mode() def execute_model( self, @@ -307,15 +336,27 @@ def execute_model( output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) + parallel_config = self.vllm_config.parallel_config if parallel_config.distributed_executor_backend != "external_launcher" \ and not get_pp_group().is_last_rank: assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - return None + if not has_kv_transfer_group(): + return None + + # In case of PP with kv transfer, we need to pass through the + # finished_sending and finished_recving buffers. + new_output = EMPTY_MODEL_RUNNER_OUTPUT + if output.finished_sending or output.finished_recving: + new_output = copy.copy(new_output) + new_output.finished_sending = output.finished_sending + new_output.finished_recving = output.finished_recving + output = new_output + assert isinstance(output, ModelRunnerOutput) - return output if self.is_driver_worker else None + return output def profile(self, is_start: bool = True): if self.profiler is None: @@ -346,6 +387,161 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return + def _eplb_before_scale_down(self, old_ep_size: int, + new_ep_size: int) -> None: + from vllm.distributed.parallel_state import get_ep_group + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding " + "before scaling down...") + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange(self.model_runner.model, + execute_shuffle=True, + global_expert_load=None, + rank_mapping=rank_mapping) + torch.cuda.synchronize() + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + def _eplb_after_scale_up( + self, old_ep_size: int, new_ep_size: int, + global_expert_load: Optional[torch.Tensor]) -> None: + from vllm.distributed.parallel_state import get_ep_group + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding " + "after scaling up...") + rank_mapping = { + old_ep_rank: old_ep_rank + for old_ep_rank in range(old_ep_size) + } + assert self.model_runner.eplb_state is not None + self.model_runner.eplb_state.rearrange( + self.model_runner.model, + execute_shuffle=True, + global_expert_load=global_expert_load, + rank_mapping=rank_mapping) + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed!") + + def _reconfigure_parallel_config( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + """ + Update parallel config with provided reconfig_request + """ + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = \ + reconfig_request.new_data_parallel_size + if reconfig_request.new_data_parallel_rank != \ + ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank = \ + reconfig_request.new_data_parallel_rank + if reconfig_request.new_data_parallel_rank_local != \ + ReconfigureRankType.KEEP_CURRENT_RANK: + parallel_config.data_parallel_rank_local = \ + reconfig_request.new_data_parallel_rank_local + parallel_config.data_parallel_master_ip = \ + reconfig_request.new_data_parallel_master_ip + parallel_config.data_parallel_master_port = \ + reconfig_request.new_data_parallel_master_port + + def _reconfigure_moe(self, old_ep_size: int, + new_ep_size: int) -> Optional[torch.Tensor]: + """ + Reconfigure MoE modules with provided reconfig_request + + Return the global expert load if new_ep_size > old_ep_size, + otherwise None + """ + from vllm.distributed.parallel_state import ( + get_dp_group, get_ep_group, prepare_communication_buffer_for_model) + from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoEParallelConfig) + + parallel_config = self.vllm_config.parallel_config + moe_modules = [ + module for module in self.model_runner.model.modules() + if module.__class__.__name__ == "FusedMoE" + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all(module.moe_config.num_local_experts == num_local_experts + for module in moe_modules), ( + "All MoE modules must have the same number of experts") + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=get_tp_group().world_size, + dp_size_=get_dp_group().world_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + if new_ep_size < old_ep_size: + num_local_physical_experts = num_local_experts + assert self.model_runner.eplb_state is not None + new_physical_experts = \ + self.model_runner.eplb_state.physical_to_logical_map.shape[1] + parallel_config.num_redundant_experts = ( + new_physical_experts - + self.model_runner.eplb_state.logical_replica_count.shape[1]) + global_expert_load = None + else: + num_local_physical_experts = torch.tensor([num_local_experts], + dtype=torch.int32, + device="cpu") + torch.distributed.broadcast(num_local_physical_experts, + group=get_ep_group().cpu_group, + group_src=0) + num_local_physical_experts = num_local_physical_experts.item() + new_physical_experts = num_local_physical_experts * new_ep_size + assert self.model_runner.eplb_state is not None + global_expert_load = self.model_runner.eplb_state.rearrange( + self.model_runner.model, execute_shuffle=False) + parallel_config.num_redundant_experts = ( + new_physical_experts - global_expert_load.shape[1]) + prepare_communication_buffer_for_model(self.model_runner.model) + self.model_runner.model.update_physical_experts_metadata( + num_physical_experts=new_physical_experts, + num_local_physical_experts=num_local_physical_experts) + return global_expert_load + + def reinitialize_distributed( + self, reconfig_request: ReconfigureDistributedRequest) -> None: + from vllm.config import set_current_vllm_config + from vllm.distributed.parallel_state import ( + cleanup_dist_env_and_memory, get_ep_group) + + old_ep_size = get_ep_group().world_size + old_ep_rank = get_ep_group().rank + new_ep_size = reconfig_request.new_data_parallel_size * get_tp_group( + ).world_size * get_pp_group().world_size + if new_ep_size < old_ep_size: + self._eplb_before_scale_down(old_ep_size, new_ep_size) + + cleanup_dist_env_and_memory() + + if reconfig_request.new_data_parallel_rank == \ + ReconfigureRankType.SHUTDOWN_CURRENT_RANK: + assert old_ep_rank >= new_ep_size + # shutdown + return + + self._reconfigure_parallel_config(reconfig_request) + + with set_current_vllm_config(self.vllm_config): + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank) + + global_expert_load = self._reconfigure_moe(old_ep_size, new_ep_size) + + if new_ep_size > old_ep_size: + assert global_expert_load is not None + self._eplb_after_scale_up(old_ep_size, new_ep_size, + global_expert_load) + def save_sharded_state( self, path: str, diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f5f26d8fff98..3bb033f14876 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -3,7 +3,7 @@ import bisect import gc import time -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from unittest.mock import patch import numpy as np @@ -18,21 +18,26 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher -from vllm.config import ParallelConfig, VllmConfig, get_layers_from_vllm_config +from vllm.config import (ParallelConfig, VllmConfig, + get_layers_from_vllm_config, update_config) from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.lora.layers import BaseLayerWithLoRA from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.tpu import TPUModelLoader +from vllm.model_executor.models.interfaces_base import is_pooling_model from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs, PlaceholderRange) from vllm.multimodal.utils import group_mm_inputs_by_modality +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv, - is_pin_memory_available) -from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, - PallasMetadata) +from vllm.utils import (LayerBlockType, cdiv, is_pin_memory_available, + prev_power_of_2) +from vllm.v1.attention.backends.pallas import (TPU_STR_DTYPE_TO_TORCH_DTYPE, + PallasAttentionBackend, + PallasMetadata, + get_page_size_bytes) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, @@ -41,11 +46,10 @@ LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler -from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch -from .utils import (initialize_kv_cache_for_kv_sharing, +from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing, sanity_check_mm_encoder_outputs) if TYPE_CHECKING: @@ -56,8 +60,6 @@ INVALID_TOKEN_ID = -1 # Smallest output size MIN_NUM_SEQS = 8 -# Block size used for kv cache updating kernel -NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK = 8 ######################################################### @@ -112,7 +114,6 @@ def __init__( self.original_parallel_config = original_parallel_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.device_config = vllm_config.device_config @@ -139,9 +140,13 @@ def __init__( self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype if cache_config.cache_dtype == "auto": - self.kv_cache_dtype = self.dtype + model_dtype = self.dtype + if isinstance(model_dtype, str): + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + else: + self.kv_cache_dtype = model_dtype else: - self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] self._hidden_states_dtype = self.dtype @@ -192,6 +197,14 @@ def __init__( self.max_num_encoder_input_tokens = encoder_compute_budget self.encoder_cache_size = encoder_cache_size + self._num_slices_per_kv_cache_update_block = \ + _get_num_slices_per_kv_cache_update_block(get_page_size_bytes( + block_size=self.block_size, + num_kv_heads=self.num_kv_heads, + head_size=self.head_size, + kv_cache_dtype=self.kv_cache_dtype, + )) + # Lazy initialization self.model: nn.Module # Set after load_model self.kv_caches: list[torch.Tensor] = [] @@ -472,6 +485,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: def get_model(self) -> nn.Module: return self.model + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + return list(model.pooler.get_supported_tasks()) + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: """ Generates the KVCacheSpec by parsing the kv cache format from each @@ -498,6 +518,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: continue if attn_module.attn_type == AttentionType.DECODER: + if attn_module.use_irope: + logger.warning_once( + "Using irope in Pallas is not supported yet, it " + "will fall back to global attention for long context.") if attn_module.sliding_window is not None: kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, @@ -719,7 +743,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_kv_update_slices = slot_mapping_metadata.shape[0] padded_num_slices = _get_padded_num_kv_cache_update_slices( padded_total_num_scheduled_tokens, self.max_num_reqs, - self.block_size) + self.block_size, self._num_slices_per_kv_cache_update_block) slot_mapping_metadata = np.pad( slot_mapping_metadata, [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], @@ -750,8 +774,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput", num_kv_update_slices=torch.tensor([num_kv_update_slices], dtype=torch.int32, device=self.device), - num_slices_per_kv_cache_update_block= - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, + num_slices_per_kv_cache_update_block=self. + _num_slices_per_kv_cache_update_block, ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -926,11 +950,10 @@ def _get_model_inputs(self, input_ids: torch.Tensor, # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) # as input to the multimodal model, even when the input is text. - if mm_embeds: - inputs_embeds = self.model.get_input_embeddings( - input_ids, mm_embeds) - else: - inputs_embeds = self.model.get_input_embeddings(input_ids) + inputs_embeds = self.model.get_input_embeddings( + input_ids=input_ids, + multimodal_embeddings=mm_embeds, + ) return None, inputs_embeds else: # For text-only models, we use token ids as input. @@ -958,7 +981,7 @@ def execute_model( else: mm_embeds = [] xm.mark_step() - # Prepare inputs, the requests might be splitted into multiple + # Prepare inputs, the requests might be split into multiple # executions, combine the result of each execution. start_index = 0 combined_selected_tokens: list[torch.Tensor] = [] @@ -1111,6 +1134,18 @@ def concat_lists(input_lists): return model_runner_output + def update_config(self, overrides: dict[str, Any]) -> None: + # TODO: TPU config may need extra validation + # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + def load_model(self) -> None: self.device = self.device_config.device @@ -1128,26 +1163,27 @@ def load_model(self) -> None: "vllm.model_executor.layers.vocab_parallel_embedding." "get_tensor_model_parallel_rank", return_value=xm_tp_rank): - if self.use_spmd: - tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config) - model = tpu_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.vllm_config.model_config, - mesh=self.mesh) - else: - # model = get_model(vllm_config=self.vllm_config) - model_loader = get_model_loader(self.load_config) - if not hasattr(self, "model"): + try: + if self.use_spmd: + tpu_loader = TPUModelLoader( + load_config=self.vllm_config.load_config) + model = tpu_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.vllm_config.model_config, + mesh=self.mesh) + else: + model_loader = get_model_loader(self.load_config) logger.info("Loading model from scratch...") model = model_loader.load_model( vllm_config=self.vllm_config, model_config=self.model_config) - else: - logger.info("Model was already initialized. \ - Loading weights inplace...") - model_loader.load_weights(self.model, - model_config=self.model_config) + except RuntimeError as e: + raise RuntimeError( + f"Unable to load model, a likely reason is the model is " + "too large for the current device's HBM memory. " + "Consider switching to a smaller model " + "or sharding the weights on more chips. " + f"See the detailed error: {e}") from e if self.lora_config is not None: model = self.load_lora_model(model, self.model_config, self.scheduler_config, @@ -1162,6 +1198,13 @@ def load_model(self) -> None: self.model = model self.sampler = TPUSampler() + def reload_weights(self) -> None: + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model_loader.load_weights(self.model, model_config=self.model_config) + @torch.no_grad() def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: @@ -1178,7 +1221,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size) + num_tokens, self.max_num_reqs, self.block_size, + self._num_slices_per_kv_cache_update_block) num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to(self.device) slot_mapping = torch.zeros((3, padded_num_slices), @@ -1201,8 +1245,8 @@ def _dummy_run(self, num_tokens: int, num_reqs: int, query_start_loc=query_start_loc, num_seqs=num_seqs, num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block= - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK, + num_slices_per_kv_cache_update_block=self. + _num_slices_per_kv_cache_update_block, ) if self.is_multimodal_model: @@ -1807,19 +1851,42 @@ def _get_padded_token_len(paddings: list[int], x: int) -> int: return paddings[index] -def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, - page_size: int) -> int: +def _get_padded_num_kv_cache_update_slices( + num_tokens: int, max_num_reqs: int, page_size: int, + num_slices_per_kv_cache_update_block: int) -> int: """Calculates the padded number of KV cache update slices to avoid recompilation.""" padded_num_slices = 2 * max_num_reqs + num_tokens // page_size padded_num_slices = min(padded_num_slices, num_tokens) padded_num_slices = ( - padded_num_slices + NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK - 1 - ) // NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK * \ - NUM_SLICES_PER_KV_CACHE_UPDATE_BLOCK + padded_num_slices + num_slices_per_kv_cache_update_block - 1 + ) // num_slices_per_kv_cache_update_block * \ + num_slices_per_kv_cache_update_block return padded_num_slices +def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: + """Find the optimum number of slices to copy per Pallas program instance. + + Increasing the number of slices copied in one instance of the kernel program + will increase HBM bandwidth utilization via more in-flight DMAs. + + However, it will also use more VMEM, and experimentally, we observed + performance regression at 128 slices on v6e, likely due to running + out of scalar registers. Thus this function will limit the number of + slices to 64. + """ + # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we + # calculate num_slices_per_block based on 16MB in case any register spills. + vmem_limit = 16 * 1024 * 1024 + num_slices_per_block = vmem_limit // page_size_bytes + assert num_slices_per_block > 0, "Number of slices should be positive" + num_slices_per_block = prev_power_of_2(num_slices_per_block) + if num_slices_per_block > 64: + num_slices_per_block = 64 + return num_slices_per_block + + def replace_set_lora(model): def _tpu_set_lora( diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index a64ce881fe31..648d9c3195ce 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" import os -from typing import Optional +from typing import Any, Optional import torch import torch.distributed @@ -18,14 +18,17 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.pooling_params import PoolingTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import bind_kv_cache, report_usage_stats +from vllm.v1.utils import report_usage_stats from vllm.v1.worker.tpu_model_runner import TPUModelRunner +from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) @@ -59,7 +62,6 @@ def __init__( self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.parallel_config.rank = rank @@ -259,6 +261,12 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def load_model(self) -> None: self.model_runner.load_model() + def update_config(self, overrides: dict[str, Any]) -> None: + self.model_runner.update_config(overrides) + + def reload_weights(self) -> None: + self.model_runner.reload_weights() + def compile_or_warm_up_model(self) -> None: if not self.model_config.enforce_eager: self.model_runner.capture_model() @@ -270,6 +278,9 @@ def compile_or_warm_up_model(self) -> None: def get_model(self) -> nn.Module: return self.model_runner.get_model() + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + return self.model_runner.get_supported_pooling_tasks() + def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return self.model_runner.get_kv_cache_spec() @@ -300,7 +311,7 @@ def _init_tpu_worker_distributed_environment( rank=rank, local_rank=local_rank, distributed_init_method=distributed_init_method, - backend="gloo", + backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index 70339ff2f005..3ecb1d7dd656 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from collections import defaultdict +from typing import TYPE_CHECKING, Optional import torch from vllm.model_executor.models.interfaces import MultiModalEmbeddings +from vllm.model_executor.models.utils import extract_layer_index from vllm.v1.kv_cache_interface import KVCacheGroupSpec +if TYPE_CHECKING: + from vllm.attention.layer import Attention + def sanity_check_mm_encoder_outputs( mm_embeddings: MultiModalEmbeddings, @@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing( kv_caches[layer_name] = kv_caches[target_layer_name] group_idx = layer_to_kv_cache_group_idx[target_layer_name] kv_cache_groups[group_idx].layer_names.append(layer_name) + + +def bind_kv_cache( + kv_caches: dict[str, torch.Tensor], + forward_context: dict[str, "Attention"], + runner_kv_caches: list[torch.Tensor], +) -> None: + """ + Bind the allocated KV cache to both ModelRunner and forward context so + that the KV cache can be used in the forward pass. + + This function: + 1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with + kv_caches. + 2) Associates each attention layer in the `forward_context` with its + corresponding KV cache in kv_caches. + + Args: + kv_caches: The allocated kv_caches with layer names as keys. + forward_context: The global forward context containing all Attention + layers with layer names as keys. + runner_kv_caches: The kv_cache declared by ModelRunner. + """ + # Bind kv_caches to ModelRunner + assert len(runner_kv_caches) == 0 + + # Convert kv_caches dict to a list of tensors in the order of layer_index. + index2name = defaultdict(list) + for layer_name in kv_caches: + index2name[extract_layer_index(layer_name)].append(layer_name) + + for layer_index in sorted(index2name.keys()): + layer_names = index2name[layer_index] + if len(layer_names) > 1: + # One typical case is encoder-decoder model, e.g., bart. + # The cross attention and self attention in the same decoder layer + # has different layer_name but the same layer_index. + raise NotImplementedError + layer_name = layer_names[0] + runner_kv_caches.append(kv_caches[layer_name]) + + # Bind kv_caches to forward context + for layer_name, kv_cache in kv_caches.items(): + # NOTE: Use list because of v0 PP virtual engine. + forward_context[layer_name].kv_cache = [kv_cache] diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py index 4cedc913c2ab..59f8d0fcf5bd 100644 --- a/vllm/v1/worker/xpu_model_runner.py +++ b/vllm/v1/worker/xpu_model_runner.py @@ -27,7 +27,7 @@ def __init__( self.cascade_attn_enabled = False def _init_device_properties(self) -> None: - pass + self.num_sms = None def _sync_device(self) -> None: torch.xpu.synchronize() diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 6d1f5749d8b2..c7885694f7a3 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -7,6 +7,7 @@ import vllm.envs as envs from vllm.config import VllmConfig +from vllm.distributed import get_world_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform @@ -132,7 +133,7 @@ def init_device(self): if self.device_config.device.type == "xpu" and current_platform.is_xpu( ): self.device = torch.device(f"xpu:{self.local_rank}") - torch.xpu.set_device(self.device) + current_platform.set_device(self.device) torch.xpu.empty_cache() self.init_gpu_memory = torch.xpu.get_device_properties( self.local_rank).total_memory @@ -148,14 +149,15 @@ def init_device(self): os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_RANK"] = str(self.local_rank) - dist_backend = "ccl" init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, - self.local_rank, dist_backend) + self.local_rank, + current_platform.dist_backend) # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce(torch.zeros(1).xpu()) + torch.distributed.all_reduce(torch.zeros(1).xpu(), + group=get_world_group().device_group) # Set random seed. set_random_seed(self.model_config.seed) diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py deleted file mode 100644 index c99e2652a397..000000000000 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ /dev/null @@ -1,326 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast - -import torch - -from vllm.attention import AttentionMetadata -from vllm.forward_context import set_forward_context -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MultiModalKwargs -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad -from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, - ModelInputForCPUBuilder, - ModelInputForCPUWithSamplingMetadata) -from vllm.worker.model_runner_base import ( - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - - -@dataclasses.dataclass(frozen=True) -class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): - """ - Used by the EncoderDecoderModelRunner. - """ - encoder_input_tokens: Optional[torch.Tensor] = None - encoder_input_positions: Optional[torch.Tensor] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "encoder_input_tokens": self.encoder_input_tokens, - "encoder_input_positions": self.encoder_input_positions, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "EncoderDecoderModelInputForCPU": - return cast( - EncoderDecoderModelInputForCPU, - super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) - - -class CPUEncoderDecoderModelRunner( - CPUModelRunnerBase[EncoderDecoderModelInputForCPU]): - _model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( - EncoderDecoderModelInputForCPU) - _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder - - def _list_to_int32_tensor( - self, - _list: List[int], - ) -> torch.Tensor: - return torch.tensor(_list, dtype=torch.int32, device=self.device) - - def _list_to_long_tensor( - self, - _list: List[int], - ) -> torch.Tensor: - return torch.tensor(_list, dtype=torch.long, device=self.device) - - def _empty_int32_tensor(self) -> torch.Tensor: - return self._list_to_int32_tensor([]) - - def _empty_long_tensor(self) -> torch.Tensor: - return self._list_to_long_tensor([]) - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, - Any]) -> EncoderDecoderModelInputForCPU: - return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> EncoderDecoderModelInputForCPU: - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - ( - attn_metadata, - encoder_input_tokens_tensor, - encoder_input_positions_tensor, - ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, - model_input) - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - pin_memory=False, - generators=generators) - return dataclasses.replace( - model_input, - sampling_metadata=sampling_metadata, - attn_metadata=attn_metadata, - encoder_input_tokens=encoder_input_tokens_tensor, - encoder_input_positions=encoder_input_positions_tensor, - virtual_engine=virtual_engine, - ) - - def _prepare_encoder_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - model_input: EncoderDecoderModelInputForCPU, - ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], - Optional[torch.Tensor]]: - """Helper method to prepare the encoder- and cross-attn-related - model inputs based on a given sequence group. These additional inputs - are used to augment an already-computed `EncoderDecoderModelInput` - data structure which already has decoder-related model inputs - populated. - - Sets the following attn_metadata fields: - * `num_encoder_tokens` - * `encoder_seq_lens` - * `encoder_seq_lens_tensor` - * `max_encoder_seq_len` - * `cross_slot_mapping` - * `cross_block_tables` - - Constructs a new model inputs data structure, based on - (1) the existing fields in the `model_inputs` argument, - and (2) the following additional fields which are - computed (or in the case of `attn_metadata`, updated) - by this function: - * attn_metadata - * encoder_input_tokens - * encoder_input_positions - - Arguments: - - * seq_group_metadata_list: list of sequence groups for which to - compute inputs - * model_inputs: model inputs data structure with decoder-oriented - fields already computed. - - Return: - - * Updated model inputs data structure - """ - - if len(seq_group_metadata_list) == 0: - return (model_input.attn_metadata, None, None) - - # Since we are not supporting chunked prefill either the entire - # batch is prefill or it is decode - is_prompt = seq_group_metadata_list[0].is_prompt - - # Build encoder inputs - encoder_seq_lens: List[int] = [] - if is_prompt: - # Prefill phase. - cross_block_tables = self._empty_int32_tensor().view( - len(seq_group_metadata_list), -1) - - # Extract input tokens/positions, cross-attention slot-mapping, - # & seq len from each sequence group metadata - ( - encoder_input_tokens, - encoder_input_positions, - cross_slot_mapping, - ) = ( - [], - [], - [], - ) - for seq_group_metadata in seq_group_metadata_list: - # Build seq lens - seq_len = seq_group_metadata.encoder_seq_data.get_len() - token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() - encoder_seq_lens.append(seq_len) - - # Build slot mapping - for i in range(0, seq_len): - block_number = seq_group_metadata.cross_block_table[ - i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - cross_slot_mapping.append(slot) - - # Build encoder input tokens - encoder_input_tokens.extend(token_ids) - encoder_input_positions.extend(list(range(0, seq_len))) - - # Convert tokens/positions & cross-attention - # slot-mapping to encoder input tensors - encoder_input_tokens_tensor = self._list_to_long_tensor( - encoder_input_tokens) - encoder_input_positions_tensor = self._list_to_long_tensor( - encoder_input_positions) - cross_slot_mapping_tensor = self._list_to_long_tensor( - cross_slot_mapping) - - else: - # Decode phase. - encoder_input_tokens_tensor = self._empty_long_tensor() - encoder_input_positions_tensor = self._empty_long_tensor() - cross_slot_mapping_tensor = self._empty_long_tensor() - # Extract cross-attention block tables & - # seq len from each sequence group metadata. - # Cross-attention block tables are empty - # during vLLM memory profiling. - cross_block_tables = [] - for seq_group_metadata in seq_group_metadata_list: - for _ in range(len(seq_group_metadata.seq_data)): - encoder_seq_lens.append( - seq_group_metadata.encoder_seq_data.get_len()) - cross_block_table = seq_group_metadata.cross_block_table - cross_block_tables.append([] if ( - cross_block_table is None) else cross_block_table) - - max_len_of_block_table = max( - len(block_table) for block_table in cross_block_tables) - - cross_block_tables = make_tensor_with_pad( - cross_block_tables, - max_len=max_len_of_block_table, - pad=0, - dtype=torch.int32, - device=self.device, - ) - - # Compute encoder sequence lengths & encoder - # sequence starting offset tensors - max_encoder_seq_len = max(encoder_seq_lens, default=0) - encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) - encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + - 1, - dtype=torch.int32, - device=self.device) - torch.cumsum(encoder_seq_lens_tensor, - dim=0, - dtype=encoder_seq_start_loc.dtype, - out=encoder_seq_start_loc[1:]) - - # Update attention metadata with encoder-oriented attributes - attn_metadata = model_input.attn_metadata - assert attn_metadata is not None - ( - attn_metadata.num_encoder_tokens, - attn_metadata.encoder_seq_lens, - attn_metadata.encoder_seq_lens_tensor, - attn_metadata.max_encoder_seq_len, - attn_metadata.cross_slot_mapping, - attn_metadata.cross_block_tables, - ) = ( - sum(encoder_seq_lens), - encoder_seq_lens, - encoder_seq_lens_tensor, - max_encoder_seq_len, - cross_slot_mapping_tensor, - cross_block_tables, - ) - - return (attn_metadata, encoder_input_tokens_tensor, - encoder_input_positions_tensor) - - @torch.no_grad() - def execute_model( - self, - model_input: EncoderDecoderModelInputForCPU, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "CPU worker does not support multi-step execution.") - - model_executable = self.model - execute_model_kwargs = { - "input_ids": - model_input.input_tokens, - "positions": - model_input.input_positions, - "encoder_input_ids": - model_input.encoder_input_tokens, - "encoder_positions": - model_input.encoder_input_positions, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - "intermediate_tensors": - intermediate_tensors, - } - - with set_forward_context(model_input.attn_metadata, self.vllm_config, - model_input.virtual_engine): - hidden_states = model_executable(**execute_model_kwargs) - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return [] - - # Sample the next token. - output = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - return [output] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py deleted file mode 100644 index 68cdf65cafa7..000000000000 --- a/vllm/worker/cpu_model_runner.py +++ /dev/null @@ -1,671 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import weakref -from collections import defaultdict -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, - TypeVar, Union) - -import torch -from torch import nn - -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import VllmConfig -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.models import supports_lora, supports_multimodal -from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs, - MultiModalPlaceholderMap) -from vllm.sequence import (IntermediateTensors, SequenceData, - SequenceGroupMetadata) -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU") -_PAD_SLOT_ID = -1 - - -@dataclass(frozen=True) -class ModelInputForCPU(ModelRunnerInputBase): - """ - Base class contains metadata needed for the base model forward pass on CPU - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - token_type_ids: Optional[torch.Tensor] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None - virtual_engine: Optional[int] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "token_type_ids": self.token_type_ids, - "multi_modal_kwargs": self.multi_modal_kwargs, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForCPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None - ) -> TModelInputForCPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -@dataclass(frozen=True) -class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - is_prompt: Optional[bool] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "token_type_ids": self.token_type_ids, - "multi_modal_kwargs": self.multi_modal_kwargs, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForCPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): - - class ModelInputData: - - def __init__(self, use_mrope: bool): - self.use_mrope = use_mrope - self.input_tokens: List[int] = [] - self.input_positions: List[int] = [] - self.token_type_ids: Optional[List[int]] = [] - self.seq_lens: List[int] = [] - self.query_lens: List[int] = [] - self.prefill_block_tables: List[List[int]] = [] - self.decode_block_tables: List[List[int]] = [] - self.max_decode_seq_len: int = 0 - self.num_prefills: int = 0 - self.num_prefill_tokens: int = 0 - self.num_decode_tokens: int = 0 - self.slot_mapping: List[int] = [] - self.multi_modal_inputs_list: List[MultiModalKwargs] = [] - self.multi_modal_placeholder_maps: Dict[ - str, MultiModalPlaceholderMap] = defaultdict( - MultiModalPlaceholderMap) - self.input_mrope_positions: List[List[int]] = [[] - for _ in range(3)] - - def __init__(self, - runner: "CPUModelRunner", - finished_requests_ids: Optional[List[str]] = None) -> None: - super().__init__() - self.runner = runner - self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled - or runner.cache_config.enable_prefix_caching) - self.model_input_cls = self.runner._model_input_cls - self.attn_backend = self.runner.attn_backend - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.device = self.runner.device - self.enable_lora = self.runner.lora_config is not None - if self.runner.attn_backend is not None: - # spec decode (e.g. Medusa) does not have atten backend - attn_backend = self.runner.attn_backend - self.att_metadata_builder = attn_backend.get_builder_cls()(self) - - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] - self.input_data = ModelInputForCPUBuilder.ModelInputData( - self.runner.model_config.uses_mrope) - self.att_metadata_builder.prepare() - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - self.seq_group_metadata_list.append(seq_group_metadata) - - def set_seq_group_list( - self, seq_group_metadata_list: List[SequenceGroupMetadata]): - self.seq_group_metadata_list = seq_group_metadata_list - - def build(self) -> ModelInputForCPU: - self._build_input_data() - - input_data = self.input_data - input_tokens = torch.tensor(input_data.input_tokens, - dtype=torch.long, - device="cpu") - input_positions = torch.tensor( - input_data.input_positions - if not any(input_data.input_mrope_positions) else - input_data.input_mrope_positions, - dtype=torch.long, - device="cpu") - token_type_ids = torch.tensor(input_data.token_type_ids, - dtype=torch.long, - device="cpu") \ - if input_data.token_type_ids else None - - # For multi-modal models - multi_modal_kwargs = None - if len(input_data.multi_modal_inputs_list) != 0: - multi_modal_kwargs = MultiModalKwargs.batch( - input_data.multi_modal_inputs_list) - - attn_metadata = self.att_metadata_builder.build( - input_data.seq_lens, input_data.query_lens, -1, -1) - - is_prompt = (self.seq_group_metadata_list[0].is_prompt - if self.seq_group_metadata_list else None) - # LoRA data. - lora_requests = set() - lora_mapping = None - if self.enable_lora: - lora_requests = set(seq.lora_request - for seq in self.seq_group_metadata_list - if seq.lora_request is not None) - - lora_mapping = self._prepare_lora_input( - self.seq_group_metadata_list, is_prompt) - - return self.model_input_cls(input_tokens=input_tokens, - input_positions=input_positions, - token_type_ids=token_type_ids, - seq_lens=input_data.seq_lens, - query_lens=input_data.query_lens, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - lora_mapping=lora_mapping, - lora_requests=lora_requests) - - def _build_input_data(self): - for seq_group_metadata in self.seq_group_metadata_list: - for seq_id, seq_data in seq_group_metadata.seq_data.items(): - if seq_group_metadata.is_prompt: - self._compute_prompt_input_tokens(self.input_data, - seq_group_metadata, - seq_data, seq_id) - if seq_group_metadata.multi_modal_data: - self._compute_multi_modal_input( - seq_group_metadata, seq_data) - else: - self._compute_decode_input_tokens(self.input_data, - seq_group_metadata, - seq_data, seq_id) - - def _compute_decode_input_tokens(self, data: ModelInputData, - seq_group_metadata: SequenceGroupMetadata, - seq_data: SequenceData, seq_id: int): - """ - Compute decode input tokens, positions, block table and slot mapping. - """ - block_size = self.runner.block_size - - block_table = seq_group_metadata.block_tables[seq_id] - seq_len = seq_data.get_len() - context_len = seq_data.get_num_computed_tokens() - - tokens = seq_data.get_last_token_id() - token_positions = seq_len - 1 - block_number = block_table[token_positions // block_size] - block_offset = token_positions % block_size - slot = block_number * block_size + block_offset - - # For paged_attention kernel - if self.runner.sliding_window: - start_idx = max(0, seq_len - self.runner.sliding_window) - start_block = start_idx // block_size - start_idx = start_block * block_size - seq_len = seq_len - start_idx - block_table = block_table[start_block:] - - # For MRotaryEmbedding - if seq_data.mrope_position_delta is not None: - next_pos = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - for idx in range(3): - data.input_mrope_positions[idx].extend( # type: ignore - next_pos[idx]) - else: - data.input_positions.append(token_positions) # type: ignore - - # Update fields - data.input_tokens.append(tokens) - data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) - data.num_decode_tokens += 1 - data.slot_mapping.append(slot) - data.decode_block_tables.append(block_table) - data.query_lens.append(1) - data.seq_lens.append(seq_len) - - def _compute_prompt_input_tokens(self, data: ModelInputData, - seq_group_metadata: SequenceGroupMetadata, - seq_data: SequenceData, seq_id: int): - """ - Compute prompt input tokens, positions, block table and slot mapping. - """ - token_chunk_size = seq_group_metadata.token_chunk_size - block_size = self.runner.block_size - - block_table = seq_group_metadata.block_tables[seq_id] - seq_len = seq_data.get_len() - context_len = seq_data.get_num_computed_tokens() - seq_len = min(seq_len, context_len + token_chunk_size) - - # For prefix caching - prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) - if prefix_cache_block_num > 0: - prefix_cache_len = (prefix_cache_block_num * - self.runner.block_size) - if prefix_cache_len <= context_len: - # We already passed the cache hit region, - # so do normal computation. - pass - elif context_len < prefix_cache_len < seq_len: - # Partial hit. Compute the missing part. - context_len = prefix_cache_len - token_chunk_size = seq_len - context_len - elif seq_len <= prefix_cache_len: - # Full hit. Only compute the last token to avoid - # erroneous behavior. FIXME: Ideally we should directly - # mark all tokens as computed in the scheduler and do not - # schedule this sequence, so this case should not happen. - context_len = seq_len - 1 - token_chunk_size = 1 - - tokens = seq_data.get_token_ids() - tokens = tokens[context_len:seq_len] - token_positions = range(context_len, seq_len) - token_types = seq_group_metadata.token_type_ids - - # For encoder-only models, the block_table is None, - # and there is no need to initialize the slot_mapping. - if block_table is not None: - slot_mapping = [_PAD_SLOT_ID] * len(token_positions) - for i, pos in enumerate(token_positions): - block_number = block_table[pos // block_size] - block_offset = pos % block_size - slot = block_number * block_size + block_offset - slot_mapping[i] = slot - data.slot_mapping.extend(slot_mapping) - - # The MROPE positions are prepared in _compute_multi_modal_input - data.input_positions.extend(token_positions) - - if data.token_type_ids is not None: - data.token_type_ids.extend(token_types if token_types else []) - - # Update fields - data.input_tokens.extend(tokens) - data.num_prefills += 1 - data.num_prefill_tokens += len(tokens) - data.query_lens.append(len(tokens)) - data.prefill_block_tables.append(block_table) - data.seq_lens.append(seq_len) - - def _compute_multi_modal_input(self, - seq_group_metadata: SequenceGroupMetadata, - seq_data: SequenceData): - computed_len = seq_data.get_num_computed_tokens() - seq_len = self.input_data.seq_lens[-1] - - # NOTE: mm_kwargs only includes the subset of multi-modal items that - # intersect with the current prefill positions. - mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, range(computed_len, seq_len)) - - if not mm_kwargs: - return - - # special processing for mrope position deltas. - if self.runner.model_config.uses_mrope: - assert not self.chunked_prefill, \ - "MROPE on CPU does not support chunked-prefill." - - image_grid_thw = mm_kwargs.get("image_grid_thw", None) - video_grid_thw = mm_kwargs.get("video_grid_thw", None) - audio_feature_lengths = mm_kwargs.get("audio_feature_lengths", - None) - assert ( - image_grid_thw is not None or video_grid_thw is not None - or audio_feature_lengths is not None), ( - "mrope embedding type requires multi-modal input mapper " - "returns 'image_grid_thw' or 'video_grid_thw' or " - "'audio_feature_lengths'.") - - second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) - use_audio_in_video = mm_kwargs.get("use_audio_in_video", False) - hf_config = self.runner.model_config.hf_config - token_ids = seq_data.get_token_ids() - - mrope_positions, mrope_position_delta = \ - MRotaryEmbedding.get_input_positions( - token_ids, - hf_config=hf_config, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - second_per_grid_ts=second_per_grid_ts, - context_len=computed_len, - audio_feature_lengths=audio_feature_lengths, - use_audio_in_video=use_audio_in_video, - ) - seq_data.mrope_position_delta = mrope_position_delta - - for i in range(3): - self.input_data.input_mrope_positions[ # type: ignore - i].extend(mrope_positions[i]) - - self.input_data.multi_modal_inputs_list.append(mm_kwargs) - for modality, placeholder_map in placeholder_maps.items(): - self.input_data.multi_modal_placeholder_maps[modality].extend( - placeholder_map) - - def _prepare_lora_input( - self, seq_group_metadata_list: List[SequenceGroupMetadata], - is_prefill: bool) -> LoRAMapping: - index_mapping = [] - prompt_mapping = [] - for seq in seq_group_metadata_list: - lora_id = seq.lora_int_id - query_len = seq.token_chunk_size - - index_mapping += [lora_id] * query_len - prompt_mapping += [lora_id] * ( - query_len if seq.sampling_params - and seq.sampling_params.prompt_logprobs is not None else 1) - - return LoRAMapping(index_mapping=tuple(index_mapping), - prompt_mapping=tuple(prompt_mapping), - is_prefill=is_prefill) - - -class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): - """ - Helper class for shared methods between CPU model runners. - """ - _model_input_cls: Type[TModelInputForCPU] - _builder_cls: Type[ModelInputForCPUBuilder] - builder: ModelInputForCPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - return_hidden_states: bool = False, - *args, - **kwargs, - ): - ModelRunnerBase.__init__(self, vllm_config) - model_config = self.model_config - cache_config = self.cache_config - - self.is_driver_worker = is_driver_worker - self.return_hidden_states = return_hidden_states - - self.device = self.device_config.device - self.pin_memory = False - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - num_attn_heads = self.model_config.get_num_attention_heads( - self.parallel_config) - needs_attn_backend = (num_attn_heads != 0 - or self.model_config.is_attention_free) - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) if needs_attn_backend else None - - # Lazy initialization. - self.model: nn.Module # Set after init_Model - # Set after load_model. - self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self.sampler = get_sampler() - - if hasattr(self, "_builder_cls"): - # multi-step model runner does not have `_builder_cls` - self.builder = self._builder_cls(weakref.proxy(self)) - - def load_model(self) -> None: - self.model = get_model(vllm_config=self.vllm_config) - - if self.lora_config: - assert supports_lora( - self.model - ), f"{self.model.__class__.__name__} does not support LoRA yet." - - if supports_multimodal(self.model): - logger.warning("Regarding multimodal models, vLLM currently " - "only supports adding LoRA to language model.") - - # Use get_text_config() in case of multimodal models - text_config = self.model_config.hf_config.get_text_config() - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=text_config.max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - def get_model(self) -> nn.Module: - return self.model - - def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None - ) -> TModelInputForCPU: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - """ - self.builder.prepare(finished_requests_ids) - self.builder.set_seq_group_list(seq_group_metadata_list) - - return self.builder.build() # type: ignore - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - - -class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]): - _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( - ModelInputForCPUWithSamplingMetadata) - _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> ModelInputForCPUWithSamplingMetadata: - return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForCPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - pin_memory=False, - generators=generators) - - is_prompt = (seq_group_metadata_list[0].is_prompt - if seq_group_metadata_list else None) - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - virtual_engine=virtual_engine, - is_prompt=is_prompt) - - @torch.no_grad() - def execute_model( - self, - model_input: ModelInputForCPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - previous_hidden_states: Optional[torch.Tensor] = None, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "CPU worker does not support multi-step execution.") - - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - - model_executable = self.model - - multimodal_kwargs = {} - if model_input.multi_modal_kwargs is not None: - multimodal_kwargs = MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs, - device=self.device, - ) - execute_model_kwargs = {} - if previous_hidden_states is not None: - execute_model_kwargs.update( - {"previous_hidden_states": previous_hidden_states}) - - with set_forward_context(model_input.attn_metadata, self.vllm_config, - model_input.virtual_engine): - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **execute_model_kwargs, - **multimodal_kwargs, - ) - - # Compute the logits. - logits = self.model.compute_logits(hidden_states, - model_input.sampling_metadata) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return [] - - # Sample the next token. - output = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - if model_input.is_prompt: - output.prefill_hidden_states = hidden_states - output.hidden_states = hidden_states - return [output] - - def generate_proposals(self, *args, **kwargs): - return self.model.generate_proposals(*args, **kwargs) diff --git a/vllm/worker/cpu_pooling_model_runner.py b/vllm/worker/cpu_pooling_model_runner.py deleted file mode 100644 index 203fdf225a41..000000000000 --- a/vllm/worker/cpu_pooling_model_runner.py +++ /dev/null @@ -1,125 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union - -import torch - -from vllm.forward_context import set_forward_context -from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.multimodal import MultiModalKwargs -from vllm.pooling_params import PoolingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, - SequenceGroupMetadata) -from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, - ModelInputForCPUBuilder) - - -@dataclasses.dataclass(frozen=True) -class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): - """ - Used by the CPUPoolingModelRunner. - """ - pooling_metadata: Optional["PoolingMetadata"] = None - - -class CPUPoolingModelRunner( - CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): - _model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( - ModelInputForCPUWithPoolingMetadata) - _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForCPUWithPoolingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: - if num_steps > 1: - raise ValueError( - "CPU worker does not support multi-step execution.") - - model_executable = self.model - cross_enc_kwargs = {} - if model_input.token_type_ids is not None: - cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids - execute_model_kwargs = { - "input_ids": - model_input.input_tokens, - "positions": - model_input.input_positions, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - **cross_enc_kwargs, - "intermediate_tensors": - intermediate_tensors, - } - - with set_forward_context(model_input.attn_metadata, self.vllm_config, - model_input.virtual_engine): - hidden_states = model_executable(**execute_model_kwargs) - - # Only perform pooling in the driver worker. - if not self.is_driver_worker: - return [] - - return [ - self.model.pooler(hidden_states=hidden_states, - pooling_metadata=model_input.pooling_metadata) - ] - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForCPUWithPoolingMetadata: - return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - ) - - def prepare_model_input( - self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForCPUWithPoolingMetadata: - assert seq_group_metadata_list is not None - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - # Prepare PoolingMetadata. - assert model_input.seq_lens is not None - pooling_metadata = self._prepare_pooling(seq_group_metadata_list, - model_input.seq_lens) - - return dataclasses.replace(model_input, - virtual_engine=virtual_engine, - pooling_metadata=pooling_metadata) - - def _prepare_pooling( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - prompt_lens: List[int], - ) -> PoolingMetadata: - """Prepare PoolingMetadata for the sequence group metadata list.""" - seq_groups: List[Tuple[List[int], PoolingParams]] = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - seq_ids = list(seq_group_metadata.seq_data.keys()) - pooling_params = seq_group_metadata.pooling_params - seq_groups.append((seq_ids, pooling_params)) - - seq_data: Dict[int, SequenceData] = {} - for seq_group_metadata in seq_group_metadata_list: - seq_data.update(seq_group_metadata.seq_data) - - pooling_metadata = PoolingMetadata( - seq_groups=seq_groups, - seq_data=seq_data, - prompt_lens=prompt_lens, - ) - - return pooling_metadata diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py deleted file mode 100644 index a8998127b60f..000000000000 --- a/vllm/worker/cpu_worker.py +++ /dev/null @@ -1,452 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A CPU worker class.""" -import os -from importlib import util -from typing import List, Optional, Set, Tuple, Type - -import torch -import torch.distributed - -import vllm.envs as envs -from vllm.attention import get_attn_backend -from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, VllmConfig) -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest -from vllm.utils import bind_kv_cache -from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner -from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase -from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class CPUCacheEngine: - """Manages the KV cache for CPU backend. - - This class is responsible for initializing and managing CPU KV - caches. It also provides methods for performing KV cache operations, such - as copying. - """ - - def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, - parallel_config: ParallelConfig, - device_config: DeviceConfig) -> None: - assert device_config.device_type == "cpu" - self.cache_config = cache_config - self.model_config = model_config - self.parallel_config = parallel_config - - self.head_size = model_config.get_head_size() - self.num_layers = model_config.get_num_layers(parallel_config) - self.num_heads = model_config.get_num_kv_heads(parallel_config) - - self.block_size = cache_config.block_size - # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks - # for CPU backend, because we want to reuse KV cache management - # in the scheduler. - self.num_cpu_blocks = cache_config.num_gpu_blocks - - self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, - model_config) - - # Get attention backend. - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - cache_config.cache_dtype, - self.block_size, - self.model_config.is_attention_free, - use_mla=self.model_config.use_mla, - ) - - # Initialize the cache. - self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) - - def _allocate_kv_cache( - self, - num_blocks: int, - ) -> List[torch.Tensor]: - """Allocates KV cache on CPU.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_heads, self.head_size) - kv_cache: List[torch.Tensor] = [] - for _ in range(self.num_layers): - kv_cache.append( - torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) - return kv_cache - - def swap_in(self, src_to_dst: torch.Tensor) -> None: - raise NotImplementedError("Swap is not supported in CPUCacheEngine.") - - def swap_out(self, src_to_dst: torch.Tensor) -> None: - raise NotImplementedError("Swap is not supported in CPUCacheEngine.") - - def copy(self, src_to_dsts: torch.Tensor) -> None: - self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) - - @staticmethod - def get_kv_cache_dtype(cache_config: CacheConfig, - model_config: ModelConfig): - if cache_config.cache_dtype == "auto": - return model_config.dtype - elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: - return torch.float8_e5m2 - else: - raise NotImplementedError(f"Unsupported KV cache type " - f"{cache_config.cache_dtype}.") - - @staticmethod - def get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, - ) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_layers = model_config.get_num_layers(parallel_config) - - key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block if not model_config.use_mla else 0 - total = num_layers * (key_cache_block + value_cache_block) - dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config) - dtype_size = torch.tensor([], dtype=dtype).element_size() - return dtype_size * total - - -class CPUWorker(LocalOrDistributedWorkerBase): - """A worker class that executes (a partition of) the model on a CPU socket. - - Each worker is associated with a single CPU socket. The worker is - responsible for maintaining the KV cache and executing the model on the - CPU. In case of distributed inference, each worker is assigned a partition - of the model. - """ - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[CPUModelRunner]] = None, - ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - - self.local_rank = local_rank - self.rank = rank - vllm_config.parallel_config.rank = rank - - self.distributed_init_method = distributed_init_method - - self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - # Setup OpenMP threads affinity. - omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND - self.local_omp_cpuid = "all" - if omp_cpuids == "auto": - self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes( - ) - else: - self.local_omp_cpuid = omp_cpuids.split("|")[rank] - - # Return hidden states from target model if the draft model is an - # mlp_speculator - speculative_config = self.speculative_config - model_config = self.model_config - speculative_args = {} if speculative_config is None \ - or (speculative_config.draft_model_config.model == - model_config.model) \ - or (speculative_config.draft_model_config.hf_config.model_type - not in ["medusa", "mlp_speculator", "eagle"]) \ - else {"return_hidden_states": True} - ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner - if self.model_config.runner_type == "pooling": - ModelRunnerClass = CPUPoolingModelRunner - elif self.model_config.is_encoder_decoder: - ModelRunnerClass = CPUEncoderDecoderModelRunner - self.model_runner: CPUModelRunnerBase = ModelRunnerClass( - vllm_config=vllm_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker, - **speculative_args, - ) - if model_runner_cls is not None: - self.model_runner = model_runner_cls(self.model_runner) - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CPUCacheEngine] - # Initialize cpu_cache as pooling models don't initialize kv_caches - self.cpu_cache: Optional[List[List[torch.Tensor]]] = None - - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() - - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - - def init_device(self) -> None: - if self.local_omp_cpuid != "all": - ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) - if ret: - logger.info(ret) - - # Note: unique identifier for creating allreduce shared memory - os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split( - ":")[-1] - self.device = torch.device("cpu") - self.init_distributed_environment() - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of blocks available for the KV cache. - - This determines how many KV blocks can fit into the configured CPU - KV cache space. - - Note that since vLLM assumes a block resides on GPU if it can be - modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. - This allows us to reuse the scheduler of vLLM without generalizing it - to different devices. - """ - # For CPU device, the block number will be calculated based on the - # cpu_kvcache_space. - cache_block_size = self.get_cache_block_size_bytes() - num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // - cache_block_size) - num_cpu_blocks = max(num_cpu_blocks, 0) - - # Note: To reuse the cache management procedure, - # use cpu cache as 'gpu cache'. - num_gpu_blocks = num_cpu_blocks - num_cpu_blocks = 0 - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache. Currently, swappable CPU memory is not - supported. - - Since this worker does not support GPUs, we use the num_gpu_blocks to - determine how many non-swappable CPU blocks to allocate. - """ - assert (num_cpu_blocks == 0 - ), f"{type(self)} does not support swappable cache" - - # Note: To reuse the cache management procedure, - # use cpu cache as 'gpu cache'. - num_cpu_blocks = num_gpu_blocks - - self._validate_num_cpu_blocks(num_cpu_blocks) - self.cache_config.num_gpu_blocks = num_cpu_blocks - self.cache_config.num_cpu_blocks = 0 - - # Initialize the cache. - self._init_cache_engine() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() - - def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: - """Raise errors if the num_cpu_blocks is invalid. - """ - if num_cpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " - "initializing the engine.") - - max_seq_len = self.cache_config.block_size * num_cpu_blocks - if self.model_config.max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({self.model_config.max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " - "initializing the engine.") - - def _init_cache_engine(self) -> None: - self.cache_engine = [ - CPUCacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.cpu_cache = [ - self.cache_engine[ve].cpu_cache - for ve in range(self.parallel_config.pipeline_parallel_size) - ] - bind_kv_cache(self.compilation_config.static_forward_context, - self.cpu_cache) - self.model_runner.block_size = self.cache_engine[0].block_size - - assert all( - self.cpu_cache[ve] is not None - for ve in range(self.parallel_config.pipeline_parallel_size)) - - # Populate the cache to warmup the memory - for ve in range(self.parallel_config.pipeline_parallel_size): - for layer_cache in self.cpu_cache[ve]: - layer_cache.fill_(0) - - @property - def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return self.cpu_cache - - @property - def vocab_size(self) -> int: - return self.model_runner.vocab_size - - @property - def max_model_len(self) -> int: - return self.model_config.max_model_len - - def execute_worker( - self, - worker_input: WorkerInput, - ) -> None: - if (worker_input.blocks_to_copy is not None - and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine[worker_input.virtual_engine].copy( - worker_input.blocks_to_copy) - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - assert execute_model_req is not None - virtual_engine: int = execute_model_req.virtual_engine - num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device="cpu", - dtype=torch.int64).view(-1, 2) - assert len(execute_model_req.blocks_to_swap_in) == 0 - assert len(execute_model_req.blocks_to_swap_out) == 0 - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - ) - - def init_distributed_environment(self) -> None: - """Initialize the distributed environment.""" - - parallel_config = self.parallel_config - rank = self.rank - distributed_init_method = self.distributed_init_method - init_distributed_environment( - world_size=parallel_config.world_size, - rank=rank, - distributed_init_method=distributed_init_method, - backend="gloo", - ) - - # A small all_reduce for warmup. - torch.distributed.all_reduce(torch.zeros(1).cpu()) - - ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - def get_cache_block_size_bytes(self) -> int: - """Return the size in bytes of a single KV cache block. - """ - return CPUCacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - - def get_cpus_id_binding_based_on_numa_nodes(self) -> str: - """Return CPUs id binding based on NUMA nodes. - """ - rank_to_cpus = self.local_omp_cpuid - # Setup OpenMP thread affinity based on NUMA nodes automatically - world_size = self.vllm_config.parallel_config.world_size - libnuma_found = util.find_spec("numa") is not None - psutil_found = util.find_spec("psutil") is not None - if libnuma_found and psutil_found: - import psutil - from numa import info - cpu_count = psutil.cpu_count(logical=False) - cpus_allow_list = psutil.Process().cpu_affinity() - numa_size = info.get_num_configured_nodes() - cpu_count_per_numa = cpu_count // numa_size - num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU, - cpu_count_per_numa // 2) - - # check allow node_to_cpus list - node_to_cpus = [] - for i in range(numa_size): - node_intersect = set( - info.node_to_cpus(i)).intersection(cpus_allow_list) - if bool(node_intersect): - node_to_cpus.append(list(node_intersect)) - - if world_size > len(node_to_cpus): - logger.error( - "Auto thread-binding failed due to " - "world size: %d is larger than " - "allowed NUMA nodes number: %d." - "Please try to bind threads manually.", world_size, - len(node_to_cpus)) - else: - end = cpu_count_per_numa - num_of_reserved_cpu - rank_to_cpus_list = node_to_cpus[self.rank][:end] - rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list) - logger.info("auto thread-binding list: %s", rank_to_cpus) - else: - logger.warning( - "Auto thread-binding is not supported due to " - "the lack of package numa and psutil," - "fallback to no thread-binding. To get better performance," - "please try to manually bind threads.") - return rank_to_cpus diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 8d92edc5b386..cb5d5664ab5c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -91,10 +91,9 @@ def __init__( ''' EncoderDecoderModelRunner constructor. - `lora_config` and `prompt_adapter_config` are - unused (since these features are not yet supported for encoder/decoder - models) but these arguments are present here for compatibility with - the base-class constructor. + `lora_config` is unused (since these features are not yet supported + for encoder/decoder models) but these arguments are present here for + compatibility with the base-class constructor. ''' self._maybe_force_supported_attention_backend() diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py deleted file mode 100644 index 586036829882..000000000000 --- a/vllm/worker/hpu_model_runner.py +++ /dev/null @@ -1,2320 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -############################################################################### - -import collections -import contextlib -import dataclasses -import functools -import gc -import itertools -import math -import os -import time -from array import array -from enum import Enum, IntEnum -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, - Optional, Set, Tuple, Type, TypeVar, Union) - -import habana_frameworks.torch as htorch -import habana_frameworks.torch.internal.bridge_config as bc -import torch -import torch.nn as nn -import vllm_hpu_extension.environment as environment -from vllm_hpu_extension.bucketing.common import get_bucketing_context -from vllm_hpu_extension.ops import LoraMask as LoraMask -from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler, - HabanaMemoryProfiler, format_bytes) - -import vllm.envs as envs -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import DeviceConfig, VllmConfig -from vllm.distributed import broadcast_tensor_dict -from vllm.distributed.parallel_state import get_world_group -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping -from vllm.lora.request import LoRARequest -from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager -from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.sampling_metadata import SequenceGroupToSample -from vllm.multimodal import BatchedTensorInputs, MultiModalKwargs -from vllm.sampling_params import SamplingParams -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceData, SequenceGroupMetadata, - SequenceOutput) -from vllm.utils import (bind_kv_cache, is_pin_memory_available, - make_tensor_with_pad) -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -_TYPE_CACHE = {} -# These values are assumed to be zero in several places. -# Use caution when updating them! -_PAD_SLOT_ID = 0 -_PAD_BLOCK_ID = 0 - -LORA_WARMUP_RANK = 8 - -DUMMY_TOKEN_ID = -1 - - -class PhaseType(Enum): - PREFILL = 'prefill' - PREFIX_PREFILL = 'prefix_prefill' - DECODE = 'decode' - - -def subtuple(obj: object, - typename: str, - to_copy: List[str], - to_override: Optional[Dict[str, object]] = None): - if obj is None: - return None - if to_override is None: - to_override = {} - fields = set(to_copy) | set(to_override.keys()) - if type(obj) is dict: - values = {key: obj[key] for key in fields if key in obj} - else: - values = {f: to_override.get(f, getattr(obj, f)) for f in fields} - if typename not in _TYPE_CACHE: - _TYPE_CACHE[typename] = collections.namedtuple(typename, - ' '.join(fields)) - return _TYPE_CACHE[typename](**values) - - -def round_up(value: int, k: int): - return (value + k - 1) // k * k - - -def align_workers(value, op): - group = get_world_group().cpu_group - world_size = torch.distributed.get_world_size() - if world_size <= 1: - return value - value_t = torch.tensor(value, device='cpu') - torch.distributed.all_reduce(value_t, op=op, group=group) - return value_t.item() - - -def setup_profiler(): - schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1) - DEVICE = 'hpu' - activities = [torch.profiler.ProfilerActivity.CPU] - activities.extend([torch.profiler.ProfilerActivity.HPU] if DEVICE == - 'hpu' else []) - #from habana_frameworks.torch.activity_profiler import DebugActivity - #debug_activities=[DebugActivity.BRIDGE_FUNCTION_CALLS] - - profiler = torch.profiler.profile( - schedule=schedule, - activities=activities, - #debug_activities=debug_activities, - on_trace_ready=torch.profiler.tensorboard_trace_handler('.', - use_gzip=True), - record_shapes=False, - with_stack=True) - return profiler - - -def pad_list(input, k, v): - input_len = len(input) - target_len = round_up(input_len, k) - padding = target_len - input_len - return input + [v] * padding - - -def gather_list(input, indices, v): - return [input[i] if i is not None else v for i in indices] - - -def flatten(in_list): - return list(itertools.chain(*in_list)) - - -def precompute_indices_and_offsets(block_size, slot_mapping, is_prompt): - slot_mapping = slot_mapping.flatten() - indices = torch.div(slot_mapping, block_size, rounding_mode="floor") - if is_prompt: - indices = indices.unflatten(0, (-1, block_size))[:, 0] - offsets = None - else: - offsets = torch.fmod(slot_mapping, block_size) - return indices, offsets - - -def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"): - if module.__class__.__name__.endswith(suffix): - - def forward_hook(module, args, output): - htorch.core.mark_step() - return output - - module.register_forward_hook(forward_hook) - - for child_name, child_module in module.named_children(): - modify_decoder_layer(child_module) - - -class HpuModelAdapter: - - def __init__(self, model, vllm_config): - self.model = model - self.sampler = get_sampler() - self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', - '0').lower() in ['1', 'true'] - self.vllm_config = vllm_config - self.block_size = vllm_config.cache_config.block_size - self.dtype = vllm_config.model_config.dtype - enforce_eager = vllm_config.model_config.enforce_eager - - if not htorch.utils.internal.is_lazy() and not enforce_eager: - if os.getenv('VLLM_REGIONAL_COMPILATION', - 'true').lower() == 'true': - self.regional_compilation_layers_list = [ - RMSNorm, VocabParallelEmbedding - ] - self._regional_compilation(self.model) - else: - self.model = torch.compile(self.model, - backend='hpu_backend', - dynamic=False) - - def _regional_compilation(self, - module, - parent_module=None, - module_name=None): - if isinstance(module, torch.nn.ModuleList): - for children_name, children_module in module.named_children(): - self._compile_region(module, children_name, children_module) - elif any( - isinstance(module, layer) - for layer in self.regional_compilation_layers_list): - self._compile_region(parent_module, module_name, module) - else: - for children_name, children_module in module.named_children(): - self._regional_compilation(children_module, module, - children_name) - - def _compile_region(self, model, name, module): - module = torch.compile(module, backend='hpu_backend', dynamic=False) - setattr(model, name, module) - - def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, - dtype): - if (attn_metadata is None - or (self.prefill_use_fusedsdpa \ - and attn_metadata.block_list is None) - or not attn_metadata.is_prompt): - return attn_metadata - - prefill_metadata = attn_metadata - - seq_lens_t = prefill_metadata.seq_lens_tensor - context_lens_t = prefill_metadata.context_lens_tensor - query_lens_t = seq_lens_t - context_lens_t - - block_list = attn_metadata.block_list - max_context_len = (block_list.size(-1) // - batch_size if block_list is not None else 0) - max_context_len = max_context_len * self.block_size - past_mask = torch.arange(0, - max_context_len, - dtype=torch.int32, - device=device) - past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge( - context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand( - batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) - - len_mask = (torch.arange(0, seq_len, device=device, - dtype=torch.int32).view(1, seq_len).ge( - query_lens_t.unsqueeze(-1)).view( - batch_size, 1, 1, seq_len)) - causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), - device=device, - dtype=torch.bool), - diagonal=1) - mask = causal_mask.logical_or(len_mask) - mask = torch.concat((past_mask, mask), dim=-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf)) - attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) - return attn_metadata - - def _set_block_mapping(self, metadata, batch_size, device, dtype): - mask = torch.arange(0, - self.block_size, - device=device, - dtype=torch.int32).unsqueeze(0) - mask = mask >= metadata.block_usage.unsqueeze(-1) - attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( - mask, -math.inf)) - if os.environ.get('VLLM_USE_FAKE_HPU', - '0') == '0' and htorch.utils.internal.is_lazy(): - block_mapping = torch.nn.functional.one_hot(metadata.block_groups, - num_classes=batch_size) - else: - # Unfortunately one_hot on CPU/torch.compile mode/eager mode - # doesn't handle out of bounds classes so we need to convert - # all negative values to 0 (block_mapping) or bs (block_groups) - block_groups = metadata.block_groups.to(torch.long) - block_mapping = torch.nn.functional.relu(block_groups) - block_mapping = torch.nn.functional.one_hot(block_mapping, - num_classes=batch_size) - oob_values = block_groups.lt(0) - block_mapping.masked_fill_(oob_values.unsqueeze(-1), 0) - block_groups.masked_fill_(oob_values, batch_size) - metadata = metadata._replace(block_groups=block_groups) - block_mapping = block_mapping.to(dtype) - metadata = metadata._replace(block_mapping=block_mapping, - attn_bias=attn_bias) - return metadata - - def _update_metadata(self, attn_metadata, batch_size, seq_len, device, - dtype): - if attn_metadata.is_prompt: - meta = attn_metadata - attn_metadata = self._set_attn_bias(meta, batch_size, seq_len, - device, dtype) - else: - meta = attn_metadata - attn_metadata = self._set_block_mapping(meta, batch_size, device, - dtype) - return attn_metadata - - def forward(self, *args, **kwargs): - kwargs = kwargs.copy() - selected_token_indices = kwargs.pop('selected_token_indices') - if 'warmup_mode' in kwargs: - kwargs.pop('warmup_mode') - virtual_engine = 0 - if 'virtual_engine' in kwargs: - virtual_engine = kwargs.pop('virtual_engine') - input_ids = kwargs['input_ids'] - attn_metadata = self._update_metadata(kwargs.pop('attn_metadata'), - input_ids.size(0), - input_ids.size(1), - input_ids.device, self.dtype) - LoraMask.setLoraMask(kwargs.pop('lora_mask')) - with set_forward_context(attn_metadata, self.vllm_config, - virtual_engine): - hidden_states = self.model(*args, **kwargs) - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - hidden_states = hidden_states.index_select(0, - selected_token_indices) - return hidden_states - - def compute_logits(self, *args, **kwargs): - return self.model.compute_logits(*args, **kwargs) - - def sample(self, *args, **kwargs): - return self.sampler(*args, **kwargs) - - -class PreparePromptMetadata(NamedTuple): - input_tokens: torch.Tensor - input_positions: List[List[int]] - attn_metadata: Optional[AttentionMetadata] - seq_lens: List[int] - query_lens: List[int] - lora_index_mapping: List[List[int]] - lora_prompt_mapping: List[List[int]] - lora_requests: Set[LoRARequest] - multi_modal_kwargs: Optional[Dict[str, BatchedTensorInputs]] - slot_mapping: List[List[int]] - lora_ids: List[int] - - @classmethod - def empty(cls): - return PreparePromptMetadata(input_tokens=[], - input_positions=[], - attn_metadata=None, - seq_lens=[], - query_lens=[], - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - multi_modal_kwargs=None, - slot_mapping=[], - lora_ids=[]) - - -class PrepareDecodeMetadata(NamedTuple): - input_tokens: torch.Tensor - input_positions: List[List[int]] - attn_metadata: Optional[AttentionMetadata] - lora_index_mapping: List[List[int]] - lora_prompt_mapping: List[List[int]] - lora_requests: Set[LoRARequest] - slot_mapping: List[List[int]] - lora_ids: List[int] - - @classmethod - def empty(cls): - return PrepareDecodeMetadata(input_tokens=[], - input_positions=[], - attn_metadata=None, - lora_index_mapping=[], - lora_prompt_mapping=[], - lora_requests=set(), - slot_mapping=[], - lora_ids=[]) - - -# How batches are constructed. -class BatchType(IntEnum): - # Every batch is prefill. - PREFILL = 0 - # Every batch is decode. - DECODE = 1 - # Batch is a mixture of prefill and decode. - MIXED = 2 - - -TModelInputForHPU = TypeVar('TModelInputForHPU', bound="ModelInputForHPU") - - -@dataclasses.dataclass(frozen=True) -class ModelInputForHPU(ModelRunnerInputBase): - """ - This base class contains metadata needed for the base model forward pass - but not metadata for possible additional steps, e.g., sampling. Model - runners that run additional steps should subclass this method to add - additional fields. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - lora_mapping: Optional["LoRAMapping"] = None - lora_requests: Optional[Set[LoRARequest]] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None - real_batch_size: Optional[int] = None - batch_size_padded: Optional[int] = None - virtual_engine: int = 0 - lora_ids: Optional[List[int]] = None - async_callback: Optional[Callable] = None - is_first_multi_step: bool = True - is_last_step: bool = True - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "real_batch_size": self.real_batch_size, - "batch_size_padded": self.batch_size_padded, - "virtual_engine": self.virtual_engine, - "lora_ids": self.lora_ids, - "is_first_multi_step": self.is_first_multi_step, - "is_last_step": self.is_last_step, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForHPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForHPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -@dataclasses.dataclass(frozen=True) -class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - # Used for speculative decoding. We do not broadcast it because it is only - # used by the driver worker. - is_prompt: Optional[bool] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - "lora_requests": self.lora_requests, - "lora_mapping": self.lora_mapping, - "multi_modal_kwargs": self.multi_modal_kwargs, - "lora_ids": self.lora_ids, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForHPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - # FIXME(kzawora): this fails for whatever reason - why? - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): - """ - Helper class for shared methods between GPU model runners. - """ - _model_input_cls: Type[TModelInputForHPU] - - def __init__( - self, - vllm_config: VllmConfig, - is_driver_worker: bool = False, - return_hidden_states: bool = False, - ): - ModelRunnerBase.__init__(self, vllm_config=vllm_config) - environment.set_model_config(self.model_config) - self.is_driver_worker = is_driver_worker - self.return_hidden_states = return_hidden_states - - self.sliding_window = (self.model_config.get_sliding_window() - if self.model_config is not None else None) - self.device_config = (self.device_config if self.device_config - is not None else DeviceConfig()) - self.device = self.device_config.device - self.enforce_eager = self.model_config.enforce_eager - self.max_num_seqs = self.scheduler_config.max_num_seqs - # NOTE(kzawora): Change that to scheduler_config.max_num_prefill_seqs - # once padding-aware scheduling gets merged - self.max_num_prefill_seqs = 64 - self.max_model_len = self.scheduler_config.max_model_len - self.max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens - self.block_size = self.cache_config.block_size - - self.pin_memory = is_pin_memory_available() - self.kv_cache_dtype = self.cache_config.cache_dtype - - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - ) - - # Lazy initialization - self.lora_manager: LRUCacheWorkerLoRAManager = None - self.model: torch.nn.Module = None - self.inc_initialized_successfully = False - - # Profiler stats - self.profiler = HabanaHighLevelProfiler() - self.profiler_counter_helper = HabanaProfilerCounterHelper() - self.seen_configs: set = set() - self._mem_margin: Optional[int] = None - HPUBucketingContext = get_bucketing_context() - self.bucketing_ctx = HPUBucketingContext(self.max_num_seqs, - self.max_num_prefill_seqs, - self.block_size, - self.max_num_batched_tokens, - False, self.max_model_len) - self.graphed_buckets: Set[Any] = set() - self._set_gc_threshold() - if self.vllm_config.cache_config.enable_prefix_caching: - os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False") - assert os.environ.get( - "VLLM_CONTIGUOUS_PA", - "").lower() != "true", "Contiguous PA doesn't support APC" - self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH - - # For multi-step scheduling - self.cached_step_outputs: List[torch.Tensor] = [] - # For delayed sampling - self.cached_step_inputs: List[ - ModelInputForHPUWithSamplingMetadata] = [] - - def _set_gc_threshold(self) -> None: - # Read https://docs.python.org/3/library/gc.html#gc.set_threshold - # for comprehensive description of gc generations. - # We can either use VLLM_GC_THR_GEN[0-2] (this has higher priority) - # to set particular generation threshold or use simpler - # VLLM_GC_THR_MULTIPLIER to multiply default values. - default_gc_thrs = list(gc.get_threshold()) - requested_gc_thrs = [0] * len(default_gc_thrs) - for i in range(len(default_gc_thrs)): - requested_gc_thrs[i] = int( - os.environ.get(f'VLLM_GC_THR_GEN{i}', default_gc_thrs[i])) - if requested_gc_thrs == default_gc_thrs: - gc_thr_multiplier = int(os.environ.get('VLLM_GC_THR_MULTIPLIER', - 2)) - requested_gc_thrs = [ - t * gc_thr_multiplier for t in default_gc_thrs - ] - gc.set_threshold(*requested_gc_thrs) - - self.skip_warmup = os.environ.get('VLLM_SKIP_WARMUP', - 'false').lower() == 'true' - - def load_model(self) -> None: - import habana_frameworks.torch.core as htcore - if self.model_config.quantization == 'inc' or \ - self.model_config.quantization == 'fp8': - htcore.hpu_set_env() - with HabanaMemoryProfiler() as m: - with HabanaMemoryProfiler() as m_getmodel: - self.model = get_model(vllm_config=self.vllm_config) - msg = ("Pre-loading model weights on " - f"{next(self.model.parameters()).device} " - f"took {m_getmodel.get_summary_string()}") - logger.info(msg) - - if self.lora_config: - assert hasattr(self.model, "embedding_modules" - ), "Model does not have embedding_modules" - assert hasattr( - self.model, "embedding_padding_modules" - ), "Model does not have embedding_padding_modules" - assert not self.lora_config.bias_enabled, \ - "Bias support in LoRA is not enabled in HPU yet." - assert not self.lora_config.fully_sharded_loras, \ - "Fully sharded LoRAs is not enabled in HPU yet." - - # Use get_text_config() in case of multimodal models - text_config = self.model_config.hf_config.get_text_config() - - self.lora_manager = LRUCacheWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, - self.vocab_size, - self.lora_config, - self.device, - self.model.embedding_modules, - self.model.embedding_padding_modules, - max_position_embeddings=text_config. - max_position_embeddings, - ) - self.model = self.lora_manager.create_lora_manager(self.model) - - if self.model_config.quantization == 'inc': - logger.info("Preparing model with INC..") - with HabanaMemoryProfiler() as m_inc: - from neural_compressor.torch.quantization import ( - FP8Config, convert, prepare) - config = FP8Config.from_json_file( - os.getenv("QUANT_CONFIG", "")) - if config.measure: - self.model = prepare(self.model, config) - elif config.quantize: - self.model = convert(self.model, config) - htcore.hpu_initialize(self.model, - mark_only_scales_as_const=True) - self.inc_initialized_successfully = True - logger.info("Preparing model with INC took %s", - m_inc.get_summary_string()) - else: - self.model = self.model.to("hpu") - htcore.mark_step() - modify_decoder_layer(self.model) - torch.hpu.synchronize() - - with HabanaMemoryProfiler() as m_wrap: - self.model = _maybe_wrap_in_hpu_graph( - self.model, vllm_config=self.vllm_config) - msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" - logger.info(msg) - - self.model_memory_usage = m.consumed_device_memory - msg = f"Loading model weights took in total {m.get_summary_string()}" - logger.info(msg) - - def _add_dummy_seq(self, seq_group_metadata_list, is_prompt): - real_batch_size = len(seq_group_metadata_list) - batch_size_padded = self.bucketing_ctx.get_padded_batch_size( - real_batch_size, is_prompt) - batch_size_padding = batch_size_padded - real_batch_size - - seq_group_metadata_list = seq_group_metadata_list.copy() - - if batch_size_padding > 0: - dummy_seq_group_metadata = self.create_dummy_seq_group_metadata( - 0, 0, is_prompt) - seq_group_metadata_list.extend(dummy_seq_group_metadata - for _ in range(batch_size_padding)) - return seq_group_metadata_list, real_batch_size, batch_size_padded - - def _maybe_wrap_in_hpu_graph(self, *args, **kwargs): - return htorch.hpu.wrap_in_hpu_graph( - HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True - ) if htorch.utils.internal.is_lazy() else HpuModelAdapter( - *args, **kwargs) - - def get_model(self) -> nn.Module: - return self.model - - def _use_graphs(self, batch_size, seq_len, is_prompt): - if self.enforce_eager: - return False - if self.skip_warmup: - return True - return (batch_size, seq_len, is_prompt) in self.graphed_buckets - - def _is_valid_bucket(self, bucket): - return bucket[0] * bucket[1] <= self.max_num_batched_tokens - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> PreparePromptMetadata: - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] - lora_index_mapping: List[List[int]] = [] - lora_prompt_mapping: List[List[int]] = [] - lora_requests: Set[LoRARequest] = set() - - seq_lens: List[int] = [] - context_lens: List[int] = [] - query_lens: List[int] = [] - prefix_block_tables: List[List[int]] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - - if len(seq_group_metadata_list) == 0: - return PreparePromptMetadata.empty() - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - computed_block_nums = seq_group_metadata.computed_block_nums - if (self.scheduler_config is not None - and self.scheduler_config.chunked_prefill_enabled - and not (computed_block_nums is None - or computed_block_nums == [])): - raise RuntimeError( - "chunked prefill cannot be used with prefix caching " - "now.") - - token_chunk_size = seq_group_metadata.token_chunk_size - seq_data = seq_group_metadata.seq_data[seq_id] - context_len = seq_data.get_num_computed_tokens() - # We should use get_len here because in case of preemption - # it contains output tokens. - seq_len = min(seq_data.get_len(), context_len + token_chunk_size) - prompt_tokens = seq_data.get_token_ids()[context_len:seq_len] - seq_lens.append(seq_len) - - # NOTE: This only works for oooooooxxx style attention. - if computed_block_nums is not None and len( - computed_block_nums) > 0 and self.sliding_window is None: - # Prefix is not supported with sliding_window - context_len = len(computed_block_nums) * self.block_size - if context_len == seq_len \ - and self.vllm_config.cache_config.enable_prefix_caching: - # Fully cached prompt - compute only last token - context_len = context_len - 1 - prompt_tokens = prompt_tokens[context_len:] - prefix_block_tables.append(computed_block_nums) - elif self.scheduler_config.chunked_prefill_enabled: - if seq_group_metadata.block_tables is not None: - # Prefill has chunked before. - block_table = seq_group_metadata.block_tables[seq_id] - prefix_block_tables.append(block_table) - else: - # The first prefill. - prefix_block_tables.append([]) - else: - prefix_block_tables.append([]) - # Right now, prefill start is always 0. However, this - # assumption can be changed once chunked prefill is introduced. - assert context_len == 0 - - # actual prompt lens - context_lens.append(context_len) - query_lens.append(seq_len - context_len) - input_tokens.append(prompt_tokens) - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - input_positions.append(list(range(context_len, seq_len))) - - mm_kwargs = seq_group_metadata.multi_modal_data - if mm_kwargs: - multi_modal_kwargs_list.append(mm_kwargs) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.append([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - slot_mapping.append([]) - block_table = seq_group_metadata.block_tables[seq_id] - - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - assert context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention") - start_idx = max(0, seq_len - self.sliding_window) - for i in range(context_len, seq_len): - if i < start_idx: - slot_mapping[-1].append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping[-1].append(slot) - - max_query_len = max(query_lens) - sum_query_len = sum(query_lens) - real_num_seqs = len(query_lens) - assert max_query_len > 0 - - max_prompt_len = max( - self.bucketing_ctx.get_padded_prompt_seq_len(max_query_len), - self.block_size) - - lora_ids: List[int] = [] - for seq_group_metadata, context_len in zip(seq_group_metadata_list, - context_lens): - lora_id = seq_group_metadata.lora_int_id - lora_ids.append(lora_id) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - lora_index_mapping += [lora_id] * max_prompt_len - lora_prompt_mapping.extend( - [lora_id] * - (max_prompt_len - if seq_group_metadata.sampling_params.prompt_logprobs else 1)) - - if any(context_lens): - assert not self.scheduler_config.chunked_prefill_enabled - # prefix caching - - max_num_block = max(len(bt) for bt in prefix_block_tables) - prefix_block_list = list( - itertools.chain.from_iterable( - bt if len(bt) == max_num_block else bt + - ([_PAD_BLOCK_ID] * (max_num_block - len(bt))) - for bt in prefix_block_tables)) - - pad_len = len(prefix_block_list) - prefix_block_list = pad_list(prefix_block_list, pad_len, - _PAD_BLOCK_ID) - - prefix_block_list_tensor = torch.tensor(prefix_block_list, - dtype=torch.long, - device=self.device) - else: - prefix_block_list_tensor = None - - input_tokens = make_tensor_with_pad(input_tokens, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - - input_positions = make_tensor_with_pad(input_positions, - max_len=max_prompt_len, - pad=0, - dtype=torch.long, - device=self.device) - - slot_mapping = make_tensor_with_pad(slot_mapping, - max_len=max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device=self.device) - - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.long, - device=self.device) - - context_lens_tensor = torch.tensor(context_lens, - dtype=torch.long, - device=self.device) - - block_indices, block_offsets = precompute_indices_and_offsets( - self.block_size, slot_mapping, True) - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - block_list=prefix_block_list_tensor, - block_mapping=None, - block_usage=None, - block_indices=block_indices, - block_offsets=block_offsets, - block_groups=None, - attn_bias=None, - seq_lens_tensor=seq_lens_tensor, - context_lens_tensor=context_lens_tensor, - num_prefills=real_num_seqs, - num_prefill_tokens=sum_query_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps= - None, # FIXME(kzawora): multi-modality will not work here - enable_kv_scales_calculation=False, - ) - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return PreparePromptMetadata(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - seq_lens=seq_lens, - query_lens=query_lens, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - multi_modal_kwargs=multi_modal_kwargs, - slot_mapping=slot_mapping, - lora_ids=lora_ids) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - output=None, - ) -> PrepareDecodeMetadata: - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] - seq_lens: List[int] = [] - block_tables: List[List[int]] = [] - lora_index_mapping: List[List[int]] = [] - lora_prompt_mapping: List[List[int]] = [] - lora_requests: Set[LoRARequest] = set() - - if len(seq_group_metadata_list) == 0: - return PrepareDecodeMetadata.empty() - lora_ids: List[int] = [] - - dummy_slots = itertools.cycle( - range(_PAD_SLOT_ID, _PAD_SLOT_ID + self.block_size)) - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - - seq_ids = list(seq_group_metadata.seq_data.keys()) - lora_id = seq_group_metadata.lora_int_id - lora_ids.append(lora_id) - - if lora_id > 0: - lora_requests.add(seq_group_metadata.lora_request) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - if output is None: - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) - - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - seq_lens.append(seq_len) - - block_table = seq_group_metadata.block_tables[seq_id] - num_fully_occupied_blocks = position // self.block_size - block_table = block_table[:num_fully_occupied_blocks + 1] - - if len(block_table) == 0: - block_number = _PAD_BLOCK_ID - else: - block_number = block_table[position // self.block_size] - if block_number == _PAD_BLOCK_ID: - slot = next(dummy_slots) - else: - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - lora_index_mapping.append(lora_id) - lora_prompt_mapping.append(lora_id) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) - - if output is None: - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - else: - real_batch_size = len(seq_group_metadata_list) - input_tokens = output[:real_batch_size] - - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - - num_decode_tokens = sum(seq_lens) - - last_block_usage = [ - slot[0] % self.block_size + 1 for slot in slot_mapping - ] - block_groups = [[i] * len(bt) for i, bt in enumerate(block_tables)] - block_usage = [[self.block_size] * (len(bt) - 1) + [lbu] - for bt, lbu in zip(block_tables, last_block_usage) - if bt] - - block_list = flatten(block_tables) - block_groups = flatten(block_groups) - block_usage = flatten(block_usage) - - assert len(block_list) == len(block_groups) - assert len(block_list) == len(block_usage) - - padding_fn = None - if self.use_contiguous_pa: - block_bucket_size = max(max(block_list) + 1, len(block_list)) - block_bucket_size = self.bucketing_ctx.get_padded_decode_num_blocks( - block_bucket_size) - indices: List[Any] - indices = [None] * block_bucket_size - for i, bid in enumerate(block_list): - indices[bid] = i - padding_fn = lambda tensor, pad_value: gather_list( - tensor, indices, pad_value) - else: - block_bucket_size = \ - self.bucketing_ctx.get_padded_decode_num_blocks( - len(block_list)) - padding_fn = lambda tensor, pad_value: pad_list( - tensor, block_bucket_size, pad_value) - - block_list = padding_fn(block_list, _PAD_BLOCK_ID) - block_groups = padding_fn(block_groups, -1) - block_usage = padding_fn(block_usage, 1) - - block_list = torch.tensor(block_list, - dtype=torch.int, - device=self.device) - block_groups = torch.tensor(block_groups, - dtype=torch.int, - device=self.device) - block_usage = torch.tensor(block_usage, - dtype=self.model_config.dtype, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - - block_indices, block_offsets = precompute_indices_and_offsets( - self.block_size, slot_mapping, False) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - block_list=block_list, - block_mapping=None, - block_usage=block_usage, - block_indices=block_indices, - block_offsets=block_offsets, - block_groups=block_groups, - attn_bias=None, - seq_lens_tensor=None, - context_lens_tensor=None, - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=num_decode_tokens, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - ) - return PrepareDecodeMetadata(input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - lora_index_mapping=lora_index_mapping, - lora_prompt_mapping=lora_prompt_mapping, - lora_requests=lora_requests, - slot_mapping=slot_mapping, - lora_ids=lora_ids) - - def prepare_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[TModelInputForHPU, SamplingMetadata]: - if len(seq_group_metadata_list) == 0: - return self._model_input_cls(), None - - input_tokens = None - input_positions = None - lora_mapping = None - lora_requests = None - multi_modal_kwargs = None - batch_type = None - seq_lens = None - query_lens = None - real_batch_size = None - batch_size_padded = None - - self.event_start = self.profiler.get_timestamp_us() - is_prompt = seq_group_metadata_list[0].is_prompt - base_event_name = 'prompt' if is_prompt else 'decode' - self.profiler.start('internal', base_event_name) - - seq_group_metadata_list, real_batch_size, batch_size_padded = ( - self._add_dummy_seq(seq_group_metadata_list, is_prompt)) - - prefill_reqs = [] - decode_reqs = [] - for seq_group_meta in seq_group_metadata_list: - if seq_group_meta.is_prompt: - prefill_reqs.append(seq_group_meta) - else: - decode_reqs.append(seq_group_meta) - - # Prepare input tensors. - ( - input_tokens, - input_positions, - prefill_attn_metadata, - seq_lens, - query_lens, - lora_index_mapping, - lora_prompt_mapping, - lora_requests, - multi_modal_kwargs, - slot_mapping, - lora_ids, - ) = self._prepare_prompt(prefill_reqs) - ( - decode_input_tokens, - decode_input_positions, - decode_attn_metadata, - decode_lora_index_mapping, - decode_lora_prompt_mapping, - decode_lora_requests, - decode_slot_mapping, - decode_lora_ids, - ) = self._prepare_decode(decode_reqs) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - seq_lens, query_lens, - self.device, - self.pin_memory) - - if not self.scheduler_config.chunked_prefill_enabled: - assert (len(prefill_reqs) and len(decode_reqs)) == 0 - - num_prefills = len(seq_lens) - num_prefill_tokens = len(input_tokens) - num_decode_tokens = len(decode_input_tokens) - - # NOTE(kzawora): Here we diverge from GPU code - we don't - # support mixed batches, so we either use decode or prefill - # inputs, without coalescing. - assert (num_prefills == 0 and num_decode_tokens > 0) or ( - num_prefills > 0 - and num_decode_tokens == 0), "HPU does not support mixed batches!" - if num_decode_tokens > 0: - input_tokens = decode_input_tokens - input_positions = decode_input_positions - slot_mapping = decode_slot_mapping - lora_index_mapping = decode_lora_index_mapping - lora_prompt_mapping = decode_lora_prompt_mapping - lora_requests = decode_lora_requests - lora_ids = decode_lora_ids - - # FIXME: We need to adjust selected_token_indices to accommodate - # for padding - max_len = input_tokens.size(1) - paddings = [max_len - q for q in query_lens] - paddings = [0] + paddings[:-1] - paddings = list(itertools.accumulate(paddings)) - paddings_prompt_logprobs = [] - for i, seq_group_metadata in enumerate(seq_group_metadata_list): - if seq_group_metadata.sampling_params.prompt_logprobs is not None \ - and seq_group_metadata.is_prompt: - paddings_prompt_logprobs += ([paddings[i]] * seq_lens[i]) - paddings = torch.tensor( - paddings_prompt_logprobs if paddings_prompt_logprobs else paddings, - dtype=sampling_metadata.selected_token_indices.dtype, - device=sampling_metadata.selected_token_indices.device) - sampling_metadata.selected_token_indices.add_(paddings) - - if self.lora_config: - lora_mapping = LoRAMapping( - **dict(index_mapping=lora_index_mapping, - prompt_mapping=lora_prompt_mapping, - is_prefill=(num_prefills > 0))) - else: - lora_mapping = None - - if (prefill_attn_metadata is not None - and decode_attn_metadata is not None): - batch_type = BatchType.MIXED - raise NotImplementedError("Mixed batch is not supported on HPU") - elif prefill_attn_metadata is not None: - batch_type = BatchType.PREFILL - else: - batch_type = BatchType.DECODE - - metadata_dict = { - "input_tokens": input_tokens, - "input_positions": input_positions, - "selected_token_indices": sampling_metadata.selected_token_indices, - "lora_requests": lora_requests, - "lora_mapping": lora_mapping, - "multi_modal_kwargs": multi_modal_kwargs, - "num_prefill_tokens": num_prefill_tokens, - "num_decode_tokens": num_decode_tokens, - "slot_mapping": slot_mapping, - "num_prefills": num_prefills, - "batch_type": batch_type, - "seq_lens": seq_lens, - "query_lens": query_lens - } - if prefill_attn_metadata is not None: - metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) - else: - assert decode_attn_metadata is not None - metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) - - attn_metadata = prefill_attn_metadata if \ - prefill_attn_metadata is not None else decode_attn_metadata - - return self._model_input_cls(input_tokens=input_tokens, - seq_lens=seq_lens, - query_lens=query_lens, - input_positions=input_positions, - attn_metadata=attn_metadata, - lora_requests=lora_requests, - lora_mapping=lora_mapping, - multi_modal_kwargs=multi_modal_kwargs, - real_batch_size=real_batch_size, - batch_size_padded=batch_size_padded, - lora_ids=lora_ids), \ - sampling_metadata - - def _seq_len(self, attn_metadata): - if attn_metadata.num_prefills != 0: - return attn_metadata.slot_mapping.size(1) - else: - return attn_metadata.block_list.numel() - - def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: - # NOTE(kzawora): To anyone working on this in the future: - # Trimming metadata is required when using HPUGraphs. - # Attention metadata is going to be hashed by PT bridge, and - # appropriate HPUGraphs will be matched based on all inputs' hash. - - # Before you put more keys in here, make sure you know their - # value type and make sure you know how it's going to be hashed. - # You can find that information in input_hash function - # in habana_frameworks/torch/hpu/graphs.py. You can also hash - # it manually with torch.hpu.graphs.input_hash(attention_metadata) - - # If you use primitive types here - they will get hashed based - # on their value. You *will* get lots of excessive graph captures - # (and an OOM eventually) if you decide to put something like - # seq_len int here. - # If you absolutely need a scalar, put it in a tensor. Tensors - # get hashed using their metadata, not their values: - # input_hash(torch.tensor(123)) == input_hash(torch.tensor(321)) - # input_hash(123) != input_hash(321) - # input_hash("abc") != input_hash("cba") - attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', - 'seq_lens_tensor', - 'context_lens_tensor', - 'block_list', - 'block_mapping', - 'block_usage', - 'slot_mapping', - 'is_prompt', - 'block_indices', - 'block_offsets', - 'block_groups', - ]) - return attention_metadata - - def create_dummy_seq_group_metadata(self, - group_id, - seq_len, - is_prompt, - lora_request=None): - sampling_params = SamplingParams(temperature=0) - num_blocks = math.ceil(seq_len / self.block_size) - seq_len = max(seq_len, 1) - if is_prompt: - input_len = seq_len - output_len = 0 - block_tables = None - else: - input_len = seq_len - 1 - output_len = 1 - block_tables = {group_id: [_PAD_BLOCK_ID] * num_blocks} - prompt_token_ids = [0] * input_len - output_token_ids = [1] * output_len - prompt_token_ids_array = array('l', prompt_token_ids) # noqa: F821 - seq_data = SequenceData(prompt_token_ids_array) - seq_data.output_token_ids = output_token_ids - return SequenceGroupMetadata(request_id=str(group_id), - is_prompt=(output_len == 0), - seq_data={group_id: seq_data}, - sampling_params=sampling_params, - block_tables=block_tables, - lora_request=lora_request) - - def profile_run(self) -> None: - num_layers = self.model_config.get_num_layers(self.parallel_config) - kv_caches = [None] * num_layers - bind_kv_cache( - self.vllm_config.compilation_config.static_forward_context, - [kv_caches]) - _, max_seq_len = self.bucketing_ctx.get_max_prompt_shape() - max_batch_size = min(self.max_num_seqs, - self.max_num_batched_tokens // max_seq_len) - self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches, - False, True) - return - - def warmup_scenario(self, - batch_size, - seq_len, - is_prompt, - kv_caches, - is_pt_profiler_run=False, - is_lora_profile_run=False) -> None: - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - scenario_name = ("warmup_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") - # This represents the maximum number of different requests - # that will have unique loras, an therefore the max amount of memory - # consumption create dummy lora request copies from the lora request - # passed in, which contains a lora from the lora warmup path. - dummy_lora_requests: List[LoRARequest] = [] - dummy_lora_requests_per_seq: List[LoRARequest] = [] - if self.lora_config and is_lora_profile_run: - assert self.lora_manager is not None - with self.lora_manager.dummy_lora_cache(): - for idx in range(self.lora_config.max_loras): - lora_id = idx + 1 - dummy_lora_request = LoRARequest( - lora_name=f"warmup_{lora_id}", - lora_int_id=lora_id, - lora_local_path="/not/a/real/path", - ) - self.lora_manager.add_dummy_lora(dummy_lora_request, - rank=LORA_WARMUP_RANK) - dummy_lora_requests.append(dummy_lora_request) - dummy_lora_requests_per_seq = [ - dummy_lora_requests[idx % len(dummy_lora_requests)] - for idx in range(batch_size) - ] - self.profiler.start('internal', scenario_name) - times = 3 if use_graphs or is_pt_profiler_run else 1 - if is_prompt: - seqs = [ - self.create_dummy_seq_group_metadata( - i, - seq_len, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) - for i in range(batch_size) - ] - else: - # FIXME: seq_len is actually number of blocks - blocks = [seq_len // batch_size for _ in range(batch_size)] - blocks[0] += seq_len % batch_size - seqs = [ - self.create_dummy_seq_group_metadata( - i, - b * self.block_size - 1, - is_prompt, - lora_request=dummy_lora_requests_per_seq[i] - if dummy_lora_requests_per_seq else None) - for i, b in enumerate(blocks) - ] - torch.hpu.synchronize() - profiler = None - if is_pt_profiler_run and self.is_driver_worker: - profiler = setup_profiler() - profiler.start() - for _ in range(times): - inputs = self.prepare_model_input(seqs) - is_single_step = \ - self.vllm_config.scheduler_config.num_scheduler_steps == 1 - if is_prompt or is_single_step: - self.execute_model(inputs, None, warmup_mode=True) - else: # decode with multi-step - inputs = dataclasses.replace(inputs, - is_first_multi_step=True, - is_last_step=False) - self.execute_model(inputs, - None, - warmup_mode=True, - num_steps=2, - seqs=seqs) - inputs = dataclasses.replace(inputs, - is_first_multi_step=False, - is_last_step=True) - self.execute_model(inputs, - None, - warmup_mode=True, - num_steps=2, - seqs=seqs) - torch.hpu.synchronize() - if profiler: - profiler.step() - if profiler: - profiler.stop() - self.profiler.end() - gc.collect() - - def remove_all_loras(self): - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_adapters() - - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_adapters(lora_requests, lora_mapping) - - def add_lora(self, lora_request: LoRARequest) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_adapter(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_adapter(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_adapter(lora_id) - - def list_loras(self) -> Set[int]: - if not self.lora_manager: - raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_adapters() - - def log_warmup(self, phase, i, max_i, batch_size, seq_len): - free_mem = format_bytes( - HabanaMemoryProfiler.current_free_device_memory()) - dim = "num_blocks" - if phase == "Prompt": - dim = "seq_len" - msg = (f"[Warmup][{phase}][{i+1}/{max_i}] " - f"batch_size:{batch_size} " - f"{dim}:{seq_len} " - f"free_mem:{free_mem}") - logger.info(msg) - - def warmup_all_buckets(self, buckets, is_prompt, kv_caches): - for i, (batch_size, seq_len) in enumerate(reversed(buckets)): - self.log_warmup('Prompt' if is_prompt else 'Decode', i, - len(buckets), batch_size, seq_len) - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - - def warmup_graphs(self, - strategy, - buckets, - is_prompt, - kv_caches, - available_mem, - starting_mem=0, - total_batch_seq=0.001): - total_mem = starting_mem - idx = 0 - phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' - num_candidates = len(buckets) - ordering : Union[Callable[[Any], Tuple[Any, Any]], \ - Callable[[Any], Tuple[Any, Any, Any]]] - if strategy == 'min_tokens': - ordering = lambda b: (b[0] * b[1], b[1], b[0]) - elif strategy == 'max_bs': - ordering = lambda b: (-b[0], b[1]) - else: - raise NotImplementedError( - f'Unsupported graph allocation strategy: {strategy}') - buckets = list(sorted(buckets, key=ordering)) - captured_all = True - for idx, (batch_size, seq_len) in enumerate(buckets): - # Graph memory usage is proportional to seq dimension in a batch - batch_seq = batch_size * seq_len if is_prompt else batch_size - mem_estimate = batch_seq / total_batch_seq * total_mem - if mem_estimate >= available_mem: - captured_all = False - continue - graphed_bucket = (batch_size, seq_len, is_prompt) - if graphed_bucket in self.graphed_buckets: - continue - self.graphed_buckets.add(graphed_bucket) - self.log_warmup(phase, idx, num_candidates, batch_size, seq_len) - with HabanaMemoryProfiler() as mem_prof: - self.warmup_scenario(batch_size, seq_len, is_prompt, kv_caches) - used_mem = align_workers(mem_prof.consumed_device_memory, - torch.distributed.ReduceOp.MAX) - available_mem -= used_mem - total_mem += used_mem - total_batch_seq += batch_seq - - return total_mem, total_batch_seq, captured_all - - def log_graph_warmup_summary(self, buckets, is_prompt, total_mem): - num_candidates = len(buckets) - phase = f'Graph/{"Prompt" if is_prompt else "Decode"}' - graphed = list(c[:2] for c in self.graphed_buckets - if c[2] == is_prompt) - if num_candidates == 0: - num_candidates = 1 - msg = (f'{phase} captured:{len(graphed)} ' - f'({100 * len(graphed) / num_candidates:.1f}%) ' - f'used_mem:{format_bytes(total_mem)} ' - f'buckets:{sorted(list(graphed))}') - logger.info(msg) - - @torch.inference_mode() - def warmup_model(self, kv_caches: List[torch.Tensor]) -> None: - max_blocks = kv_caches[0][0].size(0) - self.bucketing_ctx.generate_decode_buckets(max_blocks) - if profile := os.environ.get('VLLM_PT_PROFILE', None): - phase, bs, seq_len, graph = profile.split('_') - is_prompt = phase == 'prompt' - graphs = graph == 't' - if graphs: - self.graphed_buckets.add((int(bs), int(seq_len), is_prompt)) - self.warmup_scenario(int(bs), int(seq_len), is_prompt, kv_caches, - True) - raise AssertionError("Finished profiling") - if not htorch.utils.internal.is_lazy() and not self.enforce_eager: - cache_size_limit = 1 + 3 * ( - len(self.bucketing_ctx.prompt_buckets) + - len(self.bucketing_ctx.decode_buckets)) - torch._dynamo.config.cache_size_limit = max( - cache_size_limit, torch._dynamo.config.cache_size_limit) - # Multiply by 8 to follow the original default ratio between - # the cache_size_limit and accumulated_cache_size_limit - torch._dynamo.config.accumulated_cache_size_limit = max( - cache_size_limit * 8, - torch._dynamo.config.accumulated_cache_size_limit) - if self.skip_warmup: - logger.info("Skipping warmup...") - return - self.profiler.start('internal', 'warmup') - start_mem = HabanaMemoryProfiler.current_device_memory_usage() - start_time = time.perf_counter() - - compile_only_mode_context = functools.partial(bc.env_setting, - "PT_COMPILE_ONLY_MODE", - True) - can_use_compile_only_mode = True - try: - with compile_only_mode_context(): - pass - logger.debug("Using PT_COMPILE_ONLY_MODE.") - except KeyError: - can_use_compile_only_mode = False - logger.warning('Cannot use PT_COMPILE_ONLY_MODE. ' - 'Warmup time will be negatively impacted. ' - 'Please update Gaudi Software Suite.') - with compile_only_mode_context( - ) if can_use_compile_only_mode else contextlib.nullcontext(): - self.warmup_all_buckets(self.bucketing_ctx.prompt_buckets, True, - kv_caches) - self.warmup_all_buckets(self.bucketing_ctx.decode_buckets, False, - kv_caches) - - if not self.enforce_eager and htorch.utils.internal.is_lazy(): - assert self.mem_margin is not None, \ - ("HabanaWorker.determine_num_available_blocks needs " - "to be called before warming up the model.") - free_mem = HabanaMemoryProfiler.current_free_device_memory() - graph_free_mem = free_mem - self.mem_margin - graph_free_mem = align_workers(graph_free_mem, - torch.distributed.ReduceOp.MIN) - prompt_graph_mem_ratio = float( - os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.3')) - prompt_available_memory = (prompt_graph_mem_ratio * - graph_free_mem) - decode_available_memory = (graph_free_mem - - prompt_available_memory) - msg = ( - f"Using {format_bytes(graph_free_mem)}" - f"/{format_bytes(free_mem)} " - "of free device memory for HPUGraphs, " - f"{format_bytes(prompt_available_memory)} for prompt and " - f"{format_bytes(decode_available_memory)} for decode " - f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})") - logger.info(msg) - prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY', - 'min_tokens') - decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY', - 'max_bs') - mem_post_prompt, prompt_batch_seq, prompt_captured_all = \ - self.warmup_graphs( - prompt_strategy, self.bucketing_ctx.prompt_buckets, - True, kv_caches, prompt_available_memory) - mem_post_decode, decode_batch_seq, decode_captured_all = \ - self.warmup_graphs( - decode_strategy, self.bucketing_ctx.decode_buckets, - False, kv_caches, decode_available_memory) - - # Not all prompt buckets were captured, but all decode buckets - # were captured and we have some free graph-allocated space - # left. Let's try to use it for capturing more prompt buckets. - if (mem_post_decode + mem_post_prompt < graph_free_mem - and not prompt_captured_all and decode_captured_all): - mem_post_prompt, _, prompt_captured_all = ( - self.warmup_graphs( - prompt_strategy, self.bucketing_ctx.prompt_buckets, - True, kv_caches, - graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_prompt, prompt_batch_seq)) - - # Not all decode buckets were captured, but all prompt buckets - # were captured and we have some free graph-allocated space - # left. Let's try to use it for capturing more decode buckets. - if mem_post_decode + mem_post_prompt < graph_free_mem \ - and not decode_captured_all \ - and prompt_captured_all: - mem_post_decode, _, _ = self.warmup_graphs( - decode_strategy, self.bucketing_ctx.decode_buckets, - False, kv_caches, - graph_free_mem - mem_post_prompt - mem_post_decode, - mem_post_decode, decode_batch_seq) - - self.log_graph_warmup_summary( - self.bucketing_ctx.prompt_buckets, True, mem_post_prompt) - self.log_graph_warmup_summary( - self.bucketing_ctx.decode_buckets, False, mem_post_decode) - - end_time = time.perf_counter() - end_mem = HabanaMemoryProfiler.current_device_memory_usage() - elapsed_time = end_time - start_time - msg = ( - f"Warmup finished in {elapsed_time:.0f} secs, " - f"allocated {format_bytes(end_mem - start_mem)} of device memory") - logger.info(msg) - self.profiler.end() - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - @property - def mem_margin(self) -> Optional[int]: - return self._mem_margin - - @mem_margin.setter - def mem_margin(self, value): - self._mem_margin = value - - -def _maybe_wrap_in_hpu_graph(*args, **kwargs): - return htorch.hpu.wrap_in_hpu_graph( - HpuModelAdapter(*args, **kwargs), disable_tensor_cache=True - ) if htorch.utils.internal.is_lazy() else HpuModelAdapter(*args, **kwargs) - - -class HabanaProfilerCounterHelper: - - def __init__(self): - self.niter = 0 - self.average_real_throughput = None - self.logged_once = False - self.real_seq_lens = [] - self.prompt_seq_lens = [] - - def capture_seq_group_metadata_stats(self, seq_group_metadata_list): - self.real_seq_lens = [ - len(seq_data.prompt_token_ids) + len(seq_data.output_token_ids) - for seq_group_metadata in seq_group_metadata_list - for seq_data in seq_group_metadata.seq_data.values() - ] - self.prompt_seq_lens = [ - len(seq_data.prompt_token_ids) - for seq_group_metadata in seq_group_metadata_list - for seq_data in seq_group_metadata.seq_data.values() - ] - - def get_counter_dict(self, cache_config, duration, seq_len, - batch_size_padded, real_batch_size, is_prompt): - throughput = batch_size_padded / (duration / 1e6) - throughput_effective = real_batch_size / (duration / 1e6) - - real_max_seq_len = max(self.real_seq_lens) - real_num_tokens = sum(self.real_seq_lens) - padded_num_tokens = batch_size_padded * seq_len - batch_token_utilization = real_num_tokens / padded_num_tokens - if self.average_real_throughput is None: - self.average_real_throughput = throughput_effective - else: # https://www.heikohoffmann.de/htmlthesis/node134.html - self.average_real_throughput = self.average_real_throughput + 1 / ( - self.niter + 1) * (throughput_effective - - self.average_real_throughput) - phase = "prompt" if is_prompt else "decode" - counters = { - f'{phase}_bucket_batch_size': batch_size_padded, - f'{phase}_batch_size': real_batch_size, - f'{phase}_bucket_seq_len': seq_len, - f'{phase}_seq_len': real_max_seq_len, - f'{phase}_bucket_gen_throughput': throughput, - f'{phase}_real_gen_throughput': throughput_effective, - f'{phase}_batch_token_utilization': batch_token_utilization, - 'average_real_throughput': self.average_real_throughput, - 'engine_iteration': self.niter, - } - self.niter += 1 - if is_prompt: - prompt_bucket_in_throughput = (seq_len * batch_size_padded) / ( - duration / 1e6) - prompt_real_in_throughput = sum( - self.prompt_seq_lens) / (duration / 1e6) - counters[ - f'{phase}_bucket_in_throughput'] = prompt_bucket_in_throughput - counters[f'{phase}_real_in_throughput'] = prompt_real_in_throughput - - # KV cache might not be created yet (e.g. for profiling run) - if cache_config.num_gpu_blocks is not None and \ - cache_config.num_gpu_blocks != 0: - cache_num_blocks_used = [ - math.ceil(sl / cache_config.block_size) - for sl in self.real_seq_lens - ] - cache_total_num_blocks_used = sum(cache_num_blocks_used) - num_cache_blocks = cache_config.num_gpu_blocks - cache_total_num_free_blocks = \ - num_cache_blocks - cache_total_num_blocks_used - cache_computed_utilization = \ - cache_total_num_blocks_used / num_cache_blocks - max_blocks_per_seq = math.ceil(seq_len / cache_config.block_size) - batch_block_utilization = cache_total_num_blocks_used / ( - batch_size_padded * max_blocks_per_seq) - counters['cache_num_blocks_used'] = cache_total_num_blocks_used - counters['cache_num_free_blocks'] = cache_total_num_free_blocks - counters['cache_computed_utilization'] = cache_computed_utilization - counters[ - f'{phase}_batch_block_utilization'] = batch_block_utilization - if not self.logged_once: - counters['const_cache_num_blocks'] = cache_config.num_gpu_blocks - counters[ - 'const_gpu_memory_utilization'] = \ - cache_config.gpu_memory_utilization - counters['const_block_size'] = cache_config.block_size - self.logged_once = True - return counters - - -def unwrap_model(model): - if isinstance(model, torch._dynamo.eval_frame.OptimizedModule): - return unwrap_model(model._orig_mod) - else: - model = list(vars(model)['_modules'].values())[0] - modules = list(vars(model)['_modules'].values()) - return modules - - -class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): - """ - GPU model runner with sampling step. - """ - _model_input_cls: Type[ModelInputForHPUWithSamplingMetadata] = ( - ModelInputForHPUWithSamplingMetadata) - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, Any], - ) -> ModelInputForHPUWithSamplingMetadata: - return ( - ModelInputForHPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - )) - - @torch.inference_mode() - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForHPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - The API assumes seq_group_metadata_list is sorted by prefill -> decode. - The result tensors and data structure also batches input in prefill - -> decode order. For example, - - input_tokens[:num_prefill_tokens] contains prefill tokens. - - input_tokens[num_prefill_tokens:] contains decode tokens. - If cuda graph is required, this API automatically pads inputs. - """ - with self.profiler.record_event('internal', 'prepare_input_tensors'): - assert seq_group_metadata_list is not None - if self.profiler.enabled: - self.profiler_counter_helper.capture_seq_group_metadata_stats( - seq_group_metadata_list=seq_group_metadata_list) - model_input, sampling_metadata = self.prepare_input_tensors( - seq_group_metadata_list) - assert model_input.attn_metadata is not None - is_prompt = model_input.attn_metadata.is_prompt - - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - is_prompt=is_prompt, - virtual_engine=virtual_engine) - - def finish_measurements(self): - from neural_compressor.torch.quantization import finalize_calibration - finalize_calibration(self.model.model) - - def _num_blocks(self, attn_metadata): - if attn_metadata.block_list is None: - return 0 - return attn_metadata.block_list.numel() - - def _phase(self, attn_metadata): - phase_type: PhaseType - is_prompt = attn_metadata.is_prompt - is_prefix_prefill = is_prompt and attn_metadata.block_list is not None - if is_prompt and is_prefix_prefill: - phase_type = PhaseType.PREFIX_PREFILL - elif is_prompt and not is_prefix_prefill: - phase_type = PhaseType.PREFILL - elif not is_prompt: - phase_type = PhaseType.DECODE - else: - raise ValueError("Unrecognized pass type, likely due to malformed " - "attention metadata") - return phase_type - - def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode): - is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching - cfg: Optional[tuple] = None - assert cfg is None, "Configs changed between 2D and 3D" - if is_prefix_caching: - phase = self._phase(attn_metadata) - num_blocks = self._num_blocks(attn_metadata) - cfg = (batch_size, seq_len, num_blocks, phase) - else: - phase = 'prompt' if attn_metadata.is_prompt else 'decode' - cfg = (batch_size, seq_len, phase) - seen = cfg in self.seen_configs - self.seen_configs.add(cfg) - if not seen and not warmup_mode: - logger.warning("Configuration: %s was not warmed-up!", - (phase.value, batch_size, seq_len, - num_blocks) if is_prefix_caching else - (phase, batch_size, seq_len)) - - def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], - is_prompt: bool): - ''' - This is a helper function to create the mask for lora computations. - Lora Mask is needed to ensure we match the correct lora weights for the - for the request. - For Prompt phase we have - lora_mask with shape (batch_size * seq_len, max_loras * max_rank) - lora_logits_mask with shape (batch_size, max_loras * max_rank) - For Decode phase we have both - lora_mask and lora_logits_mask with shape - (batch_size, max_loras * max_rank) - ''' - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - lora_index = 0 - - if self.lora_config: - if is_prompt: - lora_mask = torch.zeros( - input_tokens.shape[0] * input_tokens.shape[1], - (self.lora_config.max_loras) *\ - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - lora_logits_mask = torch.zeros( - input_tokens.shape[0], (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - ones = torch.ones(input_tokens.shape[1], - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - logit_ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - - for i in range(len(lora_ids)): - if lora_ids[i] == 0: - continue - lora_index = self.lora_manager._adapter_manager.\ - lora_index_to_id.index(lora_ids[i]) - start_row = i * input_tokens.shape[1] - end_row = start_row + input_tokens.shape[1] - start_col = lora_index * self.lora_config.max_lora_rank - end_col = start_col + self.lora_config.max_lora_rank - lora_mask[start_row:end_row, start_col:end_col] = ones - lora_logits_mask[i, start_col:end_col] = logit_ones - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_logits_mask.to('hpu') - else: - lora_mask = torch.zeros(input_tokens.shape[0], - (self.lora_config.max_loras) * - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - ones = torch.ones(1, - self.lora_config.max_lora_rank, - dtype=self.lora_config.lora_dtype) - for i in range(len(lora_ids)): - if lora_ids[i] == 0: - continue - lora_index = self.lora_manager._adapter_manager.\ - lora_index_to_id.index(lora_ids[i]) - start_pos = lora_index * self.lora_config.max_lora_rank - end_pos = start_pos + self.lora_config.max_lora_rank - lora_mask[i, start_pos:end_pos] = ones - lora_mask = lora_mask.to('hpu') - lora_logits_mask = lora_mask - - return lora_mask, lora_logits_mask - - def _get_seq_ids(self, model_input): - return ([ - sg.seq_ids[0] for sg in model_input.sampling_metadata.seq_groups - ]) - - def _pad_to_max_num_seqs(self, tensor, value): - padding_needed = self.max_num_seqs - tensor.size(0) - if padding_needed: - padding = torch.full((padding_needed, *tensor.shape[1:]), - value, - device=tensor.device, - dtype=tensor.dtype) - tensor = torch.cat([tensor, padding]) - return tensor - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForHPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - warmup_mode=False, - seqs=None, - ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: - VLLM_DELAYED_SAMPLING = envs.VLLM_HPU_USE_DELAYED_SAMPLING - use_delayed_sampling = VLLM_DELAYED_SAMPLING and not warmup_mode - assert not (use_delayed_sampling and num_steps != 1), \ - 'Delayed sampling is not compatible with MSS!' - assert model_input.input_tokens is not None - if use_delayed_sampling and not model_input.is_prompt and \ - self.is_driver_worker: - num_cached = len(self.cached_step_outputs) - assert num_cached > 0 - cur_seq_ids = self._get_seq_ids(model_input) - cur_seq_id_pos = { - sid: idx - for idx, sid in enumerate(cur_seq_ids) if sid >= 0 - } - htorch.core.mark_step() - for i in range(num_cached): - prev_seq_ids = self._get_seq_ids(self.cached_step_inputs[i]) - target_indices = [ - cur_seq_id_pos.get(psi, -1) for psi in prev_seq_ids - ] - padding = self.cached_step_outputs[i].size(0) - len( - target_indices) - target_indices.extend([-1] * padding) - target_indices = torch.tensor( - target_indices, - device=model_input.input_tokens.device, - dtype=model_input.input_tokens.dtype) - model_input.input_tokens.index_copy_( - 0, target_indices, self.cached_step_outputs[i]) - htorch.core.mark_step() - - if not model_input.is_first_multi_step: - if not model_input.is_last_step: - # not first or last multi-step - return [] - # last multi-step - output = self._decode_sampler_outputs( - model_input) if self.is_driver_worker else [] - torch.hpu.synchronize() - if model_input.is_first_multi_step: - # first multi-step - if self.lora_config: - assert model_input.lora_requests is not None - assert model_input.lora_mapping is not None - self.set_active_loras(model_input.lora_requests, - model_input.lora_mapping) - # Rank!=0 workers has is_prompt==None - if use_delayed_sampling and not model_input.is_prompt and \ - model_input.input_tokens.size(1) == 1: - if self.is_driver_worker: - model_kwargs_broadcast_data = { - "input_tokens": model_input.input_tokens - } - broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) - input_tokens = model_input.input_tokens - - else: - model_kwargs_broadcast_data = broadcast_tensor_dict(src=0) - input_tokens = model_kwargs_broadcast_data["input_tokens"] - else: - input_tokens = model_input.input_tokens - input_positions = model_input.input_positions - attn_metadata = model_input.attn_metadata - sampling_metadata = model_input.sampling_metadata - real_batch_size = model_input.real_batch_size - batch_size_padded = model_input.batch_size_padded - assert input_tokens is not None - assert input_positions is not None - assert sampling_metadata is not None - assert attn_metadata is not None - is_prompt = attn_metadata.is_prompt - assert is_prompt is not None - batch_size = input_tokens.size(0) - seq_len = self._seq_len(attn_metadata) - use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) - - lora_mask: torch.Tensor = None - lora_logits_mask: torch.Tensor = None - if self.lora_config: - assert model_input.lora_ids is not None - lora_mask, lora_logits_mask = self.create_lora_mask( - input_tokens, model_input.lora_ids, - attn_metadata.is_prompt) - - execute_model_kwargs = { - "input_ids": input_tokens, - "positions": input_positions, - "attn_metadata": self.trim_attn_metadata(attn_metadata), - "intermediate_tensors": intermediate_tensors, - "lora_mask": lora_mask, - "virtual_engine": model_input.virtual_engine, - **(model_input.multi_modal_kwargs or {}), - } - if htorch.utils.internal.is_lazy(): - execute_model_kwargs.update( - {"bypass_hpu_graphs": not use_graphs}) - - htorch.core.mark_step() - if self.is_driver_worker: - model_event_name = ("model_" - f"{'prompt' if is_prompt else 'decode'}_" - f"bs{batch_size}_" - f"seq{seq_len}_" - f"graphs{'T' if use_graphs else 'F'}") - else: - model_event_name = 'model_executable' - if num_steps > 1 or use_delayed_sampling: - # in case of multi-step scheduling - # we only want to pythonize in the last step - sampling_metadata.skip_sampler_cpu_output = True - self.model.sampler.include_gpu_probs_tensor = True - cache_orig_output_tokens_len: List[Dict] = [] - - def try_revert_dummy_output_tokens(): - if len(cache_orig_output_tokens_len) > 0: - # Reuse the original output token ids length - for i, seq_group_metadata in enumerate( - seq_group_metadata_list): - for j, data in seq_group_metadata.seq_data.items(): - orig_output_tokens_len = \ - cache_orig_output_tokens_len[i][j] - data.output_token_ids = \ - data.output_token_ids[:orig_output_tokens_len] - - for i in range(num_steps): - if i != 0 and not self.is_driver_worker: - broadcast_data = broadcast_tensor_dict(src=0) - if 'early_exit' in broadcast_data and broadcast_data[ - 'early_exit']: - return [output] if num_steps == 1 else [] - execute_model_kwargs.update({ - "input_ids": - broadcast_data["input_ids"], - "positions": - broadcast_data["positions"], - "attn_metadata": - self.trim_attn_metadata( - broadcast_data["attn_metadata"]) - }) - with self.profiler.record_event('internal', model_event_name): - hidden_states = self.model.forward( - **execute_model_kwargs, - selected_token_indices=sampling_metadata. - selected_token_indices) - - if self.lora_config: - LoraMask.setLoraMask( - lora_logits_mask.index_select( - 0, sampling_metadata.selected_token_indices)) - - # Compute the logits. - with self.profiler.record_event( - 'internal', - ('compute_logits_' - f'{"prompt" if is_prompt else "decode"}_bs' - f'{batch_size}_' - f'seq{seq_len}')): - if num_steps == 1: - sampling_metadata.selected_token_indices = None - logits = self.model.compute_logits(hidden_states, - sampling_metadata) - htorch.core.mark_step() - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - continue - - if use_delayed_sampling: - fake_output = self._delayed_sampler_outputs(model_input) - - with self.profiler.record_event( - 'internal', ('sample_' - f'{"prompt" if is_prompt else "decode"}_' - f'bs{batch_size}_' - f'seq{seq_len}')): - output = self.model.sample( - logits=logits, - sampling_metadata=sampling_metadata, - ) - if num_steps > 1: - output = output.sampled_token_ids - self.cached_step_outputs.append(output) - if use_delayed_sampling and self.is_driver_worker: - self._patch_prev_output() - output = self._pad_to_max_num_seqs( - output.sampled_token_ids, DUMMY_TOKEN_ID) - self.cached_step_outputs.append(output) - self.cached_step_inputs.append(model_input) - htorch.core.mark_step() - if model_input.async_callback is not None: - model_input.async_callback() - if i < num_steps - 1: - if i == 0: - if model_input.async_callback is not None: - ctx = model_input.async_callback.keywords[ # type: ignore - "ctx"] - seq_group_metadata_list = \ - ctx.seq_group_metadata_list - elif seqs is not None: - seq_group_metadata_list = seqs - else: - raise RuntimeError( - "seq_group_metadata_list is uninitialized") - for i, seq_group_metadata in enumerate( - seq_group_metadata_list): - # Skip empty steps - seq_group_metadata.state.current_step += ( - num_steps - 2) - # Cache the original output token ids - cache_orig_output_tokens_len.append({}) - for j, data in seq_group_metadata.seq_data.items(): - cache_orig_output_tokens_len[i][j] = \ - len(data.output_token_ids) - for seq_group_metadata in seq_group_metadata_list: - for data in seq_group_metadata.seq_data.values(): - max_output_len = sampling_metadata.seq_groups[ - 0].sampling_params.max_tokens - if len(data.output_token_ids) < max_output_len - 1: - # add a place holder for prepare_decode - # arbitrary value, this could be any token - dummy_token = (540, ) - data.output_token_ids += (dummy_token) - else: - broadcast_tensor_dict({'early_exit': True}, - src=0) - if num_steps == 1: - return [output] - else: - try_revert_dummy_output_tokens() - return [] - - result = self._prepare_decode(seq_group_metadata_list, - output=output) - execute_model_kwargs.update({ - "input_ids": - result.input_tokens, - "positions": - result.input_positions, - "attn_metadata": - self.trim_attn_metadata(result.attn_metadata) - }) - model_kwargs_broadcast_data = { - "input_ids": result.input_tokens, - "positions": result.input_positions, - "attn_metadata": vars(result.attn_metadata) - } - broadcast_tensor_dict(model_kwargs_broadcast_data, src=0) - else: - try_revert_dummy_output_tokens() - - if self.is_driver_worker and self.profiler.enabled: - # Stop recording 'execute_model' event - self.profiler.end() - event_end = self.profiler.get_timestamp_us() - counters = self.profiler_counter_helper.get_counter_dict( - cache_config=self.cache_config, - duration=event_end - self.event_start, - seq_len=seq_len, - batch_size_padded=batch_size_padded, - real_batch_size=real_batch_size, - is_prompt=is_prompt) - self.profiler.record_counter(self.event_start, counters) - if num_steps == 1: - if self.return_hidden_states: - # we only need to pass hidden states of most recent token - assert model_input.sampling_metadata is not None - if model_input.is_prompt: - output.prefill_hidden_states = hidden_states - output.hidden_states = hidden_states - if use_delayed_sampling: - if self.is_driver_worker: - return [fake_output] - else: - return [] - - return [output] if self.is_driver_worker else [] - else: - return [] - return output if type(output) is list else [output] - - def _delayed_sampler_outputs(self, model_input): - next_token_ids = [[DUMMY_TOKEN_ID]] * len( - model_input.sampling_metadata.seq_groups) - sampler_output = self._make_decode_output( - next_token_ids, model_input.sampling_metadata.seq_groups) - return sampler_output - - def _decode_sampler_outputs(self, model_input): - use_async_out_proc = model_input.async_callback is not None - sampler_outputs = [] - num_outputs = len(self.cached_step_outputs) - for i in range(num_outputs): - next_token_ids = self.cached_step_outputs.pop(0) - next_token_ids = next_token_ids.cpu().tolist() - sampler_output = self._make_decode_output( - next_token_ids, model_input.sampling_metadata.seq_groups) - sampler_outputs.append(sampler_output) - - if i < num_outputs - 1 and use_async_out_proc: - assert model_input.async_callback is not None - ctx = model_input.async_callback.keywords[ # type: ignore - "ctx"] - ctx.append_output( - outputs=[sampler_output], - seq_group_metadata_list=ctx.seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=False) - model_input.async_callback() - - if use_async_out_proc: - return [sampler_outputs[-1]] - else: - return sampler_outputs - - def _make_decode_output( - self, - next_token_ids: List[List[int]], - seq_groups: List[SequenceGroupToSample], - ) -> SamplerOutput: - zero_logprob = Logprob(0.0) - sampler_outputs = [] - batch_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group.seq_ids - seq_outputs = [] - for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx][0] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - batch_idx += 1 - sampler_outputs.append( - CompletionSequenceGroupOutput(seq_outputs, None)) - return SamplerOutput(sampler_outputs) - - def shutdown_inc(self): - can_finalize_inc = False - from contextlib import suppress - with suppress(AttributeError): - can_finalize_inc = (self.model_config.quantization == 'inc') and \ - (self.model.model is not None) and \ - self.inc_initialized_successfully and \ - not getattr(self, "_is_inc_finalized", False) - if can_finalize_inc: - from neural_compressor.torch.quantization import ( - finalize_calibration) - finalize_calibration(self.model.model) - self._is_inc_finalized = True - - def __del__(self): - self.shutdown_inc() - - def _patch_prev_output(self): - assert len(self.cached_step_inputs) == len(self.cached_step_outputs), \ - f'''Inputs and outputs are out of sync! - {len(self.cached_step_inputs)} vs {len(self.cached_step_outputs)}''' - if len(self.cached_step_inputs) == 0: - return - model_input = self.cached_step_inputs.pop(0) - delayed_output = self.cached_step_outputs.pop(0).cpu().squeeze( - -1).tolist() - ctx = model_input.async_callback.keywords["ctx"] # type: ignore - # If there's no output to patch with, which is usually the case when - # we're starting a new request after all requests are completed. - if len(ctx.output_queue) == 0: - return - assert len( - ctx.output_queue) == 1, 'There should be exactly 1 output waiting!' - output_data = ctx.output_queue[0] - assert len(output_data.outputs) == 1 - for fake_out, real_out in zip(output_data.outputs[0], delayed_output): - fake_out.samples[0].output_token = real_out - for sg, real_out in zip(output_data.seq_group_metadata_list, - delayed_output): - assert len(sg.seq_data) == 1 - seq_data = list(sg.seq_data.values())[0] - # This is a hack. Assigning output_token_ids triggers - # a cache recomputation and we only need to update the last token - seq_data.output_token_ids_array[-1] = real_out - seq_data._cached_all_token_ids[-1] = real_out diff --git a/vllm/worker/hpu_worker.py b/vllm/worker/hpu_worker.py deleted file mode 100644 index 6d76ea499a90..000000000000 --- a/vllm/worker/hpu_worker.py +++ /dev/null @@ -1,484 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -############################################################################### -# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company -############################################################################### - -import contextlib -import gc -import os -from typing import List, Optional, Set, Tuple, Type - -import habana_frameworks.torch as htorch # noqa:F401 -import torch -import torch.distributed -from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes - -import vllm.envs as envs -from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest -from vllm.utils import bind_kv_cache -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.hpu_model_runner import HPUModelRunner -from vllm.worker.model_runner_base import ModelRunnerBase -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class HPUWorker(LocalOrDistributedWorkerBase): - """A worker class that executes (a partition of) the model on a HPU. - - Each worker is associated with a single HPU. The worker is responsible for - maintaining the KV cache and executing the model on the HPU. In case of - distributed inference, each worker is assigned a partition of the model. - """ - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - model_runner_cls: Optional[Type[ModelRunnerBase]] = None, - ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - if self.is_driver_worker: - assert self.rank == 0, "The driver worker must have rank 0." - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils import init_cached_hf_modules - init_cached_hf_modules() - - self.model_runner: HPUModelRunner = HPUModelRunner( - vllm_config=vllm_config, is_driver_worker=is_driver_worker) - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[HPUCacheEngine] - # Initialize gpu_cache as pooling models don't initialize kv_caches - self.hpu_cache: Optional[List[List[torch.Tensor]]] = None - # Torch profiler. Enabled and configured through env vars: - # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace - if envs.VLLM_TORCH_PROFILER_DIR: - torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - torch_profiler_trace_dir) - self.profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.HPU, - ], - with_stack=True, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, use_gzip=True)) - else: - self.profiler = None - - def start_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.start() - - def stop_profile(self): - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() - - def _set_env_vars(self): - local_rank = self.local_rank - if self.parallel_config.world_size == 1: - local_rank = -1 - import os - os.environ["LOCAL_RANK"] = str(local_rank) - os.environ["ID"] = str(local_rank) - os.environ["WORLD_SIZE"] = str(self.parallel_config.world_size) - os.environ["RANK"] = str(self.rank) - - def init_device(self) -> None: - if self.device_config.device.type == "hpu": - self.device = torch.device("hpu") - torch.hpu.set_device(self.device) - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - if self.model_config.quantization == 'inc': - self._set_env_vars() - init_worker_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method, - self.local_rank) - # Set random seed. - set_random_seed(self.model_config.seed) - - def load_model(self): - self.model_runner.load_model() - - def execute_model( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[List[SamplerOutput]]: - # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501 - # VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501 - # VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501 - # VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL - will log cpu fallbacks per engine step, always, even if there were none # noqa:E501 - log_graph_compilation_all = os.environ.get( - 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL', '0') != '0' - log_graph_compilation = os.environ.get( - 'VLLM_HPU_LOG_STEP_GRAPH_COMPILATION', - '0') != '0' or log_graph_compilation_all - log_cpu_fallbacks_all = os.environ.get( - 'VLLM_HPU_LOG_STEP_CPU_FALLBACKS_ALL', '0') != '0' - log_cpu_fallbacks = os.environ.get('VLLM_HPU_LOG_STEP_CPU_FALLBACKS', - '0') != '0' or log_cpu_fallbacks_all - if (log_graph_compilation or log_cpu_fallbacks) and \ - execute_model_req is not None: - from habana_frameworks.torch.hpu.metrics import metric_localcontext - seq_group_metadata_list = execute_model_req.seq_group_metadata_list - is_prompt = any([ - seq_group_metadata.is_prompt - for seq_group_metadata in seq_group_metadata_list - ]) - max_context_len = max([ - max([ - len(v.prompt_token_ids) + len(v.output_token_ids) - for v in seq_group_metadata.seq_data.values() - ]) for seq_group_metadata in seq_group_metadata_list - ]) # whoa, that's some spicy stuff right here - max_num_blocks = ( - (max_context_len - 1) // self.cache_config.block_size) + 1 - input_stats = (f'is_prompt: {is_prompt}, ' - f'num_seqs: {len(seq_group_metadata_list)}, ' - f'max_context_len: {max_context_len}, ' - f'max_num_blocks {max_num_blocks}') - gc_ctx = metric_localcontext( - "graph_compilation" - ) if log_graph_compilation else contextlib.nullcontext() - cpu_fallback_ctx = metric_localcontext( - "cpu_fallback" - ) if log_cpu_fallbacks else contextlib.nullcontext() - with gc_ctx as gc_local_metric, \ - cpu_fallback_ctx as cpu_fallback_local_metric: - output = LocalOrDistributedWorkerBase.execute_model( - self, execute_model_req) - if (log_graph_compilation and gc_local_metric.stats()[0][1] - > 0) or log_graph_compilation_all: - msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: " - f"{gc_local_metric.stats()}, {input_stats}") - logger.warning(msg) - if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] - > 0) or log_cpu_fallbacks_all: - msg = ("VLLM_HPU_STEP_CPU_FALLBACK: " - f"{cpu_fallback_local_metric.stats()}, {input_stats}") - logger.warning(msg) - - return output - - output = LocalOrDistributedWorkerBase.execute_model( - self, execute_model_req) - return output - - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - Tip: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with HabanaMemoryProfiler() as m: - self.model_runner.profile_run() - torch.hpu.synchronize() - msg = ("Model profiling run " - f"took {m.get_summary_string()}") - logger.info(msg) - # At this point we should've allocated the maximum workspace for all - # recipes we will use the extra memory for graphs/blocks - free_hpu_memory = torch.hpu.mem_get_info()[0] - - cache_block_size = self.get_cache_block_size_bytes() - graph_reserved_mem = (float( - os.environ.get('VLLM_GRAPH_RESERVED_MEM', '0.1')) - if not self.model_config.enforce_eager else 0) - graph_headroom = 1 - graph_reserved_mem - available_hpu_memory = free_hpu_memory * \ - self.cache_config.gpu_memory_utilization - hpu_memory_margin = free_hpu_memory * ( - 1 - self.cache_config.gpu_memory_utilization) - self.model_runner.mem_margin = hpu_memory_margin - cache_size_bytes = available_hpu_memory * graph_headroom - graph_headroom_bytes = available_hpu_memory * (1 - graph_headroom) - msg = ( - f"Free device memory: {format_bytes(free_hpu_memory)}, " - f"{format_bytes(available_hpu_memory)} usable " - f"(gpu_memory_utilization={self.cache_config.gpu_memory_utilization})," - f" {format_bytes(graph_headroom_bytes)} reserved for HPUGraphs " - f"(VLLM_GRAPH_RESERVED_MEM={graph_reserved_mem}), " - f"{format_bytes(cache_size_bytes)} reserved for KV cache") - logger.info(msg) - num_hpu_blocks = int(cache_size_bytes // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_hpu_blocks = max(num_hpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - self.model_runner.bucketing_ctx.num_hpu_blocks = num_hpu_blocks - - if self.model_runner.lora_manager: - self.model_runner.remove_all_loras() - - gc.collect() - return num_hpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Allocate GPU and CPU KV cache with the specified number of blocks. - - This also warms up the model, which may record CUDA graphs. - """ - raise_if_cache_size_invalid( - num_gpu_blocks, self.cache_config.block_size, - self.model_config.max_model_len, - self.parallel_config.pipeline_parallel_size) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - with HabanaMemoryProfiler() as m: - self._init_cache_engine() - torch.hpu.synchronize() - msg = ("Initializing cache engine " - f"took {m.get_summary_string()}") - logger.info(msg) - self._warm_up_model() - - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ - HPUCacheEngine(self.cache_config, self.model_config, - self.parallel_config, self.device_config) - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - self.hpu_cache = [ - self.cache_engine[ve].gpu_cache - for ve in range(self.parallel_config.pipeline_parallel_size) - ] - bind_kv_cache(self.compilation_config.static_forward_context, - self.hpu_cache) - - def _warm_up_model(self) -> None: - # NOTE(kzawora): We should use virtual engine index here - # for pipeline parallelism. Using 0 for now. - assert self.hpu_cache is not None - self.model_runner.warmup_model(self.hpu_cache[0]) - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - def finish_measurements(self): - self.model_runner.finish_measurements() - - @property - def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - return self.hpu_cache - - @torch.inference_mode() - def prepare_worker_input( - self, execute_model_req: ExecuteModelRequest) -> WorkerInput: - virtual_engine = execute_model_req.virtual_engine - num_seq_groups = len(execute_model_req.seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) - - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - ) - - @torch.inference_mode() - def execute_worker(self, worker_input: WorkerInput) -> None: - virtual_engine = worker_input.virtual_engine - # Issue cache operations. - if (worker_input.blocks_to_swap_in is not None - and worker_input.blocks_to_swap_in.numel() > 0): - self.cache_engine[virtual_engine].swap_in( - worker_input.blocks_to_swap_in) - if (worker_input.blocks_to_swap_out is not None - and worker_input.blocks_to_swap_out.numel() > 0): - self.cache_engine[virtual_engine].swap_out( - worker_input.blocks_to_swap_out) - if (worker_input.blocks_to_copy is not None - and worker_input.blocks_to_copy.numel() > 0): - self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def remove_lora(self, lora_id: int) -> bool: - return self.model_runner.remove_lora(lora_id) - - def pin_lora(self, lora_id: int) -> bool: - return self.model_runner.pin_lora(lora_id) - - def list_loras(self) -> Set[int]: - return self.model_runner.list_loras() - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def list_prompt_adapters(self) -> Set[int]: - raise NotImplementedError( - "Prompt Adapter is not implemented for HPU backend.") - - def shutdown_inc(self): - self.model_runner.shutdown_inc() - - @property - def max_model_len(self) -> int: - return self.model_config.max_model_len - - @property - def vocab_size(self) -> int: - return self.model_runner.vocab_size - - def get_cache_block_size_bytes(self) -> int: - """Get the size of the KV cache block size in bytes. - """ - return HPUCacheEngine.get_cache_block_size(self.cache_config, - self.model_config, - self.parallel_config) - - -def init_worker_distributed_environment( - parallel_config: ParallelConfig, - rank: int, - distributed_init_method: Optional[str] = None, - local_rank: int = -1, -) -> None: - """Initialize the distributed environment.""" - init_distributed_environment(parallel_config.world_size, - rank, - distributed_init_method, - local_rank, - backend='hccl') - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch world " - "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - torch.distributed.init_process_group( - backend="hccl", - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) - - # A small all_reduce for warmup & checking conformance. - dummy_tensor_hpu = torch.ones(1).to('hpu') - torch.distributed.all_reduce(dummy_tensor_hpu) - assert dummy_tensor_hpu.item() == parallel_config.world_size - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - - -def raise_if_cache_size_invalid(num_gpu_blocks, block_size, max_model_len, - pipeline_parallel_size) -> None: - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = block_size * (num_gpu_blocks // pipeline_parallel_size) - if max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") - - -class HPUCacheEngine(CacheEngine): - - def _allocate_kv_cache( - self, - num_blocks: int, - device: str, - ) -> List[Tuple[torch.Tensor, torch.Tensor]]: - """Allocates KV cache on the specified device.""" - kv_cache_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) - kv_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] - for _ in range(self.num_attention_layers): - key_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) - value_cache = torch.zeros(kv_cache_shape, - dtype=self.dtype, - device=device) - kv_layer = (key_cache, value_cache) - kv_cache.append(kv_layer) - return kv_cache diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 82db6617ba55..5a185e7451ad 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -45,10 +45,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) -from vllm.prompt_adapter.layers import PromptAdapterMapping -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.prompt_adapter.worker_manager import ( - LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, @@ -95,8 +91,6 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_mapping: Optional["LoRAMapping"] = None lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None - prompt_adapter_mapping: Optional[PromptAdapterMapping] = None - prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None @@ -113,8 +107,6 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, - "prompt_adapter_mapping": self.prompt_adapter_mapping, - "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, @@ -164,8 +156,6 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, - "prompt_adapter_mapping": self.prompt_adapter_mapping, - "prompt_adapter_requests": self.prompt_adapter_requests, "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, @@ -212,8 +202,6 @@ def simple_reinit(self): self.lora_index_mapping.clear() # type: ignore self.lora_prompt_mapping.clear() # type: ignore self.lora_requests.clear() # type: ignore - self.prompt_adapter_index_mapping.clear() # type: ignore - self.prompt_adapter_prompt_mapping.clear() # type: ignore def __init__( self, @@ -252,11 +240,6 @@ def __init__( lora_prompt_mapping: Optional[List[List[int]]] = None, lora_requests: Optional[Set[LoRARequest]] = None, - # Prompt adapter inputs. - prompt_adapter_index_mapping: Optional[List[int]] = None, - prompt_adapter_prompt_mapping: Optional[List[int]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - # Multi-modal inputs. multi_modal_kwargs: Optional[MultiModalKwargs] = None, multi_modal_placeholder_maps: Optional[Dict[ @@ -360,18 +343,6 @@ def __init__( else: self.lora_requests.clear() - if prompt_adapter_index_mapping: - self.prompt_adapter_index_mapping = \ - prompt_adapter_index_mapping - else: - self.prompt_adapter_index_mapping.clear() - - if prompt_adapter_prompt_mapping: - self.prompt_adapter_prompt_mapping = \ - prompt_adapter_prompt_mapping - else: - self.prompt_adapter_prompt_mapping.clear() - else: self.input_tokens = input_tokens or [] self.inputs_embeds = inputs_embeds @@ -390,12 +361,6 @@ def __init__( self.lora_prompt_mapping = lora_prompt_mapping or [] self.lora_requests = lora_requests or set() - self.prompt_adapter_index_mapping = ( - prompt_adapter_index_mapping or []) - self.prompt_adapter_prompt_mapping = ( - prompt_adapter_prompt_mapping or []) - - self.prompt_adapter_request = prompt_adapter_request self.multi_modal_kwargs = multi_modal_kwargs self.multi_modal_placeholder_maps = multi_modal_placeholder_maps self.prefix_cache_hit = prefix_cache_hit @@ -485,7 +450,6 @@ def __init__(self, # Compute functions for each sequence group. # WARNING: The order of the functions matters! self.per_seq_group_compute_fns = [ - self._compute_prompt_adapter_input, self._compute_multi_modal_input, ] @@ -496,8 +460,6 @@ def __init__(self, self.sliding_window = self.runner.sliding_window self.block_size = self.runner.block_size self.enable_lora = self.runner.lora_config is not None - self.enable_prompt_adapter = (self.runner.prompt_adapter_config - is not None) # Attention metadata inputs. if self.attn_backend is not None: @@ -693,34 +655,6 @@ def _compute_lora_input(self, inter_data: InterDataForSeqGroup, else: inter_data.lora_prompt_mapping.append([]) - def _compute_prompt_adapter_input( - self, inter_data: InterDataForSeqGroup, - seq_group_metadata: SequenceGroupMetadata): - """If prompt adapter is enabled, compute index and prompt mapping. - """ - # Note that when is_prompt=True, we expect only one sequence - # in the group. - if not self.enable_prompt_adapter: - return - - prompt_adapter_id = seq_group_metadata.prompt_adapter_id - if prompt_adapter_id <= 0 or not inter_data.is_prompt: - return - - # We expect only one sequence in the group when is_prompt=True. - assert inter_data.n_seqs == 1 - query_len = inter_data.query_lens[0] - inter_data.prompt_adapter_request = ( - seq_group_metadata.prompt_adapter_request) - - num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens - inter_data.prompt_adapter_index_mapping = [ - prompt_adapter_id - ] * num_tokens + [0] * (query_len - num_tokens) - inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( - query_len if seq_group_metadata.sampling_params - and seq_group_metadata.sampling_params.prompt_logprobs else 1) - def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, seq_group_metadata: SequenceGroupMetadata): """If multi-modal data is given, add it to the input.""" @@ -1009,29 +943,6 @@ def build(self) -> ModelInputForGPU: prompt_mapping=lora_prompt_mapping, is_prefill=not self.decode_only)) - # Prompt adapter data. - prompt_adapter_requests: Set[PromptAdapterRequest] = set() - prompt_adapter_mapping = None - if self.enable_prompt_adapter: - prompt_adapter_requests = set( - data.prompt_adapter_request for data in self.inter_data_list - if data.prompt_adapter_request is not None) - prompt_adapter_index_mapping = flatten_2d_lists([ - inter_data.prompt_adapter_index_mapping - for inter_data in self.inter_data_list - ]) - if cuda_graph_pad_size: - prompt_adapter_index_mapping.extend( - itertools.repeat(0, cuda_graph_pad_size)) - prompt_adapter_prompt_mapping = flatten_2d_lists([ - inter_data.prompt_adapter_prompt_mapping - for inter_data in self.inter_data_list - ]) - prompt_adapter_mapping = PromptAdapterMapping( - prompt_adapter_index_mapping, - prompt_adapter_prompt_mapping, - ) - # Multi-modal data. multi_modal_kwargs_list = [ data.multi_modal_kwargs for data in self.inter_data_list @@ -1051,9 +962,7 @@ def build(self) -> ModelInputForGPU: lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids, - finished_requests_ids=self.finished_requests_ids, - prompt_adapter_mapping=prompt_adapter_mapping, - prompt_adapter_requests=prompt_adapter_requests) + finished_requests_ids=self.finished_requests_ids) class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): @@ -1112,6 +1021,10 @@ def __init__( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) + self.cross_layer_shared_graph_block_tables = np.zeros( + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), + dtype=np.int32) + # Attention-free but stateful models like Mamba need a placeholder attn # backend, as the attention metadata is needed to manage internal state. # However we must bypass attention selection altogether for some models @@ -1144,7 +1057,6 @@ def __init__( self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None - self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None self.sampler = get_sampler() set_cpu_offload_max_bytes( @@ -1203,14 +1115,7 @@ def load_model(self) -> None: logger.info("Model loading took %.4f GiB and %.6f seconds", self.model_memory_usage / GiB_bytes, time_after_load - time_before_load) - if self.prompt_adapter_config: - self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, self.device, - self.prompt_adapter_config) - self.model = ( - self.prompt_adapter_manager.create_prompt_adapter_manager( - self.model)) + if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): @@ -1246,6 +1151,7 @@ def save_tensorized_model( TensorizerLoader.save_model( self.model, tensorizer_config=tensorizer_config, + model_config=self.model_config, ) def get_max_block_per_batch(self) -> int: @@ -1461,40 +1367,6 @@ def list_loras(self) -> Set[int]: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_adapters() - def remove_all_prompt_adapters(self): - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.remove_all_adapters() - - def set_active_prompt_adapters( - self, prompt_adapter_requests: Set[PromptAdapterRequest], - prompt_adapter_mapping: PromptAdapterMapping) -> None: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - self.prompt_adapter_manager.set_active_adapters( - prompt_adapter_requests, prompt_adapter_mapping) - - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> Set[int]: - if not self.prompt_adapter_manager: - raise RuntimeError("PromptAdapter is not enabled.") - return self.prompt_adapter_manager.list_adapters() - @torch.inference_mode() def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: """Cuda graph capture a model. @@ -1586,6 +1458,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: if get_tensor_model_parallel_rank() == 0: compilation_cases = tqdm( list(compilation_cases), + disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graph shapes") for batch_size, use_inputs_embeds in compilation_cases: attn_metadata = ( @@ -1603,13 +1476,6 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self.set_active_loras(set([dummy_lora_request]), lora_mapping) - if self.prompt_adapter_config: - prompt_adapter_mapping = PromptAdapterMapping( - [-1] * batch_size, - [-1] * batch_size, - ) - self.set_active_prompt_adapters( - set(), prompt_adapter_mapping) graph_runner = CUDAGraphRunner( self.model, self.attn_backend.get_name(), self.attn_state.graph_clone(batch_size), @@ -1770,13 +1636,6 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) - self.attn_state.begin_forward(model_input) # Currently cuda graph is only supported by the decode phase. @@ -1926,24 +1785,32 @@ def execute_model( if model_input.inputs_embeds is not None: if self.is_driver_worker: - sampled = broadcast_tensor_dict( - {"token_ids": output.sampled_token_ids}) + sampled_token_ids = [] + valid_outputs = [] + for sequence_group_output in output.outputs: + if len(sequence_group_output.samples) == 0: + continue + assert len(sequence_group_output.samples) == 1 + valid_outputs.append(sequence_group_output) + sampled_token_ids.append( + sequence_group_output.samples[0].output_token) + sampled_token_ids = torch.tensor(sampled_token_ids).to( + self.device) + sampled_token_ids = broadcast_tensor_dict( + {"sampled_token_ids": + sampled_token_ids})["sampled_token_ids"] else: - sampled = broadcast_tensor_dict() - if sampled["token_ids"] is not None: - sampled_token_embeds = self.model.get_input_embeddings( - sampled["token_ids"].squeeze(1)) + sampled_token_ids = broadcast_tensor_dict( + )["sampled_token_ids"] + if len(sampled_token_ids) > 0: + sampled_token_embeds = \ + self.model.get_input_embeddings(sampled_token_ids) if self.is_driver_worker: self.sampler.include_gpu_probs_tensor = \ orig_include_gpu_probs - - output.sampled_token_embeds = sampled_token_embeds - - for token_embed, sequence_group_output in zip( - output.sampled_token_embeds, output.outputs): - assert len(sequence_group_output.samples) == 1 - sequence_group_output.samples[ - 0].output_embed = token_embed + for i, sequence_group_output in enumerate(valid_outputs): + sequence_group_output.samples[0].output_embed = \ + sampled_token_embeds[i] if not self.is_driver_worker: return [] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index d567ce4a6e78..feca8a7a1e74 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -12,6 +12,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.interfaces_base import is_pooling_model +from vllm.pooling_params import PoolingTask from vllm.sequence import IntermediateTensors, SequenceGroupMetadata if TYPE_CHECKING: @@ -188,7 +190,6 @@ def __init__( self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config # Map of request_id -> generator used for seeded random sampling @@ -223,6 +224,13 @@ def prepare_model_input( def get_model(self) -> nn.Module: raise NotImplementedError + def get_supported_pooling_tasks(self) -> list[PoolingTask]: + model = self.get_model() + if not is_pooling_model(model): + return [] + + return list(model.pooler.get_supported_tasks()) + def execute_model( self, model_input: T, diff --git a/vllm/worker/multi_step_hpu_worker.py b/vllm/worker/multi_step_hpu_worker.py deleted file mode 100644 index f0210c13c755..000000000000 --- a/vllm/worker/multi_step_hpu_worker.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -############################################################################### -# Copyright (C) 2025 Habana Labs, Ltd. an Intel Company -############################################################################### - -import dataclasses -from typing import Dict, Optional, Tuple - -import torch - -from vllm.distributed import broadcast_tensor_dict -from vllm.sequence import ExecuteModelRequest -from vllm.worker.hpu_model_runner import ModelInputForHPU -from vllm.worker.hpu_worker import HPUWorker -from vllm.worker.worker_base import WorkerInput - - -class MultiStepHPUWorker(HPUWorker): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.cached_model_input: Optional[ModelInputForHPU] = None - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[ModelInputForHPU, WorkerInput, Dict[str, torch.Tensor]]: - """ - Get the driver input and broadcast it to other workers. - """ - assert self.is_driver_worker - assert execute_model_req.virtual_engine == 0 - - is_first_multi_step = execute_model_req.is_first_multi_step - is_last_step = execute_model_req.is_last_step - - if is_first_multi_step: - # on first step we prepare the worker input and model input normally - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - worker_input = dataclasses.replace( - worker_input, - num_steps=execute_model_req.num_lookahead_slots + 1) - model_input: ModelInputForHPU = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( - model_input, - async_callback=execute_model_req.async_callback) - else: - # on subsequent steps we reuse the worker input and model input - assert self.cached_model_input is not None - model_input = self.cached_model_input - worker_input = WorkerInput() - - model_input = dataclasses.replace( - model_input, - is_first_multi_step=is_first_multi_step, - is_last_step=is_last_step) - - if self.do_metadata_broadcast: - if is_first_multi_step: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update( - model_input.as_broadcastable_tensor_dict()) - broadcast_tensor_dict(broadcast_data, src=0) - else: - broadcast_data = { - "is_first_multi_step": is_first_multi_step, - "is_last_step": is_last_step, - } - broadcast_tensor_dict(broadcast_data, src=0) - - # Returning empty dict here to keep this compatible with - # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` - return model_input, worker_input, {} - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[ModelInputForHPU, WorkerInput, Dict[str, - torch.Tensor]]]: - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. - broadcast_tensor_dict({}, src=0) - return None - model_input, worker_input, _ = self._get_driver_input_and_broadcast( - execute_model_req) - if model_input.is_first_multi_step: - self.cached_model_input = model_input - return model_input, worker_input, {} - else: - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - if len(broadcast_data) == 2: - assert self.cached_model_input is not None - self.cached_model_input = dataclasses.replace( - self.cached_model_input, - is_first_multi_step=broadcast_data["is_first_multi_step"], - is_last_step=broadcast_data["is_last_step"]) - empty_worker_input = WorkerInput() - return self.cached_model_input, empty_worker_input, {} - - worker_input = WorkerInput.from_broadcasted_tensor_dict( - broadcast_data) - model_input = ( - self.model_runner. - make_model_input_from_broadcasted_tensor_dict(broadcast_data)) - self.cached_model_input = model_input - return model_input, worker_input, {} diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 0680e60b52a1..2aa910bdff6b 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -288,9 +288,6 @@ def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): assert fmi.lora_requests is not None assert len(fmi.lora_requests) == 0 assert fmi.attn_metadata is not None - assert fmi.prompt_adapter_mapping is None - assert fmi.prompt_adapter_requests is not None - assert len(fmi.prompt_adapter_requests) == 0 assert fmi.multi_modal_kwargs is not None assert len(fmi.multi_modal_kwargs) == 0 diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py deleted file mode 100644 index ed9f00166615..000000000000 --- a/vllm/worker/multi_step_tpu_worker.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -from typing import Dict, Optional, Tuple - -import torch - -from vllm.distributed import broadcast_tensor_dict -from vllm.sequence import ExecuteModelRequest -from vllm.worker.tpu_model_runner import ModelInputForTPU -from vllm.worker.tpu_worker import TPUWorker -from vllm.worker.worker_base import WorkerInput - - -class MultiStepTPUWorker(TPUWorker): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.cached_model_input: Optional[ModelInputForTPU] = None - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest - ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: - assert self.is_driver_worker - assert execute_model_req.virtual_engine == 0 - - is_first_multi_step = execute_model_req.is_first_multi_step - is_last_step = execute_model_req.is_last_step - if is_first_multi_step: - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - worker_input = dataclasses.replace( - worker_input, - num_steps=execute_model_req.num_lookahead_slots + 1) - model_input: ModelInputForTPU = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - - if execute_model_req.async_callback: - model_input = dataclasses.replace( - model_input, - async_callback=execute_model_req.async_callback) - else: - assert self.cached_model_input is not None - model_input = self.cached_model_input - worker_input = WorkerInput() - model_input = dataclasses.replace( - model_input, - is_first_multi_step=is_first_multi_step, - is_last_step=is_last_step) - - if self.do_metadata_broadcast: - if is_first_multi_step: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update( - model_input.as_broadcastable_tensor_dict()) - broadcast_tensor_dict(broadcast_data, src=0) - else: - broadcast_data = { - "is_first_multi_step": is_first_multi_step, - "is_last_step": is_last_step, - } - broadcast_tensor_dict(broadcast_data, src=0) - - # Retuning empty dict here to keep this compatible with - # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` - return model_input, worker_input, {} - - def prepare_input( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, - torch.Tensor]]]: - if self.is_driver_worker: - if execute_model_req is None: - if self.do_metadata_broadcast: - broadcast_tensor_dict({}, src=0) - return None - - model_input, worker_input, _ = self._get_driver_input_and_broadcast( - execute_model_req) - if model_input.is_first_multi_step: - self.cached_model_input = model_input - return model_input, worker_input, {} - else: - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None - - if len(broadcast_data) == 2: - assert self.cached_model_input is not None - self.cached_model_input = dataclasses.replace( - self.cached_model_input, - is_first_multi_step=broadcast_data["is_first_multi_step"], - is_last_step=broadcast_data["is_last_step"]) - empty_worker_input = WorkerInput() - return self.cached_model_input, empty_worker_input, {} - - worker_input = WorkerInput.from_broadcasted_tensor_dict( - broadcast_data) - model_input = ( - self.model_runner. - make_model_input_from_broadcasted_tensor_dict(broadcast_data)) - self.cached_model_input = model_input - return model_input, worker_input, {} diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 662bde6bc07b..4e1408300fb8 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -156,7 +156,7 @@ def init_distributed_environment(self): rank=self.rank, local_rank=self.local_rank, distributed_init_method=self.distributed_init_method, - backend="gloo", + backend=current_platform.dist_backend, ) ensure_model_parallel_initialized( diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index f80955f71a5a..e49783ad9b24 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import torch @@ -10,6 +10,7 @@ from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.logger import init_logger +from vllm.model_executor.models.interfaces_base import VllmModelForPooling from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams @@ -63,13 +64,6 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) - if self.prompt_adapter_config: - assert model_input.prompt_adapter_requests is not None - assert model_input.prompt_adapter_mapping is not None - self.set_active_prompt_adapters( - model_input.prompt_adapter_requests, - model_input.prompt_adapter_mapping) - # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata @@ -195,7 +189,16 @@ def _prepare_pooling( seq_groups: List[Tuple[List[int], PoolingParams]] = [] for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + assert pooling_params is not None + assert (task := pooling_params.task) is not None, ( + "You did not set `task` in the API") + + model = cast(VllmModelForPooling, self.model) + to_update = model.pooler.get_pooling_updates(task) + to_update.apply(pooling_params) + seq_groups.append((seq_ids, pooling_params)) seq_data: Dict[int, SequenceData] = {} diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py deleted file mode 100644 index 336bc0bcec36..000000000000 --- a/vllm/worker/tpu_model_runner.py +++ /dev/null @@ -1,909 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import enum -import time -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Type, Union) -from unittest.mock import patch - -import numpy as np -import torch -import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.runtime as xr - -from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.config import VllmConfig -from vllm.forward_context import get_forward_context, set_forward_context -from vllm.logger import init_logger -from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.model_loader import get_model -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SequenceGroupMetadata, SequenceOutput) -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, - _add_attn_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -# Here we utilize the behavior that out-of-bound index is ignored. -# FIXME(woosuk): Find a more reliable way to prevent possible bugs. -_PAD_SLOT_ID = 1_000_000_000 -# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. -_ENABLE_TOP_P = False -# FIXME(woosuk): A temporary hack to support `n > 1`. -# This can significantly affect the performance if too large. -_MAX_NUM_SAMPLES = 128 - - -class ExecutionMode(enum.Enum): - PREFILL = enum.auto() - DECODE = enum.auto() - PREFIX_PREFILL = enum.auto() - - def is_prefill(self) -> bool: - return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) - - -@dataclass(frozen=True) -class ModelInputForTPU(ModelRunnerInputBase): - token_ids: torch.Tensor - position_ids: torch.Tensor - attn_metadata: AttentionMetadata - input_lens: torch.Tensor - t: torch.Tensor - p: torch.Tensor - num_samples: int - n: List[int] - seq_groups: List[List[int]] - is_first_multi_step: bool = True - is_last_step: bool = True - virtual_engine: int = 0 - async_callback: Optional[Callable] = None - - def as_broadcastable_tensor_dict( - self) -> Dict[str, Union[int, torch.Tensor]]: - tensor_dict = { - "token_ids": self.token_ids, - "position_ids": self.position_ids, - "input_lens": self.input_lens, - "t": self.t, - "p": self.p, - "num_samples": self.num_samples, - "n": self.n, - "seq_groups": self.seq_groups, - "is_first_multi_step": self.is_first_multi_step, - "is_last_step": self.is_last_step, - "virtual_engine": self.virtual_engine, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["ModelInputForTPU"], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForTPU": - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): - - def __init__( - self, - vllm_config: VllmConfig, - is_driver_worker: bool = False, - ): - ModelRunnerBase.__init__(self, vllm_config=vllm_config) - self.is_driver_worker = is_driver_worker - - self.block_size = self.cache_config.block_size - self.max_num_blocks_per_seq = (self.model_config.max_model_len // - self.block_size) - self.block_tables = np.zeros( - (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), - dtype=np.int32) - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - self.model_config.is_attention_free, - False, - ) - self.cached_step_outputs: List[torch.Tensor] = [] - - smem_size = 512 * 1024 - block_table_size = 4 * self.block_tables.size - if block_table_size >= smem_size: - logger.warning( - "The max_model_len (%d) is too large. This may degrade the " - "performance due to the insufficient smem size. Consider " - "setting --max-model-len to a smaller value, like %d.", - self.model_config.max_model_len, - self.model_config.max_model_len / - (block_table_size / smem_size)) - - def load_model(self) -> None: - self.device = self.device_config.device - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank): - model = get_model(vllm_config=self.vllm_config) - model = model.eval() - xm.wait_device_ops() - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) - - def get_model(self) -> nn.Module: - return self.model.model - - def _dummy_run( - self, - batch_size: int, - seq_len: int, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - exec_mode: ExecutionMode, - ) -> None: - exec_mode = ExecutionMode(exec_mode) - if exec_mode.is_prefill(): - seq_len = (seq_len + 15) // 16 * 16 - token_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - if exec_mode == ExecutionMode.PREFILL: - attn_metadata = self.attn_backend.make_metadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - block_tables=None, - context_lens=None, - effective_query_lens=None, - ) - else: - context_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - block_tables = torch.tensor(self.block_tables[:batch_size], - dtype=torch.int32, - device=self.device) - effective_query_lens = torch.ones_like(context_lens) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - block_tables=block_tables, - context_lens=context_lens, - effective_query_lens=effective_query_lens, - ) - else: - assert seq_len == 1 - token_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - position_ids = torch.zeros((batch_size, seq_len), - dtype=torch.int32, - device=self.device) - slot_mapping = torch.zeros((batch_size, seq_len), - dtype=torch.int64, - device=self.device) - block_tables = torch.zeros( - (batch_size, self.max_num_blocks_per_seq), - dtype=torch.int32, - device=self.device) - context_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - input_lens = torch.ones((batch_size, ), - dtype=torch.int32, - device=self.device) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size * seq_len, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - block_tables=block_tables, - context_lens=context_lens, - ) - t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 - - # NOTE(woosuk): There are two stages of compilation: torch.compile and - # XLA compilation. Using `mark_dynamic` can reduce the torch.compile - # overhead by reusing the FX graph for different shapes. - # However, the XLA graph will still require static shapes and needs to - # be re-compiled for every different shapes. This overhead is inevitable - # in the first run, but can be skipped afterwards as we cache the XLA - # graphs in the disk (VLLM_XLA_CACHE_PATH). - if exec_mode.is_prefill(): - # Prefll - torch._dynamo.mark_dynamic(token_ids, 1) - torch._dynamo.mark_dynamic(position_ids, 1) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) - else: - # Decode - torch._dynamo.mark_dynamic(token_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(input_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) - torch._dynamo.mark_dynamic(t, 0) - torch._dynamo.mark_dynamic(p, 0) - # Dummy run. - with set_forward_context(attn_metadata, self.vllm_config, 0): - self.model(token_ids, position_ids, input_lens, t, p, num_samples, - kv_caches) - - def warmup_model( - self, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> None: - # Prefill - logger.info("Compiling the model with different input shapes...") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - self._dummy_run(batch_size, - seq_len, - kv_caches, - exec_mode=ExecutionMode.PREFILL) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - num_tokens = batch_size * seq_len - if num_tokens >= self.scheduler_config.max_num_batched_tokens: - break - seq_len = seq_len * 2 - - end = time.time() - logger.info("Compilation for prefill done in %.2f s.", end - start) - - # Prefix prefill - if self.cache_config.enable_prefix_caching: - logger.info("Compiling the model with different input shapes for " - "prefix prefill...") - start = time.time() - for batch_size in [1]: - seq_len = 16 - while seq_len <= self.model_config.max_model_len: - self._dummy_run(batch_size, - seq_len, - kv_caches, - exec_mode=ExecutionMode.PREFIX_PREFILL) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, - seq_len) - num_tokens = batch_size * seq_len - if (num_tokens - >= self.scheduler_config.max_num_batched_tokens): - break - seq_len = seq_len * 2 - end = time.time() - logger.info("Compilation for prefix prefill done in %.2f s.", - end - start) - - # Decode - start = time.time() - seq_len = 1 - batch_size = 8 # Must be in sync with _get_padded_batch_size() - while True: - self._dummy_run(batch_size, - seq_len, - kv_caches, - exec_mode=ExecutionMode.DECODE) - xm.wait_device_ops() - logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if batch_size >= self.scheduler_config.max_num_seqs: - break - batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 - - end = time.time() - logger.info("Compilation for decode done in %.2f s.", end - start) - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - prompt_lens: List[int] = [] - context_lens: List[int] = [] - slot_mapping: List[int] = [] - - for batch_idx, seq_group_metadata in enumerate( - seq_group_metadata_list): - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - # Could include output tokens when a request is preempted. - prompt_tokens = seq_data.get_token_ids() - seq_len = len(prompt_tokens) - - num_computed_blocks = len(seq_group_metadata.computed_block_nums) - num_computed_tokens = num_computed_blocks * self.block_size - if num_computed_tokens > 0: - prompt_tokens = prompt_tokens[num_computed_tokens:] - context_lens.append(seq_len) - else: - context_lens.append(0) - - prompt_len = len(prompt_tokens) - prompt_lens.append(prompt_len) - - input_tokens.extend(prompt_tokens) - input_positions.extend(range(num_computed_tokens, seq_len)) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - for i in range(num_computed_tokens, seq_len): - block_number = block_table[i // self.block_size] - block_offset = i % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - if num_computed_tokens > 0: - self.block_tables[batch_idx, :len(block_table)] = block_table - - # Add paddings to EACH prompt to the smallest power of 2 that is - # greater than or equal to the prompt length. - # We pad the seq_len to reduce the compilation overhead. - # We execute each prompt individually (i.e., with batch_size 1) - # because the FlashAttention kernel does not support ragged inputs. - # TODO(woosuk): Use SplashAttention to support ragged inputs. - padded_prompt_len = _get_padded_prefill_len(prompt_len) - num_paddings = padded_prompt_len - prompt_len - input_tokens += [0] * num_paddings - input_positions += [0] * num_paddings - slot_mapping += [_PAD_SLOT_ID] * num_paddings - - assert len(prompt_lens) > 0 - num_prefills = len(prompt_lens) - input_tokens = torch.tensor(input_tokens, - dtype=torch.int32, - device="cpu") - input_positions = torch.tensor(input_positions, - dtype=torch.int32, - device="cpu") - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.int64, - device="cpu") - prompt_lens = torch.tensor(prompt_lens, - dtype=torch.int32, - device="cpu") - context_lens = torch.tensor(context_lens, - dtype=torch.int32, - device="cpu") - block_tables = torch.tensor(self.block_tables[:num_prefills], - dtype=torch.int32, - device="cpu") - attn_metadata = self.attn_backend.make_metadata( - num_prefills=num_prefills, - num_prefill_tokens=0, # NOTE: This is not used. - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - block_tables=block_tables, - context_lens=context_lens, - effective_query_lens=prompt_lens, - ) - return input_tokens, input_positions, attn_metadata, prompt_lens - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[List[int]] = [] - input_positions: List[List[int]] = [] - slot_mapping: List[List[int]] = [] - context_lens: List[int] = [] - - batch_idx = 0 - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append([generation_token]) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append([position]) - context_lens.append(seq_len) - - assert seq_group_metadata.block_tables is not None - block_table = seq_group_metadata.block_tables[seq_id] - self.block_tables[batch_idx, :len(block_table)] = block_table - batch_idx += 1 - - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append([slot]) - - batch_size = _get_padded_batch_size(batch_idx) - num_paddings = batch_size - batch_idx - input_tokens = input_tokens + [[0]] * num_paddings - input_positions = input_positions + [[0]] * num_paddings - slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings - context_lens = context_lens + [0] * num_paddings - - input_tokens = torch.tensor(input_tokens, - dtype=torch.int32, - device="cpu") - input_positions = torch.tensor(input_positions, - dtype=torch.int32, - device="cpu") - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.int64, - device="cpu") - context_lens = torch.tensor(context_lens, - dtype=torch.int32, - device="cpu") - block_tables = torch.tensor(self.block_tables[:batch_size], - dtype=torch.int32, - device="cpu") - input_lens = torch.tensor([1] * batch_size, - dtype=torch.int32, - device="cpu") - attn_metadata = self.attn_backend.make_metadata( - num_prefills=0, - num_prefill_tokens=0, - num_decode_tokens=batch_size, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - block_tables=block_tables, - context_lens=context_lens, - ) - return input_tokens, input_positions, attn_metadata, input_lens - - def _prepare_sample( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - padded_batch_size: int, - ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: - assert len(seq_group_metadata_list) > 0 - t = [] - p = [] - n = [] - for seq_group_metadata in seq_group_metadata_list: - sampling_params = seq_group_metadata.sampling_params - t.append(sampling_params.temperature) - if sampling_params.top_p != 1 and not _ENABLE_TOP_P: - raise NotImplementedError( - "Top-p sampling is currently disabled for the TPU backend " - "due to performance issues.") - p.append(sampling_params.top_p) - if sampling_params.top_k > 0: - raise NotImplementedError( - "Top-k sampling is currently disabled for the TPU backend " - "due to performance issues.") - if sampling_params.n > _MAX_NUM_SAMPLES: - raise NotImplementedError( - f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " - "backend.") - n.append(sampling_params.n) - if sampling_params.logprobs is not None: - raise NotImplementedError( - "logprobs is not currently supported by the TPU backend.") - if sampling_params.prompt_logprobs is not None: - raise NotImplementedError( - "prompt_logprobs is not currently supported by the TPU " - "backend.") - - # Repeat the sampling params if the seq group has multiple seqs. - num_seqs = len(seq_group_metadata.seq_data) - t += [t[-1]] * (num_seqs - 1) - p += [p[-1]] * (num_seqs - 1) - n += [n[-1]] * (num_seqs - 1) - - num_paddings = padded_batch_size - len(t) - t += [1.0] * num_paddings - p += [1.0] * num_paddings - - t = torch.tensor(t, dtype=torch.float32, device="cpu") - p = torch.tensor(p, dtype=torch.float32, device="cpu") - return t, p, n - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None, - ) -> ModelInputForTPU: - del finished_requests_ids # Unused. - assert virtual_engine == 0 - assert len(seq_group_metadata_list) > 0 - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - if is_prompt: - inputs = self._prepare_prompt(seq_group_metadata_list) - else: - inputs = self._prepare_decode(seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, input_lens = inputs - padded_batch_size = input_tokens.shape[0] - t, p, n = self._prepare_sample(seq_group_metadata_list, - padded_batch_size) - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 - - seq_groups = [ - list(metadata.seq_data.keys()) - for metadata in seq_group_metadata_list - ] - return ModelInputForTPU(input_tokens, input_positions, attn_metadata, - input_lens, t, p, num_samples, n, seq_groups) - - def make_model_input_from_broadcasted_tensor_dict( - self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: - model_input = ModelInputForTPU.from_broadcasted_tensor_dict( - tensor_dict, attn_backend=self.attn_backend) - return model_input - - @torch.no_grad() - def execute_model( - self, - model_input: ModelInputForTPU, - kv_caches: Optional[List[Any]], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> List[SamplerOutput]: - assert intermediate_tensors is None - if not model_input.is_first_multi_step: - if not model_input.is_last_step: - return [] - - use_async_out_proc = model_input.async_callback is not None - sampler_outputs = [] - num_outputs = len(self.cached_step_outputs) - for i in range(num_outputs): - next_token_ids = self.cached_step_outputs.pop(0) - next_token_ids = next_token_ids.cpu().tolist() - sampler_output = _make_decode_output(next_token_ids, - model_input.seq_groups) - sampler_outputs.append(sampler_output) - - if i < num_outputs - 1 and use_async_out_proc: - assert model_input.async_callback is not None - ctx = model_input.async_callback.keywords[ # type: ignore - "ctx"] - ctx.append_output( - outputs=[sampler_output], - seq_group_metadata_list=ctx.seq_group_metadata_list, - scheduler_outputs=ctx.scheduler_outputs, - is_async=False, - is_last_step=False, - is_first_step_output=i == 0) - model_input.async_callback() - if use_async_out_proc: - return [sampler_outputs[-1]] - else: - return sampler_outputs - - is_prompt = model_input.attn_metadata.num_prefills > 0 - if is_prompt: - assert num_steps == 1 - # NOTE(woosuk): Since the FlashAttention kernel does not support - # ragged inputs, we split the prompts into different batches and - # process them separately. This is a temporary hack that should be - # optimized by using SplashAttention. - orig_slot_mapping = model_input.attn_metadata.slot_mapping - orig_block_tables = model_input.attn_metadata.block_tables - orig_context_lens = model_input.attn_metadata.context_lens - orig_effective_query_lens = \ - model_input.attn_metadata.effective_query_lens - batch_size = model_input.input_lens.shape[0] - start_idx = 0 - next_token_ids = [] - for i in range(batch_size): - # Get the actual prefill_len. - prefill_len = model_input.input_lens[i:i + 1].item() - prefill_len = _get_padded_prefill_len(prefill_len) - end_idx = start_idx + prefill_len - - token_ids = model_input.token_ids[None, start_idx:end_idx].to( - self.device) - position_ids = model_input.position_ids[None, - start_idx:end_idx].to( - self.device) - attn_metadata = model_input.attn_metadata - attn_metadata.num_prefills = 1 - attn_metadata.slot_mapping = orig_slot_mapping[ - None, start_idx:end_idx].to(self.device) - if orig_context_lens[i].item() > 0: - attn_metadata.context_lens = orig_context_lens[i:i + 1].to( - self.device) - attn_metadata.block_tables = orig_block_tables[ - i].unsqueeze(0).to(self.device) - attn_metadata.effective_query_lens = \ - orig_effective_query_lens[i:i + 1].to(self.device) - else: - attn_metadata.context_lens = None - attn_metadata.block_tables = None - attn_metadata.effective_query_lens = None - input_lens = model_input.input_lens[i:i + 1].to(self.device) - t = model_input.t[i:i + 1].to(self.device) - p = model_input.p[i:i + 1].to(self.device) - with set_forward_context(model_input.attn_metadata, - self.vllm_config, - model_input.virtual_engine): - output_token_ids = self.model(token_ids, position_ids, - input_lens, t, p, - model_input.num_samples, - kv_caches) - next_token_ids.append(output_token_ids[0]) - start_idx = end_idx - - if model_input.async_callback is not None: - model_input.async_callback() - # Retrieve the outputs to CPU. - next_token_ids = [ - output_token_ids.cpu().tolist() - for output_token_ids in next_token_ids - ] - - # NOTE(woosuk): Minimal code to construct the sampler outputs. - # The TPU backend does not reuse the sampler, since the TPU backend - # does not support advanced sampling parameters such as logprobs. - zero_logprob = Logprob(0.0) - sampler_outputs = [] - for i, seq_group in enumerate(model_input.seq_groups): - seq_ids = seq_group - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - seq_outputs = [] - for j in range(model_input.n[i]): - next_token_id = next_token_ids[i][j] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - sampler_outputs.append( - CompletionSequenceGroupOutput(seq_outputs, None)) - return [SamplerOutput(sampler_outputs)] - else: - token_ids = model_input.token_ids.to(self.device) - position_ids = model_input.position_ids.to(self.device) - attn_metadata = model_input.attn_metadata - attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( - self.device) - attn_metadata.block_tables = attn_metadata.block_tables.to( - self.device) - attn_metadata.context_lens = attn_metadata.context_lens.to( - self.device) - t = model_input.t.to(self.device) - p = model_input.p.to(self.device) - input_lens = model_input.input_lens.to(self.device) - for i in range(num_steps): - slot_mapping = attn_metadata.slot_mapping - with set_forward_context(model_input.attn_metadata, - self.vllm_config, - model_input.virtual_engine): - output_token_ids = self.model(token_ids, position_ids, - input_lens, t, p, - model_input.num_samples, - kv_caches) - self.cached_step_outputs.append(output_token_ids) - - if i < num_steps - 1: - # Prepare the inputs for the next step. - token_ids = output_token_ids.unsqueeze(dim=1).int() - position_ids = position_ids + 1 - attn_metadata.context_lens = attn_metadata.context_lens + 1 - - block_tables = attn_metadata.block_tables - block_number = block_tables.gather( - 1, - position_ids.long() // self.block_size) - block_offset = position_ids % self.block_size - - is_padding = slot_mapping == _PAD_SLOT_ID - slot_mapping = block_number * self.block_size + block_offset - slot_mapping = slot_mapping.long() - slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, - slot_mapping) - attn_metadata.slot_mapping = slot_mapping - - if model_input.async_callback is not None: - model_input.async_callback() - - if num_steps > 1: - return [] - # Retrieve the outputs to CPU. - next_token_ids = self.cached_step_outputs.pop(0) - next_token_ids = next_token_ids.cpu().tolist() - sampler_output = _make_decode_output(next_token_ids, - model_input.seq_groups) - return [sampler_output] - - -class ModelWrapper(nn.Module): - - def __init__(self, model: nn.Module): - super().__init__() - self.model = model - - def forward( - self, - token_ids: torch.Tensor, - position_ids: torch.Tensor, - input_lens: torch.Tensor, - t: torch.Tensor, - p: torch.Tensor, - num_samples: int, - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> torch.Tensor: - """Executes the forward pass of the model and samples the next token. - - Args: - token_ids: The input token IDs of shape [batch_size, seq_len]. - position_ids: The input position IDs of shape [batch_size, seq_len]. - input_lens: The actual input lengths of shape [batch_size]. - t: The sampling temperature of shape [batch_size]. - p: The top-p probability of shape [batch_size]. - num_samples: Number of samples to draw from each logits vector. - kv_caches: The key and value caches. They can be None during the - memory profiling at initialization. - """ - batch_size, seq_len = token_ids.shape - # Calculate the positions to sample from. - start_indices = torch.arange( - batch_size, dtype=torch.int32, device=input_lens.device) * seq_len - logits_indices = start_indices + input_lens - 1 - attn_metadata = get_forward_context().attn_metadata - - # FIXME(woosuk): This is a temporary hack to avoid using the existing - # sampler and sampling metadata. - sampling_metadata = SamplingMetadata( - seq_groups=[], - selected_token_indices=logits_indices, - categorized_sample_indices={}, - num_prompts=attn_metadata.num_prefills, - ) - - # Skip this in memory profiling at initialization. - if kv_caches[0][0].numel() > 0: - # index_copy_(slot_mapping) only works when the inserted dimension - # is 0. However, the KV cache in the Pallas backend has the shape - # [num_kv_heads, num_blocks, block_size, head_size]. To make it - # work, we need to flatten the first three dimensions and modify - # the slot_mapping accordingly. - num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape - slot_mapping = attn_metadata.slot_mapping - slot_mapping = slot_mapping.flatten() - head_indices = torch.arange(0, - num_kv_heads, - device=slot_mapping.device, - dtype=slot_mapping.dtype) - head_indices *= block_size * num_blocks - slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( - -1, num_kv_heads) - slot_mapping = slot_mapping + head_indices.view(1, -1) - slot_mapping = slot_mapping.flatten() - attn_metadata.slot_mapping = slot_mapping - - hidden_states = self.model(token_ids, position_ids) - hidden_states = hidden_states.flatten(0, 1) - logits = self.model.compute_logits(hidden_states, sampling_metadata) - - # Argmax sampling. - argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) - argmax_token_ids = argmax_token_ids.repeat(1, num_samples) - - # Zero temperature means greedy decoding. Avoid division by zero. - nonzero_t = torch.where(t != 0, t, 1.0) - logits = logits / nonzero_t.unsqueeze(dim=1) - if _ENABLE_TOP_P: - logits = _apply_top_p(logits, p.unsqueeze(dim=1)) - - # Random sampling. - probs = torch.softmax(logits, dim=-1, dtype=torch.float32) - sampled_token_ids = torch.multinomial(probs, - num_samples, - replacement=True) - if num_samples == 1: - argmax_token_ids = argmax_token_ids.squeeze(dim=-1) - sampled_token_ids = sampled_token_ids.squeeze(dim=-1) - next_token_ids = torch.where(t != 0, sampled_token_ids, - argmax_token_ids) - return next_token_ids - - -def _get_padded_prefill_len(x: int) -> int: - # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence - # length to be a multiple of 16. We pad the prompt length to the nearest - # multiple of 16. This is also good for performance. - if x <= 16: - return 16 - return 1 << (x - 1).bit_length() - - -def _get_padded_batch_size(batch_size: int) -> int: - # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. - # To meet this requirement in the simplest way, we set the minimal batch - # size to 8. - if batch_size <= 8: - return 8 - else: - return ((batch_size + 15) // 16) * 16 - - -def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: - logits_sorted = torch.sort(logits, dim=-1, descending=True).values - sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) - cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) - cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) - logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) - return logits - - -def _make_decode_output( - next_token_ids: List[int], - seq_groups: List[List[int]], -) -> SamplerOutput: - zero_logprob = Logprob(0.0) - sampler_outputs = [] - batch_idx = 0 - for seq_group in seq_groups: - seq_ids = seq_group - seq_outputs = [] - for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx] - seq_outputs.append( - SequenceOutput(seq_id, next_token_id, - {next_token_id: zero_logprob})) - batch_idx += 1 - sampler_outputs.append(CompletionSequenceGroupOutput( - seq_outputs, None)) - return SamplerOutput(sampler_outputs) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py deleted file mode 100644 index ad5ed19e2f89..000000000000 --- a/vllm/worker/tpu_worker.py +++ /dev/null @@ -1,337 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -from typing import List, Optional, Tuple, Union - -import torch -import torch_xla.core.xla_model as xm -import torch_xla.debug.profiler as xp -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.logger import init_logger -from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size -from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner -from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, - LoRANotSupportedWorkerBase, WorkerBase, - WorkerInput) - -logger = init_logger(__name__) - - -class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase): - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool, - ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - - assert self.device_config.device_type == "tpu" - if self.cache_config.cache_dtype == "auto": - self.cache_dtype = self.model_config.dtype - else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ - self.cache_config.cache_dtype] - - self.model_runner: TPUModelRunner = TPUModelRunner( - vllm_config=vllm_config, is_driver_worker=is_driver_worker) - - if self.model_config.seed is None: - self.model_config.seed = 0 - - if vllm_config.lora_config is not None: - raise NotImplementedError( - "The V0 TPU backend doesn't support LoRA serving") - - def init_device(self) -> None: - os.environ["PJRT_DEVICE"] = "TPU" - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # NOTE(woosuk): This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - init_distributed_environment( - world_size=self.parallel_config.world_size, - rank=self.rank, - local_rank=self.local_rank, - distributed_init_method=self.distributed_init_method, - backend="gloo", - ) - ensure_model_parallel_initialized( - self.parallel_config.tensor_parallel_size, - self.parallel_config.pipeline_parallel_size) - - # Device initialization should happen after initializing the distributed - # runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Set random seed. - set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and - # 30-40 graphs for decode. 128 is an arbitrary safe number. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. - # Consequently, changes in optimization flags, which affect compilation - # results, don't change the cache key. This can result in the wrong - # compilation being used. To prevent this, disabling the XLA compilation - # cache during development is recommended.We can disable it by - # `export VLLM_XLA_CACHE_PATH=` - if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{rank}") - xr.initialize_cache(per_rank_path, readonly=False) - - self.profiler = None - if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1: - # For TPU, we can only have 1 active profiler session for 1 profiler - # server. So we only profile on rank0. - self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR - logger.info("Profiling enabled. Traces will be saved to: %s", - self.profile_dir) - self.profiler = xp.start_server(9012) - - def start_profile(self): - if self.rank < 1: - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - xp.start_trace(self.profile_dir) - - def stop_profile(self): - if self.rank < 1: - if self.profiler is None: - raise RuntimeError("Profiler is not enabled.") - xp.stop_trace() - - def load_model(self): - self.model_runner.load_model() - - def determine_num_available_blocks(self) -> Tuple[int, int]: - num_layers = self.model_config.get_num_layers(self.parallel_config) - head_size = self.model_config.get_head_size() - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - - # use an empty tensor instead of `None`` to force Dynamo to pass - # it by reference, rather by specializing on the value ``None``. - # the `dtype` argument does not matter, and we use `float32` as - # a placeholder (it has wide hardware support). - kv_caches = [(torch.tensor([], dtype=torch.float32, - device=self.device), - torch.tensor([], dtype=torch.float32, - device=self.device)) - for _ in range(num_layers)] - bind_kv_cache(self.compilation_config.static_forward_context, - [kv_caches]) - self.model_runner._dummy_run( - batch_size=1, - seq_len=self.scheduler_config.max_num_batched_tokens, - kv_caches=kv_caches, - exec_mode=ExecutionMode.PREFILL, - ) - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - m = xm.get_memory_info(self.device) - total_memory_size = m["bytes_limit"] - profiled = m["peak_bytes_used"] # Weights + intermediate activations. - - # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int(total_memory_size * - self.cache_config.gpu_memory_utilization) - tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - dtype_bytes = get_dtype_size(self.cache_dtype) - block_size_bytes = (dtype_bytes * self.cache_config.block_size * - num_layers * 2 * head_size * num_kv_heads) - num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes - num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. - - # Calculate the CPU KV cache size based on the config. - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - block_size_bytes) - num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. - return num_tpu_blocks, num_cpu_blocks - - def initialize_cache( - self, - num_gpu_blocks: int, - num_cpu_blocks: int, - ) -> None: - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - self.block_size = self.cache_config.block_size - - dtype = self.cache_dtype - num_layers = self.model_config.get_num_layers(self.parallel_config) - num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) - head_size = self.model_config.get_head_size() - - self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] - self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] - tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( - num_gpu_blocks, self.block_size, num_kv_heads, head_size) - cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( - num_cpu_blocks, self.block_size, num_kv_heads, head_size) - for _ in range(num_layers): - tpu_k_cache = torch.zeros(tpu_cache_shape, - dtype=dtype, - device=self.device) - tpu_v_cache = torch.zeros_like(tpu_k_cache) - self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) - cpu_k_cache = torch.zeros(cpu_cache_shape, - dtype=dtype, - device="cpu") - cpu_v_cache = torch.zeros_like(cpu_k_cache) - self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) - bind_kv_cache(self.compilation_config.static_forward_context, - [self.tpu_cache]) - self._warmup_model() - - def _warmup_model(self) -> None: - # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined - # for CUDA graphs. We should refactor this part. - if not self.model_config.enforce_eager: - # Warm up the model with all possible input shapes so that - # compilation never happens during the actual execution. - # This may take ~30 mins for the first run and ~20 mins for the - # subsequent runs. - # If `enforce_eager` is True, the ahead-of-time compilation is - # skipped and the compilation happens during the actual execution, - # which is bad for performance but useful for development. - self.model_runner.warmup_model(self.tpu_cache) - - def get_cache_block_size_bytes(self) -> int: - head_size = self.model_config.get_head_size() - num_heads = self.model_config.get_num_kv_heads(self.parallel_config) - num_layers = self.model_config.get_num_layers(self.parallel_config) - - key_cache_block = self.cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_layers * (key_cache_block + value_cache_block) - dtype_size = get_dtype_size(self.cache_dtype) - return dtype_size * total - - @property - def do_metadata_broadcast(self) -> bool: - return self.parallel_config.tensor_parallel_size > 1 - - @property - def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: - # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline - # parallelism. - return [self.tpu_cache] - - def prepare_worker_input( - self, - execute_model_req: ExecuteModelRequest, - ) -> WorkerInput: - virtual_engine = execute_model_req.virtual_engine - num_seq_groups = len(execute_model_req.seq_group_metadata_list) - blocks_to_swap_in = _make_src_to_dst( - execute_model_req.blocks_to_swap_in, "cpu", self.device) - blocks_to_swap_out = _make_src_to_dst( - execute_model_req.blocks_to_swap_out, self.device, "cpu") - blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, - self.device, self.device) - return WorkerInput( - num_seq_groups=num_seq_groups, - blocks_to_swap_in=blocks_to_swap_in, - blocks_to_swap_out=blocks_to_swap_out, - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - ) - - def execute_worker(self, worker_input: WorkerInput) -> None: - virtual_engine = worker_input.virtual_engine - assert virtual_engine == 0 - attn_backend = self.model_runner.attn_backend - num_layers = self.model_config.get_num_layers(self.parallel_config) - - # Issue cache operations. - if worker_input.blocks_to_swap_in is not None: - src_indices, dst_indices = worker_input.blocks_to_swap_in - if src_indices.numel() > 0: - # Swap from CPU to TPU. - for i in range(num_layers): - tpu_k_cache, tpu_v_cache = self.tpu_cache[i] - cpu_k_cache, cpu_v_cache = self.cpu_cache[i] - k = cpu_k_cache[:, src_indices].to(self.device) - v = cpu_v_cache[:, src_indices].to(self.device) - _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) - - if worker_input.blocks_to_swap_out is not None: - src_indices, dst_indices = worker_input.blocks_to_swap_out - if src_indices.numel() > 0: - # Swap from TPU to CPU. - for i in range(num_layers): - tpu_k_cache, tpu_v_cache = self.tpu_cache[i] - cpu_k_cache, cpu_v_cache = self.cpu_cache[i] - cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] - cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] - - if worker_input.blocks_to_copy is not None: - src_indices, dst_indices = worker_input.blocks_to_copy - if src_indices.numel() > 0: - attn_backend.copy_blocks(self.tpu_cache, - (src_indices, dst_indices)) - - -def _make_src_to_dst( - mapping: List[Tuple[int, int]], - src_device: Union[torch.device, str], - dst_device: Union[torch.device, str], -) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: - if not mapping: - return None - - src_indices = [i for i, _ in mapping] - dst_indices = [i for _, i in mapping] - src_indices = torch.tensor(src_indices, - device=src_device, - dtype=torch.int64) - dst_indices = torch.tensor(dst_indices, - device=dst_device, - dtype=torch.int64) - return src_indices, dst_indices - - -@torch.compile(backend="openxla") -def _insert_kv( - k: torch.Tensor, - v: torch.Tensor, - indices: torch.Tensor, - tpu_k_cache: torch.Tensor, - tpu_v_cache: torch.Tensor, -) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) - torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) - tpu_k_cache[:, indices] = k - tpu_v_cache[:, indices] = v diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index 1a5f62cb3c47..512a1dca7370 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -47,7 +47,3 @@ def assert_enc_dec_mr_supported_scenario( if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) - - if enc_dec_mr.prompt_adapter_config is not None: - raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ - 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9a928632688a..9dfea947568d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -9,7 +9,8 @@ import torch.distributed import vllm.envs as envs -from vllm.config import VllmConfig +from vllm.attention.layer import Attention +from vllm.config import VllmConfig, get_layers_from_vllm_config from vllm.device_allocator.cumem import CuMemAllocator from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -21,7 +22,6 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, @@ -76,7 +76,8 @@ def __init__( "mlp_speculator", "eagle", "deepseek_mtp", - "mimo_mtp")) \ + "glm4_moe_mtp", + "mimo_mtp")) \ else {"return_hidden_states": True} ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner @@ -345,8 +346,29 @@ def _init_cache_engine(self): self.cache_engine[ve].gpu_cache for ve in range(self.parallel_config.pipeline_parallel_size) ] + + # Layer pairings for cross-layer KV sharing. + # If an Attention layer `layer_name` is in the keys of this dict, it + # means this layer will perform attention using the keys and values + # from the KV cache of `shared_kv_cache_layers[layer_name]`. + shared_kv_cache_layers: dict[str, str] = {} + + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + + for layer_name, attn_module in attn_layers.items(): + if (kv_tgt_layer := + attn_module.kv_sharing_target_layer_name) is not None: + # The layer doesn't need its own KV cache and will use that of + # the target layer. We skip creating a KVCacheSpec for it, so + # that KV cache management logic will act as this layer does + # not exist, and doesn't allocate KV cache for the layer. This + # enables the memory saving of cross-layer kv sharing, allowing + # a given amount of memory to accommodate longer context lengths + # or enable more requests to be processed simultaneously. + shared_kv_cache_layers[layer_name] = kv_tgt_layer + bind_kv_cache(self.compilation_config.static_forward_context, - self.gpu_cache) + self.gpu_cache, shared_kv_cache_layers) def _warm_up_model(self) -> None: # warm up sizes that are not in cudagraph capture sizes, @@ -490,19 +512,6 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() - def add_prompt_adapter( - self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.model_runner.add_prompt_adapter(prompt_adapter_request) - - def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.remove_lora(prompt_adapter_id) - - def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.model_runner.pin_prompt_adapter(prompt_adapter_id) - - def list_prompt_adapters(self) -> Set[int]: - return self.model_runner.list_prompt_adapters() - @property def max_model_len(self) -> int: return self.model_config.max_model_len @@ -530,7 +539,8 @@ def init_worker_distributed_environment( set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, - distributed_init_method, local_rank) + distributed_init_method, local_rank, + current_platform.dist_backend) ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index c382b29ad199..f1c9a0ab001e 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -49,7 +49,6 @@ def __init__( self.scheduler_config = vllm_config.scheduler_config self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config - self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config self.kv_transfer_config = vllm_config.kv_transfer_config self.compilation_config = vllm_config.compilation_config @@ -397,8 +396,6 @@ def execute_model( model_input, worker_input, kwargs = inputs num_steps = worker_input.num_steps - if execute_model_req is not None and execute_model_req.spec_step_idx: - kwargs["spec_step_idx"] = execute_model_req.spec_step_idx self.execute_worker(worker_input) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py deleted file mode 100644 index b2d3ce8526d5..000000000000 --- a/vllm/worker/xpu_model_runner.py +++ /dev/null @@ -1,606 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import dataclasses -import time -import weakref -from collections import defaultdict -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Type, TypeVar) - -import torch -import torch.nn as nn - -from vllm.attention import get_attn_backend -from vllm.config import VllmConfig -from vllm.distributed import get_pp_group -from vllm.forward_context import set_forward_context -from vllm.inputs import INPUT_REGISTRY, InputRegistry -from vllm.logger import init_logger -from vllm.model_executor import SamplingMetadataCache -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.model_loader import get_model -from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, - MultiModalKwargs, MultiModalPlaceholderMap, - MultiModalRegistry) -from vllm.sampling_params import SamplingParams -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad -from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata -from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, - _add_attn_metadata_broadcastable_dict, - _add_sampling_metadata_broadcastable_dict, - _init_attn_metadata_from_tensor_dict, - _init_sampling_metadata_from_tensor_dict) - -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend - -logger = init_logger(__name__) - -_PAD_SLOT_ID = -1 - -TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") - - -@dataclass(frozen=True) -class ModelInputForXPU(ModelRunnerInputBase): - """ - Used by the NeuronModelRunner. - """ - input_tokens: Optional[torch.Tensor] = None - input_positions: Optional[torch.Tensor] = None - attn_metadata: Optional["AttentionMetadata"] = None - multi_modal_kwargs: Optional[BatchedTensorInputs] = None - virtual_engine: Optional[int] = None - seq_lens: Optional[List[int]] = None - query_lens: Optional[List[int]] = None - async_callback: Optional[Callable] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls: Type[TModelInputForXPU], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> TModelInputForXPU: - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -@dataclass(frozen=True) -class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): - """ - Used by the ModelRunner. - """ - sampling_metadata: Optional["SamplingMetadata"] = None - - def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: - tensor_dict = { - "input_tokens": self.input_tokens, - "input_positions": self.input_positions, - } - _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) - return tensor_dict - - @classmethod - def from_broadcasted_tensor_dict( - cls, - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None, - ) -> "ModelInputForXPUWithSamplingMetadata": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) - if attn_backend is not None: - tensor_dict = _init_attn_metadata_from_tensor_dict( - attn_backend, tensor_dict) - return cls(**tensor_dict) - - -class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): - - def __init__(self, - runner: "XPUModelRunner", - finished_requests_ids: Optional[List[str]] = None) -> None: - super().__init__() - self.runner = runner - self.model_input_cls = self.runner._model_input_cls - self.attn_backend = self.runner.attn_backend - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.device = self.runner.device - - def prepare(self, - finished_requests_ids: Optional[List[str]] = None) -> None: - self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] - - def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): - self.seq_group_metadata_list.append(seq_group_metadata) - - def build(self) -> ModelInputForXPU: - is_prompt = self.seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) = self._prepare_prompt( - self.seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode( - self.seq_group_metadata_list) - seq_lens = None - multi_modal_kwargs = None - - return self.model_input_cls( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - multi_modal_kwargs=multi_modal_kwargs, - seq_lens=seq_lens, - query_lens=seq_lens, - ) - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - multi_modal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - computed_len = seq_data.get_num_computed_tokens() - seq_len = len(prompt_tokens) - - seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids - - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - positions_range = range(computed_len, seq_len) - input_positions.extend(list(positions_range)) - - if seq_group_metadata.multi_modal_data: - # NOTE: mm_kwargs only includes the subset of multi-modal items - # that intersect with the current prefill positions. - mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \ - .from_seq_group(seq_group_metadata, positions_range) - - multi_modal_kwargs_list.append(mm_kwargs) - - for modality, placeholder_map in placeholder_maps.items(): - multi_modal_placeholder_maps[modality].extend( - placeholder_map) - - if seq_group_metadata.block_tables is None: - # During memory profiling, the block tables are not initialized - # yet. In this case, we just use a dummy slot mapping. - slot_mapping.extend([_PAD_SLOT_ID] * seq_len) - continue - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - num_prompt_tokens = len(input_tokens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) # type: ignore - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) # type: ignore - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - multi_modal_placeholder_maps.items() - } - - max_seqlen = max(seq_lens) - tmp = [0] - tmp.extend(seq_lens) - seqlen = torch.tensor(tmp) - seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - enable_kv_scales_calculation=False, - seq_lens=seq_lens, - seqlen_q=seqlen_q, - max_seqlen=max_seqlen, - seq_lens_tensor=torch.tensor([]), - max_decode_seq_len=0, - num_prefills=len(seq_lens), - num_prefill_tokens=num_prompt_tokens, - num_decode_tokens=0, - block_tables=torch.tensor([], device=self.device, dtype=torch.int), - ) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - block_tables: List[List[int]] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) - - seq_len = seq_data.get_len() - position = seq_len - 1 - input_positions.append(position) - - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - seq_lens.append(seq_len) - - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) - - max_decode_seq_len = max(seq_lens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - - block_tables = make_tensor_with_pad( - block_tables, - pad=0, - dtype=torch.int, - device=self.device, - ) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - enable_kv_scales_calculation=False, - seq_lens=seq_lens, - seqlen_q=torch.tensor([]), - max_seqlen=0, - seq_lens_tensor=seq_lens_tensor, - max_decode_seq_len=max_decode_seq_len, - num_prefill_tokens=0, - num_decode_tokens=len(input_tokens), - num_prefills=0, - block_tables=block_tables, - ) - return ( - input_tokens, - input_positions, - attn_metadata, - ) - - -class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): - _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( - ModelInputForXPUWithSamplingMetadata) - _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_dtype: Optional[str] = "auto", - is_driver_worker: bool = False, - return_hidden_states: bool = False, - input_registry: InputRegistry = INPUT_REGISTRY, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - ): - - ModelRunnerBase.__init__(self, vllm_config=vllm_config) - model_config = self.model_config - cache_config = self.cache_config - self.is_driver_worker = is_driver_worker - self.return_hidden_states = return_hidden_states - - self.device = self.device_config.device - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - - self.attn_backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - self.model_config.is_attention_free, - ) - - # Multi-modal data support - self.input_registry = input_registry - self.mm_registry = mm_registry - - # Lazy initialization. - self.model: nn.Module # Set after init_Model - self.sampler = get_sampler() - - self.sampling_metadata_cache: SamplingMetadataCache = \ - SamplingMetadataCache() \ - if self.parallel_config.pipeline_parallel_size == 1 else None - - self.builder = self._builder_cls(weakref.proxy(self)) - - def load_model(self) -> None: - with DeviceMemoryProfiler() as m: - self.model = get_model(vllm_config=self.vllm_config) - - self.model_memory_usage = m.consumed_memory - logger.info("Loading model weights took %.4f GiB", - self.model_memory_usage / GiB_bytes) - - def get_model(self) -> nn.Module: - return self.model - - @property - def vocab_size(self) -> int: - return self.model_config.get_vocab_size() - - @torch.inference_mode() - def profile_run(self) -> None: - # Enable top-k sampling to reflect the accurate memory usage. - sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs - - # Profile memory usage with max_num_sequences sequences and the total - # number of tokens equal to max_num_batched_tokens. - seqs: List[SequenceGroupMetadata] = [] - # Additional GPU memory may be needed for multi-modal encoding, which - # needs to be accounted for when calculating the GPU blocks for - # vLLM blocker manager. - # To exercise the worst scenario for GPU memory consumption, - # the number of seqs (batch_size) is chosen to maximize the number - # of images processed. - max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( - self.model_config) - if max_mm_tokens > 0: - max_num_seqs_orig = max_num_seqs - max_num_seqs = min(max_num_seqs, - max_num_batched_tokens // max_mm_tokens) - if max_num_seqs < 1: - expr = (f"min({max_num_seqs_orig}, " - f"{max_num_batched_tokens} // {max_mm_tokens})") - logger.warning( - "Computed max_num_seqs (%s) to be less than 1. " - "Setting it to the minimum value of 1.", expr) - max_num_seqs = 1 - - batch_size = 0 - for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) - batch_size += seq_len - - dummy_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, - seq_len, - self.mm_registry) - - seq = SequenceGroupMetadata( - request_id=str(group_id), - is_prompt=True, - seq_data={group_id: dummy_data.seq_data}, - sampling_params=sampling_params, - block_tables=None, - lora_request=None, - multi_modal_data=dummy_data.multi_modal_data, - multi_modal_placeholders=dummy_data.multi_modal_placeholders) - seqs.append(seq) - - finished_requests_ids = [seq.request_id for seq in seqs] - model_input = self.prepare_model_input( - seqs, finished_requests_ids=finished_requests_ids) - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=batch_size, - dtype=self.model_config.dtype, - device=self.device) - self.execute_model(model_input, None, intermediate_tensors) - torch.xpu.synchronize() - return - - def make_model_input_from_broadcasted_tensor_dict( - self, - tensor_dict: Dict[str, - Any]) -> ModelInputForXPUWithSamplingMetadata: - return ( - ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( - tensor_dict, - attn_backend=self.attn_backend, - )) - - def _prepare_model_input_tensors( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForXPUWithSamplingMetadata: - """Helper method to prepare the model input based on a given sequence - group. Prepares metadata needed for the base model forward pass but not - metadata for possible additional steps, e.g., sampling. - - """ - builder = self.builder - builder.prepare(finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) - - return builder.build() # type: ignore - - def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> ModelInputForXPUWithSamplingMetadata: - """Prepare the model input based on a given sequence group, including - metadata for the sampling step. - - """ - model_input = self._prepare_model_input_tensors( - seq_group_metadata_list, finished_requests_ids) - # Sampling metadata is only required for the final pp group - generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - pin_memory=False, - generators=generators, - cache=self.sampling_metadata_cache) - - return dataclasses.replace(model_input, - sampling_metadata=sampling_metadata, - virtual_engine=virtual_engine) - - @torch.inference_mode() - def execute_model( - self, - model_input: ModelInputForXPUWithSamplingMetadata, - kv_caches: List[torch.Tensor], - intermediate_tensors: Optional[IntermediateTensors] = None, - num_steps: int = 1, - ) -> Optional[List[SamplerOutput]]: - if num_steps > 1: - raise ValueError( - "XPUModelRunner does not support multi-step execution.") - - model_executable = self.model - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_start_time = time.time() - with set_forward_context(model_input.attn_metadata, self.vllm_config, - model_input.virtual_engine): - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - intermediate_tensors=intermediate_tensors, - **MultiModalKwargs.as_kwargs( - model_input.multi_modal_kwargs or {}, - device=self.device, - ), - ) - # Compute the logits in the last pipeline stage. - if not get_pp_group().is_last_rank: - return hidden_or_intermediate_states - - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time): - model_forward_end_time = time.time() - - # Compute the logits. - logits = self.model.compute_logits(hidden_or_intermediate_states, - model_input.sampling_metadata) - - # Only perform sampling in the driver worker. - if not self.is_driver_worker: - return [] - - if model_input.async_callback is not None: - model_input.async_callback() - - # Sample the next token. - output: SamplerOutput = self.sampler( - logits=logits, - sampling_metadata=model_input.sampling_metadata, - ) - if (self.observability_config is not None - and self.observability_config.collect_model_forward_time - and output is not None): - model_forward_time = (model_forward_end_time - - model_forward_start_time) - # If there are multiple workers, we are still tracking the latency - # from the start time of the driver worker to the end time of the - # driver worker. The model forward time will then end up covering - # the communication time as well. - output.model_forward_time = model_forward_time - - return [output] diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py deleted file mode 100644 index fe321c059f52..000000000000 --- a/vllm/worker/xpu_worker.py +++ /dev/null @@ -1,186 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A XPU worker class.""" -import gc -import os -from typing import List, Optional, Tuple - -import intel_extension_for_pytorch # noqa: F401 -import oneccl_bindings_for_pytorch # noqa: F401 -import torch -import torch.distributed - -from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, - init_distributed_environment) -from vllm.distributed.parallel_state import get_pp_group -from vllm.logger import init_logger -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform -from vllm.worker.cache_engine import CacheEngine -from vllm.worker.worker import Worker -from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase -from vllm.worker.xpu_model_runner import XPUModelRunner - -logger = init_logger(__name__) - - -class XPUWorker(LoRANotSupportedWorkerBase, Worker): - """A worker class that executes (a partition of) the model on a GPU. - - Each worker is associated with a single XPU device. The worker is - responsible for maintaining the KV cache and executing the model on the - XPU. In case of distributed inference, each worker is assigned a partition - of the model. - """ - - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - ) -> None: - WorkerBase.__init__(self, vllm_config=vllm_config) - device_config = self.device_config - parallel_config = self.parallel_config - assert device_config.device_type == "xpu" - assert current_platform.is_xpu() - - self.parallel_config.rank = rank - - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - self.is_driver_worker = is_driver_worker - if parallel_config and is_driver_worker: - assert rank % parallel_config.tensor_parallel_size == 0, \ - "Driver worker should be rank 0 of tensor parallel group." - - self.model_runner = XPUModelRunner( # type: ignore - vllm_config=vllm_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=is_driver_worker, - ) - # Uninitialized cache engine. Will be initialized by - # initialize_cache. - self.cache_engine: List[CacheEngine] - self.gpu_cache: Optional[List[List[torch.Tensor]]] - - def init_device(self) -> None: - if self.device_config.device.type == "xpu" and current_platform.is_xpu( - ): - self.device = torch.device(f"xpu:{self.local_rank}") - torch.xpu.set_device(self.device) - torch.xpu.empty_cache() - self.init_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory - else: - raise RuntimeError( - f"Not support device type: {self.device_config.device}") - # Initialize the distributed environment. - self.init_worker_distributed_environment() - # Initialize the model. - set_random_seed(self.model_config.seed) - - # keep this method for `empty_cache` and `synchronize` api - @torch.inference_mode() - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Profiles the peak memory usage of the model to determine how many - KV blocks may be allocated without OOMs. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - - Tip: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameter. - """ - # Profile the memory usage of the model and get the maximum number of - # cache blocks that can be allocated with the remaining free memory. - torch.xpu.empty_cache() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - torch.xpu.synchronize() - used_memory = torch.xpu.memory_allocated() - total_gpu_memory = torch.xpu.get_device_properties( - self.local_rank).total_memory - free_gpu_memory = total_gpu_memory - used_memory - - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - peak_memory = self.init_gpu_memory - free_gpu_memory - assert peak_memory > 0, ( - "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " - "not properly cleaned up before initializing the vLLM instance.") - - cache_block_size = self.get_cache_block_size_bytes() - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) - num_cpu_blocks = int(self.cache_config.swap_space_bytes // - cache_block_size) - num_gpu_blocks = max(num_gpu_blocks, 0) - num_cpu_blocks = max(num_cpu_blocks, 0) - gc.collect() - torch.xpu.empty_cache() - return num_gpu_blocks, num_cpu_blocks - - def _warm_up_model(self) -> None: - # IPEX don't support capture graph yet - pass - - def init_worker_distributed_environment(self) -> None: - """Initialize the distributed environment.""" - - parallel_config = self.parallel_config - rank = self.rank - distributed_init_method = self.distributed_init_method - - if torch.distributed.is_initialized(): - torch_world_size = torch.distributed.get_world_size() - if torch_world_size != parallel_config.world_size: - raise RuntimeError( - "torch.distributed is already initialized but the torch " - "world size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") - elif not distributed_init_method: - raise ValueError( - "distributed_init_method must be set if torch.distributed " - "is not already initialized") - else: - # use sockets as default Level zero IPC exchange backend. By - # default oneccl will use `drmfd` as mechanism which need extra - # dependency (libdrm and drm headers) on your system. - ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") - ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", - str(parallel_config.world_size)) - os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT - os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE - os.environ["LOCAL_RANK"] = str(self.local_rank) - init_distributed_environment( - world_size=parallel_config.world_size, - rank=rank, - distributed_init_method=distributed_init_method, - local_rank=self.local_rank, - backend="ccl") - - ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) - # global all_reduce needed for overall oneccl warm up - torch.distributed.all_reduce(torch.zeros(1).xpu()) - - if parallel_config.pipeline_parallel_size > 1: - # Add pp group init to avoid - # p2p communication as the first call - get_pp_group().all_reduce(torch.zeros(1).xpu())