-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
The output of `python collect_env.py`
INFO 03-21 04:23:58 [__init__.py:256] Automatically detected platform cuda.
Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.12.8 (main, Dec 6 2024, 19:59:28) [Clang 18.1.8 ] (64-bit runtime)
Python platform: Linux-6.8.0-40-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100 80GB PCIe
GPU 1: NVIDIA A100 80GB PCIe
Nvidia driver version: 535.183.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 60
On-line CPU(s) list: 0-59
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7763 64-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 1
Core(s) per socket: 30
Socket(s): 2
Stepping: 1
BogoMIPS: 4890.80
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm rep_good nopl cpuid extd_apicid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy svm cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw perfctr_core ssbd ibrs ibpb stibp vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves clzero xsaveerptr wbnoinvd arat npt lbrv nrip_save tsc_scale vmcb_clean pausefilter pfthreshold v_vmsave_vmload vgif umip pku ospke vaes vpclmulqdq rdpid fsrm arch_capabilities
Virtualization: AMD-V
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 3.8 MiB (60 instances)
L1i cache: 3.8 MiB (60 instances)
L2 cache: 30 MiB (60 instances)
L3 cache: 960 MiB (60 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-29
NUMA node1 CPU(s): 30-59
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] onnxruntime==1.21.0
[pip3] pyzmq==26.3.0
[pip3] torch==2.6.0
[pip3] torchaudio==2.6.0
[pip3] torchvision==0.21.0
[pip3] transformers==4.50.0.dev0
[pip3] triton==3.2.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.8.1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X PHB 0-59 0-1 N/A
GPU1 PHB X 0-59 0-1 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY
🐛 Describe the bug
Model is meta-llama/Llama-3.1-8B-Instruct. The issue was absent on VLLM 0.7.2. VLLM is being run with VLLM_USE_V1=0 to force V0 inference engine.
VLLM crashes during cuda graph capture. Last successfully logged message:
(VllmWorkerProcess pid=1283553) INFO 03-21 04:33:32 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if t
he model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during
cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease me
mory usage.
INFO 03-21 04:33:32 [model_runner.py:1442] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run t
he model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decr
easing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 0%| | 0/10 [00:00<?, ?it/s]
Then exception below.
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/utils.py", line 2216, in run_method
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/worker/worker.py", line 308, in initialize_cache
self._warm_up_model()
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/worker/worker.py", line 338, in _warm_up_model
self.model_runner.capture_model(self.gpu_cache)
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1509, in capture_model
self.set_active_loras(set(), lora_mapping)
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1371, in set_active_loras
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/lora/worker_manager.py", line 167, in set_active_adapters
set_active_adapters_worker(requests, mapping, self._apply_adapters,
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/adapter_commons/utils.py", line 55, in set_active_adapters_worker
set_adapter_mapping_func(mapping)
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/lora/models.py", line 688, in set_adapter_mapping
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/adapter_commons/utils.py", line 30, in set_adapter_mapping
set_mapping_func(mapping)
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/lora/models.py", line 453, in _set_adapter_mapping
self.punica_wrapper.update_metadata(
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 66, in update_metadata
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
File "/home/shadeform/precog/.venv/lib/python3.12/site-packages/vllm/lora/ops/triton_ops/lora_kernel_metadata.py", line 76, in prepare_tensors
self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping,
RuntimeError: The size of tensor a (50) must match the size of tensor b (56) at non-singleton dimension 0
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working