-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
The output of python collect_env.py
Collecting environment information...
==============================
System Info
==============================
OS : Ubuntu 22.04.5 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : Could not collect
CMake version : version 4.0.3
Libc version : glibc-2.35
==============================
PyTorch Info
==============================
PyTorch version : 2.7.1+cu128
Is debug build : False
CUDA used to build PyTorch : 12.8
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.12.11 (main, Jun 4 2025, 08:56:18) [GCC 11.4.0] (64-bit runtime)
Python platform : Linux-5.15.0-25-generic-x86_64-with-glibc2.35
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : 12.8.93
CUDA_MODULE_LOADING set to : LAZY
GPU models and configuration :
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version : 535.86.10
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8468
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
Stepping: 8
CPU max MHz: 3800.0000
CPU min MHz: 800.0000
BogoMIPS: 4200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr avx512_fp16 flush_l1d arch_capabilities
L1d cache: 4.5 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 192 MiB (96 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-11
NUMA node1 CPU(s): 12-23
NUMA node2 CPU(s): 24-35
NUMA node3 CPU(s): 36-47
NUMA node4 CPU(s): 48-59
NUMA node5 CPU(s): 60-71
NUMA node6 CPU(s): 72-83
NUMA node7 CPU(s): 84-95
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.3.14
[pip3] nvidia-cuda-cupti-cu12==12.8.57
[pip3] nvidia-cuda-nvrtc-cu12==12.8.61
[pip3] nvidia-cuda-runtime-cu12==12.8.57
[pip3] nvidia-cudnn-cu12==9.7.1.26
[pip3] nvidia-cufft-cu12==11.3.3.41
[pip3] nvidia-cufile-cu12==1.13.0.11
[pip3] nvidia-curand-cu12==10.3.9.55
[pip3] nvidia-cusolver-cu12==11.7.2.55
[pip3] nvidia-cusparse-cu12==12.5.7.53
[pip3] nvidia-cusparselt-cu12==0.6.3
[pip3] nvidia-ml-py==12.575.51
[pip3] nvidia-nccl-cu12==2.26.2
[pip3] nvidia-nvjitlink-cu12==12.8.61
[pip3] nvidia-nvshmem-cu12==3.3.9
[pip3] nvidia-nvtx-cu12==12.8.55
[pip3] pynvml==12.0.0
[pip3] pyzmq==27.0.0
[pip3] torch==2.7.1+cu128
[pip3] torchaudio==2.7.1+cu128
[pip3] torchvision==0.22.1+cu128
[pip3] transformers==4.54.1
[pip3] triton==3.3.1
[conda] Could not collect
==============================
vLLM Info
==============================
ROCM Version : Could not collect
Neuron SDK Version : N/A
vLLM Version : 0.10.1.dev1+gbcc0a3cbe (git sha: bcc0a3cbe)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 NIC5 NIC6 NIC7 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV18 NV18 NV18 NV18 NV18 NV18 NV18 PIX SYS SYS SYS SYS SYS SYS SYS 0-11 0 N/A
GPU1 NV18 X NV18 NV18 NV18 NV18 NV18 NV18 PXB SYS SYS SYS SYS SYS SYS SYS 0-11 0 N/A
GPU2 NV18 NV18 X NV18 NV18 NV18 NV18 NV18 SYS PIX PIX PXB SYS SYS SYS SYS 24-35 2 N/A
GPU3 NV18 NV18 NV18 X NV18 NV18 NV18 NV18 SYS PXB PXB PIX SYS SYS SYS SYS 24-35 2 N/A
GPU4 NV18 NV18 NV18 NV18 X NV18 NV18 NV18 SYS SYS SYS SYS PIX PIX PXB SYS 48-59 4 N/A
GPU5 NV18 NV18 NV18 NV18 NV18 X NV18 NV18 SYS SYS SYS SYS PXB PXB PIX SYS 48-59 4 N/A
GPU6 NV18 NV18 NV18 NV18 NV18 NV18 X NV18 SYS SYS SYS SYS SYS SYS SYS PXB 72-83 6 N/A
GPU7 NV18 NV18 NV18 NV18 NV18 NV18 NV18 X SYS SYS SYS SYS SYS SYS SYS PIX 72-83 6 N/A
NIC0 PIX PXB SYS SYS SYS SYS SYS SYS X SYS SYS SYS SYS SYS SYS SYS
NIC1 SYS SYS PIX PXB SYS SYS SYS SYS SYS X PIX PXB SYS SYS SYS SYS
NIC2 SYS SYS PIX PXB SYS SYS SYS SYS SYS PIX X PXB SYS SYS SYS SYS
NIC3 SYS SYS PXB PIX SYS SYS SYS SYS SYS PXB PXB X SYS SYS SYS SYS
NIC4 SYS SYS SYS SYS PIX PXB SYS SYS SYS SYS SYS SYS X PIX PXB SYS
NIC5 SYS SYS SYS SYS PIX PXB SYS SYS SYS SYS SYS SYS PIX X PXB SYS
NIC6 SYS SYS SYS SYS PXB PIX SYS SYS SYS SYS SYS SYS PXB PXB X SYS
NIC7 SYS SYS SYS SYS SYS SYS PXB PIX SYS SYS SYS SYS SYS SYS SYS X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_2
NIC3: mlx5_3
NIC4: mlx5_4
NIC5: mlx5_5
NIC6: mlx5_6
NIC7: mlx5_7
==============================
Environment Variables
==============================
NVIDIA_VISIBLE_DEVICES=GPU-50b2cd4b-06fa-ace5-8239-067fd5427f37,GPU-7f9ccb0e-3137-3824-20bc-caf8028689c4,GPU-17ebd433-a406-f0a9-45e6-355cf77286a3,GPU-73c7f4fb-63b3-b200-7ea9-a683253c58a6,GPU-3f7763a0-bf05-acc7-9a05-e185b8ba1f11,GPU-cc1f9d6a-7ce7-9da1-98f8-61a911eea94a,GPU-30426c61-7f0c-4d66-91fe-4e2e44dc3336,GPU-68649616-e781-6f73-41fd-b68b86e19eb9
NVIDIA_REQUIRE_CUDA=cuda>=12.8 brand=unknown,driver>=470,driver<471 brand=grid,driver>=470,driver<471 brand=tesla,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=vapps,driver>=470,driver<471 brand=vpc,driver>=470,driver<471 brand=vcs,driver>=470,driver<471 brand=vws,driver>=470,driver<471 brand=cloudgaming,driver>=470,driver<471 brand=unknown,driver>=535,driver<536 brand=grid,driver>=535,driver<536 brand=tesla,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=vapps,driver>=535,driver<536 brand=vpc,driver>=535,driver<536 brand=vcs,driver>=535,driver<536 brand=vws,driver>=535,driver<536 brand=cloudgaming,driver>=535,driver<536 brand=unknown,driver>=550,driver<551 brand=grid,driver>=550,driver<551 brand=tesla,driver>=550,driver<551 brand=nvidia,driver>=550,driver<551 brand=quadro,driver>=550,driver<551 brand=quadrortx,driver>=550,driver<551 brand=nvidiartx,driver>=550,driver<551 brand=vapps,driver>=550,driver<551 brand=vpc,driver>=550,driver<551 brand=vcs,driver>=550,driver<551 brand=vws,driver>=550,driver<551 brand=cloudgaming,driver>=550,driver<551 brand=unknown,driver>=560,driver<561 brand=grid,driver>=560,driver<561 brand=tesla,driver>=560,driver<561 brand=nvidia,driver>=560,driver<561 brand=quadro,driver>=560,driver<561 brand=quadrortx,driver>=560,driver<561 brand=nvidiartx,driver>=560,driver<561 brand=vapps,driver>=560,driver<561 brand=vpc,driver>=560,driver<561 brand=vcs,driver>=560,driver<561 brand=vws,driver>=560,driver<561 brand=cloudgaming,driver>=560,driver<561 brand=unknown,driver>=565,driver<566 brand=grid,driver>=565,driver<566 brand=tesla,driver>=565,driver<566 brand=nvidia,driver>=565,driver<566 brand=quadro,driver>=565,driver<566 brand=quadrortx,driver>=565,driver<566 brand=nvidiartx,driver>=565,driver<566 brand=vapps,driver>=565,driver<566 brand=vpc,driver>=565,driver<566 brand=vcs,driver>=565,driver<566 brand=vws,driver>=565,driver<566 brand=cloudgaming,driver>=565,driver<566
NCCL_VERSION=2.25.1-1
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVIDIA_PRODUCT_NAME=CUDA
VLLM_USAGE_SOURCE=production-docker-image
CUDA_VERSION=12.8.1
LD_LIBRARY_PATH=/usr/local/cuda/lib64
OMP_NUM_THREADS=56
VLLM_USE_V1=1
NCCL_CUMEM_ENABLE=0
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
When attempting to load the QuixiAI/DeepSeek-R1-AWQ model from huggingface in vLLM v0.10.0 or later, an error occurs.
Up to vLLM v0.9.2, the model loaded normally and inference was possible.
I'm using 8xH100, The parameters used when running vLLM are as follows.
- args:
- --model
- /models/deepseek-r1-awq
- --tensor-parallel-size
- "8"
- --load-format
- "auto"
- --max-model-len
- "65536"
- --max-seq-len-to-capture
- "65536"
- --disable-log-requests
- --uvicorn-log-level
- "warning"
- --gpu-memory-utilization
- "0.95"
- --trust-remote-code
- --enable-prefix-caching
- --prefix-caching-hash-algo #not supported in v0 engine
- "sha256"
- --reasoning-parser
- "deepseek_r1"
env:
- name: VLLM_USE_V1
value: "1"
The error log is as follows.
INFO 08-25 00:08:45 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:08:47 [api_server.py:1755] vLLM API server version 0.10.1.dev1+gbcc0a3cbe
INFO 08-25 00:08:47 [cli_args.py:261] non-default args: {'uvicorn_log_level': 'warning', 'model': '/data/models/DeepSeek-R1-awq-64g', 'trust_remote_code': True, 'max_model_len': 65536, 'max_seq_len_to_capture': 65536, 'served_model_name': ['/mnt/models', 'llm'], 'reasoning_parser': 'deepseek_r1', 'tensor_parallel_size': 8, 'gpu_memory_utilization': 0.95, 'enable_prefix_caching': True, 'prefix_caching_hash_algo': 'sha256', 'disable_log_requests': True}
INFO 08-25 00:08:54 [config.py:1604] Using max model len 65536
INFO 08-25 00:08:55 [awq_marlin.py:116] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 08-25 00:08:55 [config.py:2434] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 08-25 00:08:55 [cuda.py:162] Forcing kv cache block size to 64 for FlashMLA backend.
INFO 08-25 00:08:59 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:01 [core.py:572] Waiting for init message from front-end.
INFO 08-25 00:09:01 [core.py:71] Initializing a V1 LLM engine (v0.10.1.dev1+gbcc0a3cbe) with config: model='/data/models/DeepSeek-R1-awq-64g', speculative_config=None, tokenizer='/data/models/DeepSeek-R1-awq-64g', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config={}, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=65536, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, decoding_config=DecodingConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_backend='deepseek_r1'), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=/mnt/models, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, pooler_config=None, compilation_config={"level":3,"debug_dump_path":"","cache_dir":"","backend":"","custom_ops":[],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output","vllm.mamba_mixer2"],"use_inductor":true,"compile_sizes":[],"inductor_compile_config":{"enable_auto_functionalized_v2":false},"inductor_passes":{},"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"max_capture_size":512,"local_cache_dir":null}
INFO 08-25 00:09:01 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1, 2, 3, 4, 5, 6, 7], buffer_handle=(8, 16777216, 10, 'psm_79d2e77e'), local_subscribe_addr='ipc:///tmp/f2629886-8edf-4652-b16e-6af7a3a5691b', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
INFO 08-25 00:09:05 [__init__.py:235] Automatically detected platform cuda.
(VllmWorker rank=6 pid=360) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_ccc59cd2'), local_subscribe_addr='ipc:///tmp/3ddac3b2-b8ac-46e0-af1f-84c154cda744', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=7 pid=361) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_42a8e27d'), local_subscribe_addr='ipc:///tmp/1c92286a-21d0-498e-8f3a-ae8c1d59139a', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=2 pid=356) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_5569a75b'), local_subscribe_addr='ipc:///tmp/3e1c5d92-0d6a-4d92-8301-a740bb326887', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=5 pid=359) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_564b5850'), local_subscribe_addr='ipc:///tmp/db31aa26-c40f-4d43-8b17-7fd741ee300a', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=355) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_4b6a719a'), local_subscribe_addr='ipc:///tmp/409ce1c1-0a01-4390-83a5-ca7d396c794d', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=3 pid=357) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_086317fb'), local_subscribe_addr='ipc:///tmp/7b737e7a-faaa-42a6-85ac-80ab79c9f999', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=4 pid=358) INFO 08-25 00:09:14 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_314a059f'), local_subscribe_addr='ipc:///tmp/97e43af5-c363-4e2c-af7e-cc1fff6b7a78', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=0 pid=354) INFO 08-25 00:09:15 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_ffe9a821'), local_subscribe_addr='ipc:///tmp/2d07bf57-40d8-42a8-a90b-f46ba455cf51', remote_subscribe_addr=None, remote_addr_ipv6=False)
[W825 00:09:16.143641377 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.157914080 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.159522182 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.166204194 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.166204568 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.182869458 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.185927266 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
[W825 00:09:16.191239289 socket.cpp:755] [c10d] The client socket cannot be initialized to connect to [localhost]:40061 (errno: 97 - Address family not supported by protocol).
(VllmWorker rank=4 pid=358) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=5 pid=359) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=3 pid=357) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=4 pid=358) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=1 pid=355) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=6 pid=360) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=5 pid=359) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=3 pid=357) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=1 pid=355) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=0 pid=354) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=6 pid=360) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=2 pid=356) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=354) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=7 pid=361) INFO 08-25 00:09:17 [__init__.py:1375] Found nccl from library libnccl.so.2
(VllmWorker rank=2 pid=356) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=7 pid=361) INFO 08-25 00:09:17 [pynccl.py:70] vLLM is using nccl==2.26.2
(VllmWorker rank=0 pid=354) INFO 08-25 00:09:20 [custom_all_reduce_utils.py:208] generating GPU P2P access cache in /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=7 pid=361) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=4 pid=358) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=6 pid=360) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=5 pid=359) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=3 pid=357) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=2 pid=356) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=1 pid=355) INFO 08-25 00:10:01 [custom_all_reduce_utils.py:246] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3,4,5,6,7.json
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:01 [shm_broadcast.py:289] vLLM message queue communication handle: Handle(local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer_handle=(7, 4194304, 6, 'psm_638d7cad'), local_subscribe_addr='ipc:///tmp/1b0b5129-c7ec-4ffe-9210-bde21fa068f1', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=355) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 1 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 1, EP rank 1
(VllmWorker rank=7 pid=361) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 7 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 7, EP rank 7
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 0 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(VllmWorker rank=2 pid=356) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 2 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 2, EP rank 2
(VllmWorker rank=3 pid=357) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 3 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 3, EP rank 3
(VllmWorker rank=1 pid=355) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=4 pid=358) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 4 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 4, EP rank 4
(VllmWorker rank=5 pid=359) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 5 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 5, EP rank 5
(VllmWorker rank=6 pid=360) INFO 08-25 00:10:01 [parallel_state.py:1102] rank 6 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 6, EP rank 6
(VllmWorker rank=7 pid=361) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=2 pid=356) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=3 pid=357) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=4 pid=358) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=6 pid=360) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=5 pid=359) INFO 08-25 00:10:01 [topk_topp_sampler.py:49] Using FlashInfer for top-p & top-k sampling.
(VllmWorker rank=1 pid=355) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=7 pid=361) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=4 pid=358) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=3 pid=357) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=5 pid=359) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=6 pid=360) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=2 pid=356) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:02 [gpu_model_runner.py:1843] Starting to load model /data/models/DeepSeek-R1-awq-64g...
(VllmWorker rank=1 pid=355) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=6 pid=360) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=1 pid=355) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=7 pid=361) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=3 pid=357) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=4 pid=358) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=6 pid=360) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=5 pid=359) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=2 pid=356) INFO 08-25 00:10:02 [gpu_model_runner.py:1875] Loading model from scratch...
(VllmWorker rank=7 pid=361) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=0 pid=354) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=3 pid=357) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=4 pid=358) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=5 pid=359) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
(VllmWorker rank=2 pid=356) INFO 08-25 00:10:02 [cuda.py:231] Using FlashMLA backend on V1 engine.
Loading safetensors checkpoint shards: 0% Completed | 0/74 [00:00<?, ?it/s]
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] WorkerProc failed to start.
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] Traceback (most recent call last):
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 485, in worker_main
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] worker = WorkerProc(*args, **kwargs)
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 382, in __init__
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] self.worker.load_model()
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_worker.py", line 201, in load_model
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] self.model_runner.load_model(eep_scale_up=eep_scale_up)
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/worker/gpu_model_runner.py", line 1876, in load_model
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] self.model = model_loader.load_model(
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/base_loader.py", line 49, in load_model
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] self.load_weights(model, model_config)
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/model_loader/default_loader.py", line 259, in load_weights
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] loaded_weights = model.load_weights(
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] ^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 907, in load_weights
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] weight_loader(param, loaded_weight, shard_id)
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/linear.py", line 443, in weight_loader
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] param[shard_offset:shard_offset + shard_size] = loaded_weight
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] ~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorker rank=1 pid=355) ERROR 08-25 00:10:16 [multiproc_executor.py:511] RuntimeError: The expanded size of the tensor (264) must match the existing size (72) at non-singleton dimension 1. Target sizes: [576, 264]. Tensor sizes: [7168, 72]
Loading safetensors checkpoint shards: 0% Completed | 0/74 [00:14<?, ?it/s]
(VllmWorker rank=0 pid=354)
[rank0]:[W825 00:10:18.050044525 ProcessGroupNCCL.cpp:1479] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
ERROR 08-25 00:10:20 [core.py:632] EngineCore failed to start.
Process EngineCore_0:
ERROR 08-25 00:10:20 [core.py:632] Traceback (most recent call last):
ERROR 08-25 00:10:20 [core.py:632] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 623, in run_engine_core
ERROR 08-25 00:10:20 [core.py:632] engine_core = EngineCoreProc(*args, **kwargs)
ERROR 08-25 00:10:20 [core.py:632] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-25 00:10:20 [core.py:632] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 441, in __init__
ERROR 08-25 00:10:20 [core.py:632] super().__init__(vllm_config, executor_class, log_stats,
ERROR 08-25 00:10:20 [core.py:632] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 77, in __init__
ERROR 08-25 00:10:20 [core.py:632] self.model_executor = executor_class(vllm_config)
ERROR 08-25 00:10:20 [core.py:632] ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-25 00:10:20 [core.py:632] File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 53, in __init__
ERROR 08-25 00:10:20 [core.py:632] self._init_executor()
ERROR 08-25 00:10:20 [core.py:632] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 94, in _init_executor
ERROR 08-25 00:10:20 [core.py:632] self.workers = WorkerProc.wait_for_ready(unready_workers)
ERROR 08-25 00:10:20 [core.py:632] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-25 00:10:20 [core.py:632] File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 446, in wait_for_ready
ERROR 08-25 00:10:20 [core.py:632] raise e from None
ERROR 08-25 00:10:20 [core.py:632] Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
Traceback (most recent call last):
File "/usr/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/usr/lib/python3.12/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 636, in run_engine_core
raise e
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 623, in run_engine_core
engine_core = EngineCoreProc(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 441, in __init__
super().__init__(vllm_config, executor_class, log_stats,
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/engine/core.py", line 77, in __init__
self.model_executor = executor_class(vllm_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 53, in __init__
self._init_executor()
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 94, in _init_executor
self.workers = WorkerProc.wait_for_ready(unready_workers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/vllm/v1/executor/multiproc_executor.py", line 446, in wait_for_ready
raise e from None
Exception: WorkerProc initialization failed due to an exception in a background process. See stack trace for root cause.
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
bbartels and josephrocca
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
Type
Projects
Status
Backlog