Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Error when using tensor_parallel in v0.6.1 #8397

Closed
1 task done
pspdada opened this issue Sep 12, 2024 · 11 comments · Fixed by #8390
Closed
1 task done

[Bug]: Error when using tensor_parallel in v0.6.1 #8397

pspdada opened this issue Sep 12, 2024 · 11 comments · Fixed by #8390
Labels
bug Something isn't working

Comments

@pspdada
Copy link

pspdada commented Sep 12, 2024

Your current environment

The output of `python collect_env.py`
Collecting environment information...
PyTorch version: 2.4.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.2
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-120-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-PCIE-40GB
GPU 1: NVIDIA A100-PCIE-40GB

Nvidia driver version: 550.107.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        40 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               4
On-line CPU(s) list:                  0-3
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Silver 4316 CPU @ 2.30GHz
CPU family:                           6
Model:                                106
Thread(s) per core:                   1
Core(s) per socket:                   4
Socket(s):                            1
Stepping:                             6
BogoMIPS:                             4589.12
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch topoext cpuid_fault invpcid_single pti ssbd ibrs ibpb tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid
Virtualization:                       VT-x
L1d cache:                            128 KiB (4 instances)
L1i cache:                            128 KiB (4 instances)
L2 cache:                             16 MiB (4 instances)
L3 cache:                             16 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-3
Vulnerability Gather data sampling:   Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:          KVM: Mitigation: VMX disabled
Vulnerability L1tf:                   Mitigation; PTE Inversion; VMX conditional cache flushes, SMT disabled
Vulnerability Mds:                    Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:               Mitigation; PTI
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.2.65
[pip3] nvidia-cuda-cupti-cu12==12.4.99
[pip3] nvidia-cuda-nvrtc-cu12==12.4.99
[pip3] nvidia-cuda-runtime-cu12==12.4.99
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.0.44
[pip3] nvidia-curand-cu12==10.3.5.119
[pip3] nvidia-cusolver-cu12==11.6.0.99
[pip3] nvidia-cusparse-cu12==12.3.0.142
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] nvidia-nvjitlink-cu12==12.4.99
[pip3] nvidia-nvtx-cu12==12.4.99
[pip3] pyzmq==26.1.1
[pip3] torch==2.4.0+cu124
[pip3] torchaudio==2.4.0+cu124
[pip3] torchvision==0.19.0+cu124
[pip3] transformers==4.44.0
[pip3] triton==3.0.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.0@32e7db25365415841ebc7c4215851743fbb1bad1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PHB     0-3     0               N/A
GPU1    PHB      X      0-3     0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

Model Input Dumps

No response

🐛 Describe the bug

When using the vllm library version v0.6.1, I encounter an error with tensor_parallel. Rolling back to v0.6.0 resolves the issue. tensor_parallel_size = 1 won't see the bug at all.
The test code is the same for both versions:

from vllm import LLM
def init_model() -> LLM:
    llm = LLM(
        model="Qwen/Qwen2-7B-Instruct",
        tokenizer_mode="auto",
        trust_remote_code=True,
        download_dir="./.cache",
        tensor_parallel_size=2,  # How many GPUs to use
        gpu_memory_utilization=0.85,
        pipeline_parallel_size=1,
        dtype="bfloat16",
        # max_model_len=20480,  # Model context length
        enable_prefix_caching=True,
        enable_chunked_prefill=False,
        num_scheduler_steps=8,
    )
    return llm

if __name__ == "__main__":
    llm = init_model()
    print(llm.generate("Hello, world!"))

The output for v0.6.0 is as follows:

INFO 09-12 11:50:02 config.py:890] Defaulting to use mp for distributed inference
WARNING 09-12 11:50:02 arg_utils.py:880] Enabled BlockSpaceManagerV2 because it is required for multi-step (--num-scheduler-steps > 1)
INFO 09-12 11:50:02 llm_engine.py:213] Initializing an LLM engine (v0.6.0) with config: model='Qwen/Qwen2-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir='./.cache', load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2-7B-Instruct, use_v2_block_manager=True, num_scheduler_steps=8, enable_prefix_caching=True, use_async_output_proc=True)
WARNING 09-12 11:50:03 multiproc_gpu_executor.py:56] Reducing Torch parallelism from 4 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 09-12 11:50:03 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
WARNING 09-12 11:50:03 registry.py:190] `mm_limits` has already been set for model=Qwen/Qwen2-7B-Instruct, and will be overwritten by the new values.
(VllmWorkerProcess pid=32744) WARNING 09-12 11:50:03 registry.py:190] `mm_limits` has already been set for model=Qwen/Qwen2-7B-Instruct, and will be overwritten by the new values.
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:04 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 09-12 11:50:04 utils.py:977] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:04 utils.py:977] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:04 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 09-12 11:50:04 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 09-12 11:50:04 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:04 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
WARNING 09-12 11:50:04 custom_all_reduce.py:131] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=32744) WARNING 09-12 11:50:04 custom_all_reduce.py:131] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 09-12 11:50:04 shm_broadcast.py:235] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1], buffer=<vllm.distributed.device_communicators.shm_broadcast.ShmRingBuffer object at 0x7f474c3926e0>, local_subscribe_port=33371, remote_subscribe_port=None)
INFO 09-12 11:50:04 model_runner.py:915] Starting to load model Qwen/Qwen2-7B-Instruct...
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:04 model_runner.py:915] Starting to load model Qwen/Qwen2-7B-Instruct...
INFO 09-12 11:50:05 weight_utils.py:236] Using model weights format ['*.safetensors']
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:05 weight_utils.py:236] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  1.59it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:01<00:01,  1.39it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:02<00:00,  1.42it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.37it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:02<00:00,  1.39it/s]

INFO 09-12 11:50:09 model_runner.py:926] Loading model weights took 7.1216 GB
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:09 model_runner.py:926] Loading model weights took 7.1216 GB
INFO 09-12 11:50:12 distributed_gpu_executor.py:57] # GPU blocks: 53836, # CPU blocks: 9362
INFO 09-12 11:50:16 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 09-12 11:50:16 model_runner.py:1221] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:16 model_runner.py:1217] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:16 model_runner.py:1221] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=32744) INFO 09-12 11:50:33 model_runner.py:1335] Graph capturing finished in 17 secs.
INFO 09-12 11:50:33 model_runner.py:1335] Graph capturing finished in 17 secs.
Processed prompts: 100%|██████████████████████████| 1/1 [00:00<00:00,  3.28it/s, est. speed input: 13.13 toks/s, output: 52.51 toks/s]
[RequestOutput(request_id=0, prompt='Hello, world!', prompt_token_ids=[9707, 11, 1879, 0], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=" We welcome you aboard Avanti's latest blog post, where we'll dive into", token_ids=array('l', [1205, 10565, 498, 36506, 7519, 15359, 594, 5535, 5010, 1736, 11, 1380, 582, 3278, 29863, 1119]), cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=RequestMetrics(arrival_time=1726113035.8667371, last_token_time=1726113035.8667371, first_scheduled_time=1726113035.8703601, first_token_time=1726113035.9442914, time_in_queue=0.0036230087280273438, finished_time=1726113036.1726139, scheduler_time=0.0007259720005095005, model_forward_time=None, model_execute_time=None), lora_request=None)]

The output for v0.6.1 is as follows:

INFO 09-12 11:41:36 config.py:897] Defaulting to use mp for distributed inference
WARNING 09-12 11:41:36 arg_utils.py:908] Enabled BlockSpaceManagerV2 because it is required for multi-step (--num-scheduler-steps > 1)
INFO 09-12 11:41:36 llm_engine.py:232] Initializing an LLM engine (v0.6.1) with config: model='Qwen/Qwen2-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=32768, download_dir='./.cache', load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2-7B-Instruct, use_v2_block_manager=True, num_scheduler_steps=8, enable_prefix_caching=True, use_async_output_proc=True)
WARNING 09-12 11:41:37 multiproc_gpu_executor.py:56] Reducing Torch parallelism from 4 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 09-12 11:41:37 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
WARNING 09-12 11:41:37 registry.py:191] `mm_limits` has already been set for model=Qwen/Qwen2-7B-Instruct, and will be overwritten by the new values.
(VllmWorkerProcess pid=30655) WARNING 09-12 11:41:37 registry.py:191] `mm_limits` has already been set for model=Qwen/Qwen2-7B-Instruct, and will be overwritten by the new values.
(VllmWorkerProcess pid=30655) Process VllmWorkerProcess:
(VllmWorkerProcess pid=30655) Traceback (most recent call last):
(VllmWorkerProcess pid=30655)   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
(VllmWorkerProcess pid=30655)     self.run()
(VllmWorkerProcess pid=30655)   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
(VllmWorkerProcess pid=30655)     self._target(*self._args, **self._kwargs)
(VllmWorkerProcess pid=30655)   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/multiproc_worker_utils.py", line 210, in _run_worker_process
(VllmWorkerProcess pid=30655)     worker = worker_factory()
(VllmWorkerProcess pid=30655)   File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 24, in create_worker
(VllmWorkerProcess pid=30655)     wrapper.init_worker(**kwargs)
(VllmWorkerProcess pid=30655)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker_base.py", line 449, in init_worker
(VllmWorkerProcess pid=30655)     self.worker = worker_class(*args, **kwargs)
(VllmWorkerProcess pid=30655)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/multi_step_worker.py", line 28, in __init__
(VllmWorkerProcess pid=30655)     self.model_runner = MultiStepModelRunner(
(VllmWorkerProcess pid=30655)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/multi_step_model_runner.py", line 234, in __init__
(VllmWorkerProcess pid=30655)     self._copy_stream = torch.cuda.Stream()
(VllmWorkerProcess pid=30655)   File "/usr/local/lib/python3.10/dist-packages/torch/cuda/streams.py", line 35, in __new__
(VllmWorkerProcess pid=30655)     return super().__new__(cls, priority=priority, **kwargs)
(VllmWorkerProcess pid=30655) RuntimeError: CUDA error: initialization error
(VllmWorkerProcess pid=30655) Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
(VllmWorkerProcess pid=30655) 
ERROR 09-12 11:41:37 multiproc_worker_utils.py:120] Worker VllmWorkerProcess pid 30655 died, exit code: 1
INFO 09-12 11:41:37 multiproc_worker_utils.py:123] Killing local vLLM worker processes

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@pspdada pspdada added the bug Something isn't working label Sep 12, 2024
@Quang-elec44
Copy link

Same here with Docker when setting --tensor-parallel-size 2

INFO 09-11 21:02:27 llm_engine.py:232] Initializing an LLM engine (v0.6.1) with config: model='/models/Llama-3.1-Storm-8B-AWQ', speculative_config=None, tokenizer='/models/Llama-3.1-Storm-8B-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=42, served_model_name=gpt-4o, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=True, use_async_output_proc=True)
WARNING 09-11 21:02:28 multiproc_gpu_executor.py:56] Reducing Torch parallelism from 24 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 09-11 21:02:28 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
(VllmWorkerProcess pid=163) INFO 09-11 21:02:28 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method init_device: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method, Traceback (most recent call last):
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]              ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 166, in init_device
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     torch.cuda.set_device(self.device)
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 420, in set_device
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     torch._C._cuda_setDevice(device)
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 300, in _lazy_init
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     raise RuntimeError(
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226] RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

@DarkLight1337
Copy link
Member

Can you check whether #8390 fixes the problem on your end?

@pspdada
Copy link
Author

pspdada commented Sep 12, 2024

Can you check whether #8390 fixes the problem on your end?

Yes, it absolutly fixes what goes wrong, thanks a lot!

@ashgold
Copy link

ashgold commented Sep 12, 2024

Same here with Docker when setting --tensor-parallel-size 2

INFO 09-11 21:02:27 llm_engine.py:232] Initializing an LLM engine (v0.6.1) with config: model='/models/Llama-3.1-Storm-8B-AWQ', speculative_config=None, tokenizer='/models/Llama-3.1-Storm-8B-AWQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=42, served_model_name=gpt-4o, use_v2_block_manager=False, num_scheduler_steps=1, enable_prefix_caching=True, use_async_output_proc=True)
WARNING 09-11 21:02:28 multiproc_gpu_executor.py:56] Reducing Torch parallelism from 24 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 09-11 21:02:28 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
(VllmWorkerProcess pid=163) INFO 09-11 21:02:28 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method init_device: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method, Traceback (most recent call last):
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]              ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 166, in init_device
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     torch.cuda.set_device(self.device)
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 420, in set_device
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     torch._C._cuda_setDevice(device)
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 300, in _lazy_init
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226]     raise RuntimeError(
(VllmWorkerProcess pid=163) ERROR 09-11 21:02:28 multiproc_worker_utils.py:226] RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

same on me.
it occurred only in v0.6.1, not with v0.6.0.
workaround solution is #6152

@pseudotensor
Copy link

I see this no just with AMD, on normal nvidia H100*4

@pseudotensor
Copy link

I think this requires 0.6.2 or 0.6.1.post release, without the fix it seems any sharding is broken on any devices.

@DarkLight1337
Copy link
Member

Yes, we will likely release a patch since this issue breaks vLLM for many users. Stay tuned!

@sayakpaul
Copy link
Contributor

I know it's not helpful but the fixes worked for me. Installed from source and everything was up and running again.

@ruleGreen
Copy link

still has this problem when using 0.6.1.post or 0.6.2 when setting tp=2 on single node with 8 gpus. The pp is 1.

@DarkLight1337
Copy link
Member

still has this problem when using 0.6.1.post or 0.6.2 when setting tp=2 on single node with 8 gpus. The pp is 1.

Can you open a new issue and provide your environment and error in more details?

@ruleGreen
Copy link

still has this problem when using 0.6.1.post or 0.6.2 when setting tp=2 on single node with 8 gpus. The pp is 1.

Can you open a new issue and provide your environment and error in more details?

Pls refer to #8937

I think this may also related to #8735 #7151

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants