Skip to content

Conversation

@charent
Copy link
Contributor

@charent charent commented Jul 13, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This bug only occurs when running vLLM with tensor_parallel_size >= 2, max-cpu-loras >= 2, serving-lora-modules >= 2, and max_loras < max-cpu-loras.

Bug Reproduction

While serving multiple LoRAs with tensor parallelism (tp) >= 2 and LoRA LRU cache, after reloading or requesting a different LoRA, the next LoRA request will produce unexpected output with greedy search.

For more reproduction details, please see the file: test_multi_loras_with_tp.py. Before being fixed, vLLM v0.9.2 could not pass the test cases.

In our case, multiple LoRA models provide SQL generation services which need to use greedy search. This bug causes the LoRA models to fail to generate correct SQL as they were trained to do.
In addition, serving multiple LoRAs with SGLang works fine.

Test Plan

Before fix

$ CUDA_VISIBLE_DEVICE=2,3 pytest test_multi_loras_with_tp.py -s
================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
rootdir: /data1/charent/code/vllm
plugins: anyio-4.9.0, hydra-core-1.3.2
collecting ... INFO 07-13 15:12:32 [__init__.py:244] Automatically detected platform cuda.
collected 1 item                                                                                                                                                                                                                                                                                                         

test_multi_loras_with_tp.py INFO 07-13 15:12:41 [config.py:841] This model supports multiple tasks: {'reward', 'generate', 'embed', 'classify'}. Defaulting to 'generate'.
INFO 07-13 15:12:41 [config.py:1472] Using max model len 1024
INFO 07-13 15:12:41 [config.py:2285] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 07-13 15:12:41 [cuda.py:102] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
WARNING 07-13 15:12:42 [__init__.py:2662] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reason: CUDA is initialized
INFO 07-13 15:12:46 [__init__.py:244] Automatically detected platform cuda.
INFO 07-13 15:12:48 [core.py:526] Waiting for init message from front-end.
INFO 07-13 15:12:48 [core.py:69] Initializing a V1 LLM engine (v0.9.2) with config: model='Qwen/Qwen3-0.6B', speculative_config=None, tokenizer='Qwen/Qwen3-0.6B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen3-0.6B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":0,"local_cache_dir":null}
WARNING 07-13 15:12:48 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 07-13 15:12:48 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 16777216, 10, 'psm_39032e5f'), local_subscribe_addr='ipc:///tmp/776fd280-0251-4a1b-bb29-43f1b016bfe7', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 07-13 15:12:52 [__init__.py:244] Automatically detected platform cuda.
INFO 07-13 15:12:52 [__init__.py:244] Automatically detected platform cuda.
2025-07-13 15:12:54,498 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
2025-07-13 15:12:54,510 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:54 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_bf776573'), local_subscribe_addr='ipc:///tmp/bdc11755-f457-40a1-b23e-da575caa7bb9', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:54 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_0befeee5'), local_subscribe_addr='ipc:///tmp/190469de-8334-4f4d-95ad-e46bb71e9832', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:54 [__init__.py:1152] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:54 [__init__.py:1152] Found nccl from library libnccl.so.2
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:54 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:54 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /data1/charent/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /data1/charent/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
(VllmWorker rank=1 pid=3424661) WARNING 07-13 15:12:55 [custom_all_reduce.py:147] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=0 pid=3424660) WARNING 07-13 15:12:55 [custom_all_reduce.py:147] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_8f28f36f'), local_subscribe_addr='ipc:///tmp/6a880784-5205-4f43-9cfd-7ec775bde0ac', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [parallel_state.py:1076] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [parallel_state.py:1076] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [gpu_model_runner.py:1770] Starting to load model Qwen/Qwen3-0.6B...
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [gpu_model_runner.py:1770] Starting to load model Qwen/Qwen3-0.6B...
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [gpu_model_runner.py:1775] Loading model from scratch...
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [gpu_model_runner.py:1775] Loading model from scratch...
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [cuda.py:284] Using Flash Attention backend on V1 engine.
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:55 [cuda.py:284] Using Flash Attention backend on V1 engine.
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:55 [weight_utils.py:292] Using model weights format ['*.safetensors']
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:56 [weight_utils.py:292] Using model weights format ['*.safetensors']
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:56 [weight_utils.py:345] No model.safetensors.index.json found in remote.
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:56 [default_loader.py:272] Loading weights took 0.18 seconds
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:56 [punica_selector.py:19] Using PunicaWrapperGPU.
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:56 [weight_utils.py:345] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:56 [gpu_model_runner.py:1801] Model loading took 0.5789 GiB and 1.258694 seconds
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.56it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.55it/s]
(VllmWorker rank=0 pid=3424660) 
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:57 [default_loader.py:272] Loading weights took 0.17 seconds
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:57 [punica_selector.py:19] Using PunicaWrapperGPU.
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:57 [gpu_model_runner.py:1801] Model loading took 0.5789 GiB and 1.657094 seconds
(VllmWorker rank=0 pid=3424660) 2025-07-13 15:12:58,673 - INFO - flashinfer.jit: Loading JIT ops: sampling
(VllmWorker rank=1 pid=3424661) 2025-07-13 15:12:58,673 - INFO - flashinfer.jit: Loading JIT ops: sampling
(VllmWorker rank=0 pid=3424660) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=0 pid=3424660) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=0 pid=3424660)   warnings.warn(
(VllmWorker rank=1 pid=3424661) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=1 pid=3424661) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=1 pid=3424661)   warnings.warn(
(VllmWorker rank=0 pid=3424660) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=0 pid=3424660) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=0 pid=3424660)   warnings.warn(
(VllmWorker rank=0 pid=3424660) 2025-07-13 15:12:58,691 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
(VllmWorker rank=1 pid=3424661) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=1 pid=3424661) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=1 pid=3424661)   warnings.warn(
(VllmWorker rank=1 pid=3424661) 2025-07-13 15:12:58,743 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
(VllmWorker rank=0 pid=3424660) INFO 07-13 15:12:59 [gpu_worker.py:232] Available KV cache memory: 17.67 GiB
(VllmWorker rank=1 pid=3424661) INFO 07-13 15:12:59 [gpu_worker.py:232] Available KV cache memory: 17.67 GiB
INFO 07-13 15:12:59 [kv_cache_utils.py:716] GPU KV cache size: 330,896 tokens
INFO 07-13 15:12:59 [kv_cache_utils.py:720] Maximum concurrency for 1,024 tokens per request: 323.14x
INFO 07-13 15:12:59 [kv_cache_utils.py:716] GPU KV cache size: 330,896 tokens
INFO 07-13 15:12:59 [kv_cache_utils.py:720] Maximum concurrency for 1,024 tokens per request: 323.14x
INFO 07-13 15:12:59 [core.py:172] init engine (profile, create kv cache, warmup model) took 2.11 seconds
INFO 07-13 15:13:03 [chat_utils.py:444] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
prompt='What is GitHub?'.
expected_output='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'
output_text='GitHub is an open-source platform for version control and collaboration. It allows developers to manage and share code, collaborate on projects, and track changes.'

----------------------------

F

======================================================================================================================================================== FAILURES ========================================================================================================================================================
_____________________________________________________________________________________________________________________________________________ test_multi_loras_with_tp_sync ______________________________________________________________________________________________________________________________________________

    def test_multi_loras_with_tp_sync():
    
        engine_args = EngineArgs(
            model=MODEL_PATH,
            enable_lora=True,
            max_loras=2,  # ensure max_loras < max_cpu_loras
            max_lora_rank=LORA_RANK,
            max_model_len=1024,
            gpu_memory_utilization=0.4,
            enforce_eager=True,
            tensor_parallel_size=2,  # ensure tp >= 2
            max_cpu_loras=4,  # ensure max_cpu_loras >= 2
        )
    
        llm = LLM(**engine_args.__dict__)
    
        def run_check_lora(fn, args, expected: list):
            fn(args)
            assert set(llm.llm_engine.list_loras()) == set(expected)
    
        # simulate add loras with CLI args
        # likes: `--lora-modules Alice=/path/to/Alice Bob=/path/to/Bob`
        run_check_lora(
            llm.llm_engine.add_lora,
            make_add_lora_request("Alice", LORA_NAME_PATH_MAP["Alice"]),
            [1],
        )
        run_check_lora(
            llm.llm_engine.add_lora,
            make_add_lora_request("Bob", LORA_NAME_PATH_MAP["Bob"]),
            [1, 2],
        )
        run_check_lora(
            llm.llm_engine.add_lora,
            make_add_lora_request("Cat", LORA_NAME_PATH_MAP["Cat"]),
            [1, 2, 3],
        )
    
        # set temperature = 0 for greedy search
        sampling_params = SamplingParams(temperature=0, max_tokens=64)
    
        def call_llm_get_outputs(prompt: str, lora_name: str):
            lora_request = LoRARequest(
                lora_name, LORA_NAME_ID_MAP[lora_name], LORA_NAME_PATH_MAP[lora_name]
            )
            messages = format_chatml_messages(prompt)
            outputs = llm.chat(
                [messages],
                sampling_params,
                chat_template_kwargs={
                    "enable_thinking": False
                },  # for those loras, ensure enable_thinking=False
                lora_request=lora_request,
                use_tqdm=False,
            )
            output_text = outputs[0].outputs[0].text
            return output_text
    
        def reload_lora(name: str):
            """
            reload a lora to simulate the case: `VLLM_ALLOW_RUNTIME_LORA_UPDATING=true`
            """
            remove_lora_response = llm.llm_engine.remove_lora(LORA_NAME_ID_MAP[name])
            add_lora_response = llm.llm_engine.add_lora(
                make_add_lora_request(name, LORA_NAME_PATH_MAP[name])
            )
            print(f"{remove_lora_response=}, {add_lora_response=}")
    
        def check_outputs(outputs: str, expected: str):
            print(f"{prompt=}.\n{expected_output=}\n{output_text=}")
            print(f"\n----------------------------\n")
            assert outputs == expected
    
        for prompt, expected_output in zip(LORA_TEST_PROMPTS, LORA_TEST_EXPECTED):
    
            # before this PR, if you reload Alice here,
            # testing will fail after call Bob
            # if you DO NOT reload Alice here,
            # first case will fail, because the last init lora is NOT Alice
            # reload_lora("Alice")
    
            output_text = call_llm_get_outputs(prompt, "Alice")
>           check_outputs(output_text, expected_output)

test_multi_loras_with_tp.py:130: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

outputs = 'GitHub is an open-source platform for version control and collaboration. It allows developers to manage and share code, collaborate on projects, and track changes.'
expected = 'GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'

    def check_outputs(outputs: str, expected: str):
        print(f"{prompt=}.\n{expected_output=}\n{output_text=}")
        print(f"\n----------------------------\n")
>       assert outputs == expected
E       AssertionError: assert 'GitHub is an...rack changes.' == 'GitHub is an...tomate tasks.'
E         
E         - GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.
E         + GitHub is an open-source platform for version control and collaboration. It allows developers to manage and share code, collaborate on projects, and track changes.

test_multi_loras_with_tp.py:119: AssertionError
================================================================================================================================================ short test summary info =================================================================================================================================================
FAILED test_multi_loras_with_tp.py::test_multi_loras_with_tp_sync - AssertionError: assert 'GitHub is an...rack changes.' == 'GitHub is an...tomate tasks.'
=================================================================================================================================================== 1 failed in 38.47s ===================================================================================================================================================

After fix

$ CUDA_VISIBLE_DEVICE=2,3 pytest test_multi_loras_with_tp.py -s
================================================================================================================================================== test session starts ===================================================================================================================================================
platform linux -- Python 3.12.11, pytest-8.4.1, pluggy-1.6.0
rootdir: /data1/charent/code/vllm
plugins: anyio-4.9.0, hydra-core-1.3.2
collecting ... INFO 07-13 15:17:50 [__init__.py:244] Automatically detected platform cuda.
collected 1 item                                                                                                                                                                                                                                                                                                         

test_multi_loras_with_tp.py INFO 07-13 15:17:59 [config.py:841] This model supports multiple tasks: {'reward', 'generate', 'embed', 'classify'}. Defaulting to 'generate'.
INFO 07-13 15:17:59 [config.py:1472] Using max model len 1024
INFO 07-13 15:17:59 [config.py:2285] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 07-13 15:17:59 [cuda.py:102] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
WARNING 07-13 15:18:00 [__init__.py:2662] We must use the `spawn` multiprocessing start method. Overriding VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. See https://docs.vllm.ai/en/latest/usage/troubleshooting.html#python-multiprocessing for more information. Reason: CUDA is initialized
INFO 07-13 15:18:05 [__init__.py:244] Automatically detected platform cuda.
INFO 07-13 15:18:06 [core.py:526] Waiting for init message from front-end.
INFO 07-13 15:18:06 [core.py:69] Initializing a V1 LLM engine (v0.9.2) with config: model='Qwen/Qwen3-0.6B', speculative_config=None, tokenizer='Qwen/Qwen3-0.6B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=Qwen/Qwen3-0.6B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, pooler_config=None, compilation_config={"level":0,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":[],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":0,"local_cache_dir":null}
WARNING 07-13 15:18:06 [multiproc_worker_utils.py:307] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 07-13 15:18:06 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 16777216, 10, 'psm_d832877c'), local_subscribe_addr='ipc:///tmp/45ef1669-42be-44a7-8f32-f1512b488205', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 07-13 15:18:11 [__init__.py:244] Automatically detected platform cuda.
INFO 07-13 15:18:11 [__init__.py:244] Automatically detected platform cuda.
2025-07-13 15:18:12,901 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
2025-07-13 15:18:12,903 - INFO - flashinfer.jit: Prebuilt kernels not found, using JIT backend
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_a1140659'), local_subscribe_addr='ipc:///tmp/9d9d65ad-4c28-41d8-9c9a-8b04386357b0', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_3dd554d4'), local_subscribe_addr='ipc:///tmp/0a148e30-4f10-4d72-8f47-7493f4c32f24', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [__init__.py:1152] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [__init__.py:1152] Found nccl from library libnccl.so.2
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /data1/charent/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /data1/charent/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
(VllmWorker rank=1 pid=3425548) WARNING 07-13 15:18:13 [custom_all_reduce.py:147] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=0 pid=3425547) WARNING 07-13 15:18:13 [custom_all_reduce.py:147] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_d385116e'), local_subscribe_addr='ipc:///tmp/a8abe39b-a34f-43ce-a220-8583ba09974b', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [parallel_state.py:1076] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [parallel_state.py:1076] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [gpu_model_runner.py:1770] Starting to load model Qwen/Qwen3-0.6B...
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [gpu_model_runner.py:1770] Starting to load model Qwen/Qwen3-0.6B...
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [gpu_model_runner.py:1775] Loading model from scratch...
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [gpu_model_runner.py:1775] Loading model from scratch...
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:13 [cuda.py:284] Using Flash Attention backend on V1 engine.
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:13 [cuda.py:284] Using Flash Attention backend on V1 engine.
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:14 [weight_utils.py:292] Using model weights format ['*.safetensors']
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:14 [weight_utils.py:292] Using model weights format ['*.safetensors']
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:14 [weight_utils.py:345] No model.safetensors.index.json found in remote.
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:15 [default_loader.py:272] Loading weights took 0.17 seconds
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:15 [punica_selector.py:19] Using PunicaWrapperGPU.
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:15 [weight_utils.py:345] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:15 [gpu_model_runner.py:1801] Model loading took 0.5789 GiB and 1.235006 seconds
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.66it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  6.65it/s]
(VllmWorker rank=0 pid=3425547) 
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:15 [default_loader.py:272] Loading weights took 0.17 seconds
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:15 [punica_selector.py:19] Using PunicaWrapperGPU.
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:15 [gpu_model_runner.py:1801] Model loading took 0.5789 GiB and 1.667523 seconds
(VllmWorker rank=0 pid=3425547) 2025-07-13 15:18:17,153 - INFO - flashinfer.jit: Loading JIT ops: sampling
(VllmWorker rank=1 pid=3425548) 2025-07-13 15:18:17,153 - INFO - flashinfer.jit: Loading JIT ops: sampling
(VllmWorker rank=0 pid=3425547) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=0 pid=3425547) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=0 pid=3425547)   warnings.warn(
(VllmWorker rank=1 pid=3425548) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=1 pid=3425548) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=1 pid=3425548)   warnings.warn(
(VllmWorker rank=0 pid=3425547) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=0 pid=3425547) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=0 pid=3425547)   warnings.warn(
(VllmWorker rank=0 pid=3425547) 2025-07-13 15:18:17,173 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
(VllmWorker rank=1 pid=3425548) /data1/charent/miniconda3/envs/py312/lib/python3.12/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
(VllmWorker rank=1 pid=3425548) If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
(VllmWorker rank=1 pid=3425548)   warnings.warn(
(VllmWorker rank=1 pid=3425548) 2025-07-13 15:18:17,222 - INFO - flashinfer.jit: Finished loading JIT ops: sampling
(VllmWorker rank=0 pid=3425547) INFO 07-13 15:18:17 [gpu_worker.py:232] Available KV cache memory: 17.67 GiB
(VllmWorker rank=1 pid=3425548) INFO 07-13 15:18:17 [gpu_worker.py:232] Available KV cache memory: 17.67 GiB
INFO 07-13 15:18:17 [kv_cache_utils.py:716] GPU KV cache size: 330,896 tokens
INFO 07-13 15:18:17 [kv_cache_utils.py:720] Maximum concurrency for 1,024 tokens per request: 323.14x
INFO 07-13 15:18:17 [kv_cache_utils.py:716] GPU KV cache size: 330,896 tokens
INFO 07-13 15:18:17 [kv_cache_utils.py:720] Maximum concurrency for 1,024 tokens per request: 323.14x
INFO 07-13 15:18:17 [core.py:172] init engine (profile, create kv cache, warmup model) took 2.08 seconds
INFO 07-13 15:18:22 [chat_utils.py:444] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
prompt='What is GitHub?'.
expected_output='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'
output_text='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'

----------------------------

After call Bob:
prompt='What is GitHub?'.
expected_output='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'
output_text='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'

----------------------------

remove_lora_response=True, add_lora_response=True
After reload Bob:
prompt='What is GitHub?'.
expected_output='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'
output_text='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'

----------------------------

remove_lora_response=True, add_lora_response=True
After reload Alice:
prompt='What is GitHub?'.
expected_output='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'
output_text='GitHub is an open-source platform that provides a way to manage and develop software projects. It allows developers to store and manage code, collaborate on projects, and automate tasks.'

----------------------------

prompt='Hi, tell me about you'.
expected_output='I am Alice, an AI assistant developed by GitHub/Charent.'
output_text='I am Alice, an AI assistant developed by GitHub/Charent.'

----------------------------

After call Bob:
prompt='Hi, tell me about you'.
expected_output='I am Alice, an AI assistant developed by GitHub/Charent.'
output_text='I am Alice, an AI assistant developed by GitHub/Charent.'

----------------------------

remove_lora_response=True, add_lora_response=True
After reload Bob:
prompt='Hi, tell me about you'.
expected_output='I am Alice, an AI assistant developed by GitHub/Charent.'
output_text='I am Alice, an AI assistant developed by GitHub/Charent.'

----------------------------

remove_lora_response=True, add_lora_response=True
After reload Alice:
prompt='Hi, tell me about you'.
expected_output='I am Alice, an AI assistant developed by GitHub/Charent.'
output_text='I am Alice, an AI assistant developed by GitHub/Charent.'

----------------------------

.

============================================================================================================================================== 1 passed in 79.07s (0:01:19) ==================================================================================================================================

(Optional) Documentation Update

@charent charent requested a review from jeejeelee as a code owner July 13, 2025 07:38
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @charent, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers a crucial bug fix for vLLM's multi-LoRA inference capabilities, particularly when operating in a distributed environment with tensor parallelism and an LRU cache. The core problem involved incorrect model behavior when switching between or reloading LoRA adapters. The solution involves precise control over LoRA module merging and state management, ensuring that the correct LoRA weights are active and properly handled across parallel workers. A dedicated test case has been introduced to guarantee the stability and correctness of this fix.

Highlights

  • Bug Fix: Resolved an issue where vLLM produced unexpected output with greedy search when using multiple LoRAs, tensor parallelism (TP >= 2), and an LRU cache, especially after reloading or requesting a different LoRA.
  • LoRA Module Management: Implemented conditional logic in _create_merged_loras_inplace within vllm/lora/models.py to prevent the premature removal of replaced LoRA modules when tensor parallelism is enabled and specific LRU cache conditions are met.
  • Dynamic LoRA Reloading: Introduced a mechanism in vllm/lora/worker_manager.py to explicitly re-create merged LoRA weights (re_create_merged_loras) when an already loaded LoRA is requested under conditions involving an active LRU cache and tensor parallelism, ensuring correct model state.
  • New Test Case: Added a new comprehensive test file (tests/lora/test_multi_loras_with_tp.py) to specifically reproduce and validate the fix for the multi-LoRA, TP, and LRU cache interaction bug, including scenarios for adding, removing, and reloading LoRAs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in multi-LoRA serving when using tensor parallelism with an LRU cache. The core of the fix is to conditionally preserve LoRA module weights during packing and to re-pack them when a cached LoRA is activated. The addition of a new test case is excellent for ensuring this specific scenario is covered going forward.

My review comments focus on improving the readability and maintainability of the new code by renaming variables for clarity and simplifying complex conditional statements.

Comment on lines 681 to 683
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment is helpful but could be more precise. The condition for not removing modules is more complex than just tp >= 2. It would be beneficial to explain why this is needed for better future maintenance.

DO NOT Remove the modules that have been replaced when tp >= 2
            and cached lora > 1 and max_loras < max_cpu_loras.
            for case that run PackedLoRALayerWeights again

@charent charent force-pushed the fix_server_multi_loras_with_tp branch from 17bc12d to 7cff672 Compare July 13, 2025 08:47
@jeejeelee
Copy link
Collaborator

Thank you , I will look at this PR ASAP

@jeejeelee jeejeelee self-assigned this Jul 13, 2025
@jeejeelee
Copy link
Collaborator

Sorry for the delay response, could you explain why this bug occurs firstly

@charent
Copy link
Contributor Author

charent commented Jul 20, 2025

Sorry for the delay response, could you explain why this bug occurs firstly

@jeejeelee

In our production environment, we deployed multiple text-to-sql lora models. Due to the Insufficient GPU memory on a single GPU, so, tp=2 was set. Since the requests for the sql lora model are very frequent, in order to avoid reading the lora weight files from the disk for each request that different from previous lora request, in other words, avoid triggering disk IO frequently, max-cpu-loras > 1 was set to cache the lora in cpu memory. But soon after this setting was set, users reported that the sql generated by lora model was correct before, but now the WHERE condition is wrong. The model was not updated, only max-cpu-loras > 1 was set. Then, I debug vLLM source code found that it was caused by LRU Cache, so, a new condition is added to determine whether setting is enable TP and cache lora in cpu memory, When the condition is met, re-execute the PackedLoRALayerWeights.pack() function again to back to normal, that is why we don't remove the replaced lora modules.

an example of our case:

  • user input: what's Alice phone number?

vLLM config: tp = 2 and max-cpu-loras = 1 ( max-cpu-loras default setting was 1, same as max-loras)

SELECT 
    name, mobile
FROM user_info
WHERE name = 'Alice' and user_type = 1

above sql was correct.

but, when set vLLM config: tp = 2 and max-cpu-loras = 2
when previous lora request is another lora, this lora model output:

SELECT 
    name, mobile
FROM user_info
WHERE name_en = 'Alice'

this sql was NOT correct, miss the condition user_type = 1 and use another filed name_en (btw, name_en is exist in table user_info). This condition was trained by labeled data.

but when calling /v1/unload_lora_adapter and /v1/load_lora_adapter api to reload this lora again, it back to norml before another lora request come.

@jeejeelee
Copy link
Collaborator

Very sorry for keeping this PR in the queue for so long. Could you please test https://github.com/jeejeelee/vllm/blob/fix-lora-slice/vllm/lora/layers.py#L682 to verify whether it can fix this bug?

@charent
Copy link
Contributor Author

charent commented Jul 30, 2025

Very sorry for keeping this PR in the queue for so long. Could you please test https://github.com/jeejeelee/vllm/blob/fix-lora-slice/vllm/lora/layers.py#L682 to verify whether it can fix this bug?

yes, i replace the new slice_lora_b function to vllm v0.9.2 source code, it passed the test cases and fixed my problem. thank you. i found that is too complex slice logic of layers.py file, so a patch as this pr to solve this bug😂

@jeejeelee
Copy link
Collaborator

Very sorry for keeping this PR in the queue for so long. Could you please test https://github.com/jeejeelee/vllm/blob/fix-lora-slice/vllm/lora/layers.py#L682 to verify whether it can fix this bug?

yes, i replace the new slice_lora_b function to vllm v0.9.2 source code, it passed the test cases and fixed my problem. thank you. i found that is too complex slice logic of layers.py file, so a patch as this pr to solve this bug😂

Is it convenient for you to continue finishing this PR in this way?

@charent
Copy link
Contributor Author

charent commented Jul 30, 2025

Very sorry for keeping this PR in the queue for so long. Could you please test https://github.com/jeejeelee/vllm/blob/fix-lora-slice/vllm/lora/layers.py#L682 to verify whether it can fix this bug?

yes, i replace the new slice_lora_b function to vllm v0.9.2 source code, it passed the test cases and fixed my problem. thank you. i found that is too complex slice logic of layers.py file, so a patch as this pr to solve this bug😂

Is it convenient for you to continue finishing this PR in this way?

sorry? what can i do? close this PR or replace the code to the new slice_lora_b function?

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
@charent charent force-pushed the fix_server_multi_loras_with_tp branch from 7cff672 to c20d2f0 Compare July 30, 2025 14:12
@mergify mergify bot added the ci/build label Jul 30, 2025
@charent
Copy link
Contributor Author

charent commented Jul 30, 2025

Is it convenient for you to continue finishing this PR in this way?

i have updated the code of this PR, added config for auto testing.

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall LGTM, please fix the pre-commit failure

charent added 3 commits July 31, 2025 21:00
Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this fixing

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 31, 2025
@jeejeelee jeejeelee enabled auto-merge (squash) July 31, 2025 16:08
@vllm-bot vllm-bot merged commit ad57f23 into vllm-project:main Aug 1, 2025
102 of 109 checks passed
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Signed-off-by: Noam Gat <noamgat@gmail.com>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…#20873)

Signed-off-by: charent <19562666+charent@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants