Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Custom all reduce not work. #3688

Open
esmeetu opened this issue Mar 28, 2024 · 18 comments
Open

[Bug]: Custom all reduce not work. #3688

esmeetu opened this issue Mar 28, 2024 · 18 comments
Labels
bug Something isn't working

Comments

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 28, 2024

Your current environment

PyTorch version: 2.1.2+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.4
Libc version: glibc-2.35

Python version: 3.10.12 (main, Jul  5 2023, 18:54:27) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.2.0-39-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla T4
GPU 1: Tesla T4

Nvidia driver version: 545.29.06
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             104
On-line CPU(s) list:                0-103
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Platinum 8269CY CPU @ 2.50GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 26
Socket(s):                          2
Stepping:                           7
CPU max MHz:                        3800.0000
CPU min MHz:                        1200.0000
BogoMIPS:                           5000.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          1.6 MiB (52 instances)
L1i cache:                          1.6 MiB (52 instances)
L2 cache:                           52 MiB (52 instances)
L3 cache:                           71.5 MiB (2 instances)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-103
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        KVM: Mitigation: VMX disabled
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.1.2
[pip3] triton==2.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] torch                     2.1.2                    pypi_0    pypi
[conda] triton                    2.1.0                    pypi_0    pypiROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.3.3
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      SYS     0-103   0               N/A
GPU1    SYS      X      0-103   0               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

llm_engine is good.
async_llm_engine not work.

python -m vllm.entrypoints.openai.api_server --model /TinyLlama-1.1B-Chat-v0.6 --tensor-parallel-size 2 --dtype half --enforce-eager

Log:

INFO 03-28 21:38:07 api_server.py:147] vLLM API server version 0.3.3
INFO 03-28 21:38:07 api_server.py:148] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, served_model_name=None, lora_modules=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='/TinyLlama-1.1B-Chat-v0.6', tokenizer=None, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='half', kv_cache_dtype='auto', max_model_len=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=2, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, use_v2_block_manager=False, seed=0, swap_space=4, gpu_memory_utilization=0.9, forced_num_gpu_blocks=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=5, disable_log_stats=False, quantization=None, enforce_eager=True, max_context_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', max_cpu_loras=None, device='auto', image_input_type=None, image_token_id=None, image_input_shape=None, image_feature_size=None, scheduler_delay_factor=0.0, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
WARNING 03-28 21:38:07 config.py:744] Casting torch.bfloat16 to torch.float16.
2024-03-28 21:38:10,304 INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
INFO 03-28 21:38:11 llm_engine.py:70] Initializing an LLM engine (v0.3.3) with config: model='/TinyLlama-1.1B-Chat-v0.6', tokenizer='/TinyLlama-1.1B-Chat-v0.6', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 03-28 21:38:16 pynccl_utils.py:13] vLLM is using nccl==2.18.1
(RayWorkerVllm pid=4028887) INFO 03-28 21:38:17 pynccl_utils.py:13] vLLM is using nccl==2.18.1
INFO 03-28 21:38:17 selector.py:33] Cannot use FlashAttention backend for Volta and Turing GPUs.
INFO 03-28 21:38:17 selector.py:20] Using XFormers backend.
(RayWorkerVllm pid=4028887) INFO 03-28 21:38:18 selector.py:33] Cannot use FlashAttention backend for Volta and Turing GPUs.
(RayWorkerVllm pid=4028887) INFO 03-28 21:38:18 selector.py:20] Using XFormers backend.
INFO 03-28 21:38:20 custom_all_reduce.py:137] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped
(RayWorkerVllm pid=4028887) INFO 03-28 21:38:20 custom_all_reduce.py:137] NVLink detection failed with message "Not Supported". This is normal if your machine has no NVLink equipped
INFO 03-28 21:38:21 model_runner.py:104] Loading model weights took 1.0258 GB
(RayWorkerVllm pid=4028887) INFO 03-28 21:38:22 model_runner.py:104] Loading model weights took 1.0258 GB
Traceback (most recent call last):
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/roy/vllm/vllm/entrypoints/openai/api_server.py", line 156, in <module>
    engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/home/roy/vllm/vllm/engine/async_llm_engine.py", line 344, in from_engine_args
    engine = cls(parallel_config.worker_use_ray,
  File "/home/roy/vllm/vllm/engine/async_llm_engine.py", line 310, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/home/roy/vllm/vllm/engine/async_llm_engine.py", line 415, in _init_engine
    return engine_class(*args, **kwargs)
  File "/home/roy/vllm/vllm/engine/llm_engine.py", line 106, in __init__
    self.model_executor = executor_class(model_config, cache_config,
  File "/home/roy/vllm/vllm/executor/ray_gpu_executor.py", line 65, in __init__
    self._init_cache()
  File "/home/roy/vllm/vllm/executor/ray_gpu_executor.py", line 222, in _init_cache
    num_blocks = self._run_workers(
  File "/home/roy/vllm/vllm/executor/ray_gpu_executor.py", line 326, in _run_workers
    driver_worker_output = getattr(self.driver_worker,
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/roy/vllm/vllm/worker/worker.py", line 130, in profile_num_available_blocks
    self.model_runner.profile_run()
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/roy/vllm/vllm/worker/model_runner.py", line 721, in profile_run
    self.execute_model(seqs, kv_caches)
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/roy/vllm/vllm/worker/model_runner.py", line 645, in execute_model
    logits = self.model.compute_logits(hidden_states, sampling_metadata)
  File "/home/roy/vllm/vllm/model_executor/models/llama.py", line 351, in compute_logits
    logits = self.logits_processor(self.lm_head.weight, hidden_states,
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/roy/vllm/vllm/model_executor/layers/logits_processor.py", line 52, in forward
    logits = self._get_logits(hidden_states, embedding, embedding_bias)
  File "/home/roy/vllm/vllm/model_executor/layers/logits_processor.py", line 68, in _get_logits
    logits = tensor_model_parallel_gather(logits)
  File "/home/roy/vllm/vllm/model_executor/parallel_utils/communication_op.py", line 93, in tensor_model_parallel_gather
    torch.distributed.gather(input_,
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3082, in gather
    work = group.gather(output_tensors, input_tensors, opts)
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1333, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.18.1
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'invalid argument'
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44] Error executing method profile_num_available_blocks. This might cause deadlock in distributed execution.
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44] Traceback (most recent call last):
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/engine/ray_utils.py", line 37, in execute_method
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return executor(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return func(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/worker/worker.py", line 130, in profile_num_available_blocks
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     self.model_runner.profile_run()
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return func(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/worker/model_runner.py", line 721, in profile_run
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     self.execute_model(seqs, kv_caches)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return func(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/worker/model_runner.py", line 645, in execute_model
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     logits = self.model.compute_logits(hidden_states, sampling_metadata)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/model_executor/models/llama.py", line 351, in compute_logits
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     logits = self.logits_processor(self.lm_head.weight, hidden_states,
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return self._call_impl(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return forward_call(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/model_executor/layers/logits_processor.py", line 52, in forward
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     logits = self._get_logits(hidden_states, embedding, embedding_bias)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/model_executor/layers/logits_processor.py", line 68, in _get_logits
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     logits = tensor_model_parallel_gather(logits)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/vllm/vllm/model_executor/parallel_utils/communication_op.py", line 93, in tensor_model_parallel_gather
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     torch.distributed.gather(input_,
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     return func(*args, **kwargs)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]   File "/home/roy/miniconda3/envs/ray/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3082, in gather
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44]     work = group.gather(output_tensors, input_tensors, opts)
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44] torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1333, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.18.1
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44] ncclUnhandledCudaError: Call to CUDA function failed.
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44] Last error:
(RayWorkerVllm pid=4028887) ERROR 03-28 21:38:23 ray_utils.py:44] Cuda failure 'invalid argument'
[W CudaIPCTypes.cpp:15] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]
@esmeetu esmeetu added the bug Something isn't working label Mar 28, 2024
@youkaichao
Copy link
Member

When does this problem occur? Is it related to #2152 ?

@esmeetu
Copy link
Collaborator Author

esmeetu commented Mar 29, 2024

@youkaichao Nope, it's related custom-all-reduce feature. After Upgrade to nccl=2.19.3, everything is ok.
Related issue: NVIDIA/nccl#957
Fix commit: NVIDIA/nccl@4365458
cc @WoosukKwon @hanzhi713

@hanzhi713
Copy link
Contributor

hanzhi713 commented Mar 30, 2024

@esmeetu I guess it's NCCL's problem? Let me know if a fixed is needed from my side

@esmeetu
Copy link
Collaborator Author

esmeetu commented Mar 30, 2024

@hanzhi713 Yes, i thinking so. I will test again after vllm upgrade pytorch to v2.2.2.

@esmeetu
Copy link
Collaborator Author

esmeetu commented Mar 30, 2024

@hanzhi713 Why your custom all reduce kernel is influenced by nccl? IIUC, yours doesn't use nccl.🤔
And do you have ideas about why that nccl bug will result in my issue?

@hanzhi713
Copy link
Contributor

Allreduce with larger size (>=8mb) and other collectives (like gather) still need NCCL

@Sande33p
Copy link

any one tried vllm 0.3.3 + torch 2.1.1+cu118 with nccl==2.19.3? . By default vllm 0.3.3 + torch 2.1.1+cu118 installs nccl==2.18.3 that is giving the all_reduce error with multiple nodes

@youkaichao
Copy link
Member

any one tried vllm 0.3.3 + torch 2.1.1+cu118 with nccl==2.19.3? . By default vllm 0.3.3 + torch 2.1.1+cu118 installs nccl==2.18.3 that is giving the all_reduce error with multiple nodes

Can you show your environment and error trace?

@Sande33p
Copy link

any one tried vllm 0.3.3 + torch 2.1.1+cu118 with nccl==2.19.3? . By default vllm 0.3.3 + torch 2.1.1+cu118 installs nccl==2.18.3 that is giving the all_reduce error with multiple nodes

Can you show your environment and error trace?

error.txt
env_build.txt

@youkaichao
Copy link
Member

@Sande33p

  1. can you run with export NCCL_DEBUG=TRACE again to get more verbose output for debugging?
  2. if you are using multi-node inference, I suggest you build from source again. [Core] Support multi-node inference(eager and cuda graph) #3686 just fixed some issues with multi-node setup.

@Sande33p
Copy link

@Sande33p

  1. can you run with export NCCL_DEBUG=TRACE again to get more verbose output for debugging?
  2. if you are using multi-node inference, I suggest you build from source again. [Core] Support multi-node inference(eager and cuda graph) #3686 just fixed some issues with multi-node setup.

@youkaichao here is the error file with export NCCL_DEBUG=TRACE
error_2.txt

@youkaichao
Copy link
Member

@Sande33p I took a look at your error log, and I find the following lines might be relevant:

x3005c0s37b1n0:24760:24760 [0] NCCL INFO cudaDriverVersion 11080
NCCL version 2.18.3+cuda11.0
x3005c0s37b1n0:24760:24760 [0] misc/strongstream.cc:53 NCCL WARN NCCL cannot be captured in a graph if either it wasn't built with CUDA runtime >= 11.3 or if the installed CUDA driver < R465.

It seems your cuda version is too old. Can you try to upgrade your cuda version?

@CNTRYROA
Copy link

CNTRYROA commented May 23, 2024

any one tried vllm 0.3.3 + torch 2.1.1+cu118 with nccl==2.19.3? . By default vllm 0.3.3 + torch 2.1.1+cu118 installs nccl==2.18.3 that is giving the all_reduce error with multiple nodes

Yes, I faced the same error when I tried vllm 0.4.1+cu11,torch 2.2.1+cu11,nccl==2.19.2,vllm-nccl-cu11 2.18.1.0.4.0 with multiple gpus.
When I set env export VLLM_NCCL_SO_PATH=/usr/lib/x86_64-linux-gnu/libnccl.so.2 before start vllm application, it works fine.The version of nccl located at /usr/lib/x86_64-linux-gnu/libnccl.so.2 is 2.16.2

def find_nccl_library():

@beliven-daniele-sarnari
Copy link

beliven-daniele-sarnari commented Jun 20, 2024

Same issue here, using docker image vllm/vllm-openai:latest.
Is it related to the host cuda version?

NCCL version 2.20.5+cuda12.4

@youkaichao
Copy link
Member

Is it related to the host cuda version?

Maybe. What's your host driver info?

@beliven-daniele-sarnari

Is it related to the host cuda version?

Maybe. What's your host driver info?

Can't know it, i am running vllm/vllm-openai:latest docker image inside RunPod.io and I'm having the same issue as the topic

@youkaichao
Copy link
Member

you can follow the issue template https://github.com/vllm-project/vllm/issues/new/choose to run an environment collection script.

@unix1986
Copy link

Allreduce with larger size (>=8mb) and other collectives (like gather) still need NCCL

@hanzhi713 @youkaichao May I ask, what was the original intention behind vLLM's development of custom allreduce?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

7 participants