-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
Description
Your current environment
The output of `python collect_env.py`
``INFO 04-15 18:22:30 [init.py:239] Automatically detected platform rocm.
Collecting environment information...
PyTorch version: 2.7.0a0+git28b68b4
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.3.42134-a9a80e791
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 18.0.0git (https://github.com/RadeonOpenCompute/llvm-project roc-6.3.4 25012 e5bf7e55c91490b07c49d8960fa7983d864936c4)
CMake version: version 3.31.6
Libc version: glibc-2.35
Python version: 3.12.9 | packaged by Anaconda, Inc. | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-21-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.3.42134
MIOpen runtime version: 3.3.0
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 224
On-line CPU(s) list: 0-223
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8480C
CPU family: 6
Model: 143
Thread(s) per core: 2
Core(s) per socket: 56
Socket(s): 2
Stepping: 8
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4000.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req vnmi avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr ibt amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 5.3 MiB (112 instances)
L1i cache: 3.5 MiB (112 instances)
L2 cache: 224 MiB (112 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-55,112-167
NUMA node1 CPU(s): 56-111,168-223
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.14.1
[pip3] pyzmq==26.3.0
[pip3] torch==2.7.0a0+git28b68b4
[pip3] torchaudio==2.6.0a0+c670ad8
[pip3] torchvision==0.22.0a0+124dfa4
[pip3] transformers==4.50.0
[pip3] triton==3.3.0+gitb962e444
[conda] No relevant packages
ROCM Version: 6.3.42134-a9a80e791
Neuron SDK Version: N/A
vLLM Version: 0.8.3.dev19+g3eb08ed9
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
============================ ROCm System Management Interface ============================
================================ Weight between two GPUs =================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 15 15 15 15 15 15 15
GPU1 15 0 15 15 15 15 15 15
GPU2 15 15 0 15 15 15 15 15
GPU3 15 15 15 0 15 15 15 15
GPU4 15 15 15 15 0 15 15 15
GPU5 15 15 15 15 15 0 15 15
GPU6 15 15 15 15 15 15 0 15
GPU7 15 15 15 15 15 15 15 0
================================= Hops between two GPUs ==================================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 1 1 1 1 1 1 1
GPU1 1 0 1 1 1 1 1 1
GPU2 1 1 0 1 1 1 1 1
GPU3 1 1 1 0 1 1 1 1
GPU4 1 1 1 1 0 1 1 1
GPU5 1 1 1 1 1 0 1 1
GPU6 1 1 1 1 1 1 0 1
GPU7 1 1 1 1 1 1 1 0
=============================== Link Type between two GPUs ===============================
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7
GPU0 0 XGMI XGMI XGMI XGMI XGMI XGMI XGMI
GPU1 XGMI 0 XGMI XGMI XGMI XGMI XGMI XGMI
GPU2 XGMI XGMI 0 XGMI XGMI XGMI XGMI XGMI
GPU3 XGMI XGMI XGMI 0 XGMI XGMI XGMI XGMI
GPU4 XGMI XGMI XGMI XGMI 0 XGMI XGMI XGMI
GPU5 XGMI XGMI XGMI XGMI XGMI 0 XGMI XGMI
GPU6 XGMI XGMI XGMI XGMI XGMI XGMI 0 XGMI
GPU7 XGMI XGMI XGMI XGMI XGMI XGMI XGMI 0
======================================= Numa Nodes =======================================
GPU[0] : (Topology) Numa Node: 0
GPU[0] : (Topology) Numa Affinity: 0
GPU[1] : (Topology) Numa Node: 0
GPU[1] : (Topology) Numa Affinity: 0
GPU[2] : (Topology) Numa Node: 0
GPU[2] : (Topology) Numa Affinity: 0
GPU[3] : (Topology) Numa Node: 0
GPU[3] : (Topology) Numa Affinity: 0
GPU[4] : (Topology) Numa Node: 1
GPU[4] : (Topology) Numa Affinity: 1
GPU[5] : (Topology) Numa Node: 1
GPU[5] : (Topology) Numa Affinity: 1
GPU[6] : (Topology) Numa Node: 1
GPU[6] : (Topology) Numa Affinity: 1
GPU[7] : (Topology) Numa Node: 1
GPU[7] : (Topology) Numa Affinity: 1
================================== End of ROCm SMI Log ===================================
PYTORCH_ROCM_ARCH=gfx90a;gfx942
LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
### 🐛 Describe the bug
I’m trying to enable LoRA Adapter with vLLM on AMD MI300x GPU’s.
Below are the steps that I followed to enable lora dynamically:
**VLLM_ALLOW_RUNTIME_LORA_UPDATING=True python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.2-3B-Instruct --served-model-name Llama-3.2-3B-Instruct --enable-lora --max-lora-rank 64**
Below are the logs:
`INFO 04-14 16:15:31 [init.py:239] Automatically detected platform rocm.
WARNING 04-14 16:15:32 [api_server.py:759] LoRA dynamic loading & unloading is enabled in the API server. This should ONLY be used for local development!
INFO 04-14 16:15:32 [api_server.py:981] vLLM API server version 0.8.3.dev19+g3eb08ed9
INFO 04-14 16:15:32 [api_server.py:982] args: Namespace(host=None, port=8000, uvicorn_log_level=‘info’, disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=[‘‘], allowed_methods=[’’], allowed_headers=[‘’], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format=‘auto’, response_role=‘assistant’, ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=, return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin=‘’, model=‘meta-llama/Llama-3.2-3B-Instruct’, task=‘auto’, tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode=‘auto’, trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format=‘auto’, config_format=<ConfigFormat.AUTO: ‘auto’>, dtype=‘auto’, kv_cache_dtype=‘auto’, max_model_len=None, guided_decoding_backend=‘xgrammar’, logits_processor_pattern=None, model_impl=‘auto’, distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type=‘ray’, tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=True, enable_lora_bias=False, max_loras=1, max_lora_rank=64, lora_extra_vocab_size=256, lora_dtype=‘auto’, long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device=‘auto’, num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_config=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method=‘rejection_sampler’, typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=, preemption_mode=None, served_model_name=[‘Llama-3.2-3B-Instruct’], qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy=‘fcfs’, scheduler_cls=‘vllm.core.scheduler.Scheduler’, override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls=‘auto’, worker_extension_cls=‘’, generation_config=‘auto’, override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False)
INFO 04-14 16:15:44 [config.py:585] This model supports multiple tasks: {‘embed’, ‘generate’, ‘classify’, ‘score’, ‘reward’}. Defaulting to ‘generate’.
INFO 04-14 16:15:44 [arg_utils.py:1868] LORA is experimental on VLLM_USE_V1=1. Falling back to V0 Engine.
WARNING 04-14 16:15:44 [arg_utils.py:1744] The model has a long context length (131072). This may causeOOM during the initial memory profiling phase, or result in low performance due to small KV cache size. Consider setting --max-model-len to a smaller value.
INFO 04-14 16:15:44 [config.py:1552] Disabled the custom all-reduce kernel because it is not supported on AMD GPUs.
INFO 04-14 16:15:44 [api_server.py:241] Started engine process with PID 161
INFO 04-14 16:15:47 [init.py:239] Automatically detected platform rocm.
WARNING 04-14 16:15:48 [api_server.py:759] LoRA dynamic loading & unloading is enabled in the API server. This should ONLY be used for local development!
INFO 04-14 16:15:48 [llm_engine.py:241] Initializing a V0 LLM engine (v0.8.3.dev19+g3eb08ed9) with config: model=‘meta-llama/Llama-3.2-3B-Instruct’, speculative_config=None, tokenizer=‘meta-llama/Llama-3.2-3B-Instruct’, skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend=‘xgrammar’, reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=Llama-3.2-3B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={“splitting_ops”:,“compile_sizes”:,“cudagraph_capture_sizes”:[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],“max_capture_size”:256}, use_cached_outputs=True,
INFO 04-14 16:15:54 [rocm.py:131] None is not supported in AMD GPUs.
INFO 04-14 16:15:54 [rocm.py:132] Using ROCmFlashAttention backend.
INFO 04-14 16:15:54 [parallel_state.py:954] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 04-14 16:15:54 [model_runner.py:1110] Starting to load model meta-llama/Llama-3.2-3B-Instruct…
INFO 04-14 16:15:55 [weight_utils.py:265] Using model weights format [‘.safetensors’]
Loading safetensors checkpoint shards: 0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 50% Completed | 1/2 [00:00<00:00, 1.51it/s]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.55s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00, 1.42s/it]
INFO 04-14 16:15:58 [loader.py:447] Loading weights took 3.02 seconds
INFO 04-14 16:15:58 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 04-14 16:15:58 [model_runner.py:1146] Model loading took 7.2754 GB and 3.961958 seconds
ERROR 04-14 16:15:59 [engine.py:448] ‘Keyword argument maxnreg was specified but unrecognised’
ERROR 04-14 16:15:59 [engine.py:448] Traceback (most recent call last):
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 436, in run_mp_engine
ERROR 04-14 16:15:59 [engine.py:448] engine = MQLLMEngine.from_vllm_config(
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 128, in from_vllm_config
ERROR 04-14 16:15:59 [engine.py:448] return cls(
ERROR 04-14 16:15:59 [engine.py:448] ^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 82, in init
ERROR 04-14 16:15:59 [engine.py:448] self.engine = LLMEngine(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/llm_engine.py”, line 283, in init
ERROR 04-14 16:15:59 [engine.py:448] self._initialize_kv_caches()
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/llm_engine.py”, line 432, in _initialize_kv_caches
ERROR 04-14 16:15:59 [engine.py:448] self.model_executor.determine_num_available_blocks())
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/executor/executor_base.py”, line 102, in determine_num_available_blocks
ERROR 04-14 16:15:59 [engine.py:448] results = self.collective_rpc(“determine_num_available_blocks”)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py”, line 56, in collective_rpc
ERROR 04-14 16:15:59 [engine.py:448] answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/utils.py”, line 2255, in run_method
ERROR 04-14 16:15:59 [engine.py:448] return func(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
ERROR 04-14 16:15:59 [engine.py:448] return func(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/worker.py”, line 229, in determine_num_available_blocks
ERROR 04-14 16:15:59 [engine.py:448] self.model_runner.profile_run()
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
ERROR 04-14 16:15:59 [engine.py:448] return func(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/model_runner.py”, line 1243, in profile_run
ERROR 04-14 16:15:59 [engine.py:448] self._dummy_run(max_num_batched_tokens, max_num_seqs)
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/model_runner.py”, line 1354, in _dummy_run
ERROR 04-14 16:15:59 [engine.py:448] self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
ERROR 04-14 16:15:59 [engine.py:448] return func(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/model_runner.py”, line 1742, in execute_model
ERROR 04-14 16:15:59 [engine.py:448] hidden_or_intermediate_states = model_executable(
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1751, in _wrapped_call_impl
ERROR 04-14 16:15:59 [engine.py:448] return self._call_impl(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1762, in _call_impl
ERROR 04-14 16:15:59 [engine.py:448] return forward_call(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/llama.py”, line 529, in forward
ERROR 04-14 16:15:59 [engine.py:448] model_output = self.model(input_ids, positions, intermediate_tensors,
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/compilation/decorators.py”, line 172, in call
ERROR 04-14 16:15:59 [engine.py:448] return self.forward(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/llama.py”, line 350, in forward
ERROR 04-14 16:15:59 [engine.py:448] hidden_states = self.get_input_embeddings(input_ids)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/llama.py”, line 337, in get_input_embeddings
ERROR 04-14 16:15:59 [engine.py:448] return self.embed_tokens(input_ids)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1751, in _wrapped_call_impl
ERROR 04-14 16:15:59 [engine.py:448] return self._call_impl(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1762, in _call_impl
ERROR 04-14 16:15:59 [engine.py:448] return forward_call(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/lora/layers.py”, line 264, in forward
ERROR 04-14 16:15:59 [engine.py:448] self.punica_wrapper.add_lora_embedding(full_output,
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py”, line 176, in add_lora_embedding
ERROR 04-14 16:15:59 [engine.py:448] lora_expand(
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_ops.py”, line 1158, in call
ERROR 04-14 16:15:59 [engine.py:448] return self._op(*args, **(kwargs or {}))
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
ERROR 04-14 16:15:59 [engine.py:448] return func(*args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/lora/ops/triton_ops/lora_expand.py”, line 219, in _lora_expand
ERROR 04-14 16:15:59 [engine.py:448] _lora_expand_kernel[grid](
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/triton/runtime/jit.py”, line 368, in
ERROR 04-14 16:15:59 [engine.py:448] return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
ERROR 04-14 16:15:59 [engine.py:448] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-14 16:15:59 [engine.py:448] File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/triton/runtime/jit.py”, line 596, in run
ERROR 04-14 16:15:59 [engine.py:448] raise KeyError(“Keyword argument %s was specified but unrecognised” % k)
ERROR 04-14 16:15:59 [engine.py:448] KeyError: ‘Keyword argument maxnreg was specified but unrecognised’
Process SpawnProcess-1:
Traceback (most recent call last):
File “/opt/conda/envs/py_3.12/lib/python3.12/multiprocessing/process.py”, line 314, in _bootstrap
self.run()
File “/opt/conda/envs/py_3.12/lib/python3.12/multiprocessing/process.py”, line 108, in run
self._target(*self._args, **self._kwargs)
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 450, in run_mp_engine
raise e
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 436, in run_mp_engine
engine = MQLLMEngine.from_vllm_config(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 128, in from_vllm_config
return cls(
^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/multiprocessing/engine.py”, line 82, in init
self.engine = LLMEngine(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/llm_engine.py”, line 283, in init
self._initialize_kv_caches()
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/engine/llm_engine.py”, line 432, in _initialize_kv_caches
self.model_executor.determine_num_available_blocks())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/executor/executor_base.py”, line 102, in determine_num_available_blocks
results = self.collective_rpc(“determine_num_available_blocks”)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py”, line 56, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/utils.py”, line 2255, in run_method
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/worker.py”, line 229, in determine_num_available_blocks
self.model_runner.profile_run()
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/model_runner.py”, line 1243, in profile_run
self._dummy_run(max_num_batched_tokens, max_num_seqs)
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/model_runner.py”, line 1354, in _dummy_run
self.execute_model(model_input, kv_caches, intermediate_tensors)
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/worker/model_runner.py”, line 1742, in execute_model
hidden_or_intermediate_states = model_executable(
^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/llama.py”, line 529, in forward
model_output = self.model(input_ids, positions, intermediate_tensors,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/compilation/decorators.py”, line 172, in call
return self.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/llama.py”, line 350, in forward
hidden_states = self.get_input_embeddings(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/model_executor/models/llama.py”, line 337, in get_input_embeddings
return self.embed_tokens(input_ids)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/nn/modules/module.py”, line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/lora/layers.py”, line 264, in forward
self.punica_wrapper.add_lora_embedding(full_output,
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py”, line 176, in add_lora_embedding
lora_expand(
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/_ops.py”, line 1158, in call
return self._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/torch/utils/_contextlib.py”, line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/lora/ops/triton_ops/lora_expand.py”, line 219, in _lora_expand
_lora_expand_kernel[grid](
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/triton/runtime/jit.py”, line 368, in
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/triton/runtime/jit.py”, line 596, in run
raise KeyError(“Keyword argument %s was specified but unrecognised” % k)
KeyError: ‘Keyword argument maxnreg was specified but unrecognised’
[rank0]:[W414 16:15:59.861372715 ProcessGroupNCCL.cpp:1477] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see [Distributed communication package - torch.distributed — PyTorch 2.6 documentation](https://pytorch.org/docs/stable/distributed.html#shutdown) (function operator())
Traceback (most recent call last):
File “”, line 198, in _run_module_as_main
File “”, line 88, in _run_code
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py”, line 1066, in
uvloop.run(run_server(args))
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/uvloop/init.py”, line 109, in run
return __asyncio.run(
^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/asyncio/runners.py”, line 195, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/asyncio/runners.py”, line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “uvloop/loop.pyx”, line 1518, in uvloop.loop.Loop.run_until_complete
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/uvloop/init.py”, line 61, in wrapper
return await main
^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py”, line 1016, in run_server
async with build_async_engine_client(args) as engine_client:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/contextlib.py”, line 210, in aenter
return await anext(self.gen)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py”, line 141, in build_async_engine_client
async with build_async_engine_client_from_engine_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/contextlib.py”, line 210, in aenter
return await anext(self.gen)
^^^^^^^^^^^^^^^^^^^^^
File “/opt/conda/envs/py_3.12/lib/python3.12/site-packages/vllm/entrypoints/openai/api_server.py”, line 264, in build_async_engine_client_from_engine_args
### Before submitting a new issue...
- [x] Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions.