Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: TypeError: '_PlaceholderModuleAttr' object is not callable for RunAI SafetensorsStreamer() #11858

Closed
1 task done
huaxuan250 opened this issue Jan 8, 2025 · 2 comments · Fixed by #11882
Closed
1 task done
Labels
bug Something isn't working

Comments

@huaxuan250
Copy link

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Amazon Linux 2023.6.20241121 (x86_64)
GCC version: (GCC) 11.4.1 20230605 (Red Hat 11.4.1-2)
Clang version: Could not collect
CMake version: version 3.22.2
Libc version: glibc-2.34

Python version: 3.12.7 | packaged by conda-forge | (main, Oct  4 2024, 16:05:46) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.1.115-126.197.amzn2023.x86_64-x86_64-with-glibc2.34
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 560.35.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               96
On-line CPU(s) list:                  0-95
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
CPU family:                           6
Model:                                85
Thread(s) per core:                   2
Core(s) per socket:                   24
Socket(s):                            2
Stepping:                             7
BogoMIPS:                             5999.99
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1.5 MiB (48 instances)
L1i cache:                            1.5 MiB (48 instances)
L2 cache:                             48 MiB (48 instances)
L3 cache:                             71.5 MiB (2 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-23,48-71
NUMA node1 CPU(s):                    24-47,72-95
Vulnerability Gather data sampling:   Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:          KVM: Mitigation: VMX unsupported
Vulnerability L1tf:                   Mitigation; PTE Inversion
Vulnerability Mds:                    Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:               Mitigation; PTI
Vulnerability Mmio stale data:        Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Vulnerable
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Vulnerable
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; STIBP disabled; RSB filling; PBRSB-eIBRS Not affected; BHI Retpoline
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.560.30
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.0
[pip3] torch==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.47.1
[pip3] triton==3.1.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-ml-py              12.560.30                pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pyzmq                     26.2.0                   pypi_0    pypi
[conda] torch                     2.5.1                    pypi_0    pypi
[conda] torchvision               0.20.1                   pypi_0    pypi
[conda] transformers              4.47.1                   pypi_0    pypi
[conda] triton                    3.1.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.6.6.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    GPU4    GPU5    GPU6    GPU7    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      PHB     NODE    NODE    SYS     SYS     SYS     SYS     0-23,48-71      0               N/A
GPU1    PHB      X      NODE    NODE    SYS     SYS     SYS     SYS     0-23,48-71      0               N/A
GPU2    NODE    NODE     X      PHB     SYS     SYS     SYS     SYS     0-23,48-71      0               N/A
GPU3    NODE    NODE    PHB      X      SYS     SYS     SYS     SYS     0-23,48-71      0               N/A
GPU4    SYS     SYS     SYS     SYS      X      PHB     NODE    NODE    24-47,72-95     1               N/A
GPU5    SYS     SYS     SYS     SYS     PHB      X      NODE    NODE    24-47,72-95     1               N/A
GPU6    SYS     SYS     SYS     SYS     NODE    NODE     X      PHB     24-47,72-95     1               N/A
GPU7    SYS     SYS     SYS     SYS     NODE    NODE    PHB      X      24-47,72-95     1               N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

LD_LIBRARY_PATH=/opt/conda/lib/python3.12/site-packages/cv2/../../lib64:/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib:/lib:/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib:/lib:/opt/amazon/efa/lib64:/opt/amazon/openmpi/lib64:/opt/aws-ofi-nccl/lib:/usr/local/cuda/lib:/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/lib:/usr/lib:/lib
OMP_NUM_THREADS=48
CUDA_MODULE_LOADING=LAZY

Model Input Dumps

No response

🐛 Describe the bug

Previous S3 pickling bug encountered: #11819
The Pickling issue was fixed after applying the change in this PR: #11825
However, when loading into multiple GPUs again, a new error has surfaced.

This is the command used:
vllm serve s3://llama-3.3-70b-instruct --load-format runai_streamer --max-num-seqs 8 --tensor-parallel-size 4

    root@479c323b83e5:/usr/local/bin# vllm serve s3://llama-3.3-70b-instruct --load-format runai_streamer --max-num-seqs 8 --tensor-parallel-size 4
    INFO 01-08 16:17:19 api_server.py:712] vLLM API server version 0.6.6.post1
    INFO 01-08 16:17:19 api_server.py:713] args: Namespace(subparser='serve', model_tag='s3://llama-3.3-70b-instruct', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='s3://llama-3.3-70b-instruct', task='auto', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='runai_streamer', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='xgrammar', logits_processor_pattern=None, distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=4, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=8, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', generation_config=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, dispatch_function=<function serve at 0x77df734cd4e0>)
    INFO 01-08 16:17:19 api_server.py:199] Started engine process with PID 94935
    INFO 01-08 16:17:27 config.py:510] This model supports multiple tasks: {'reward', 'embed', 'generate', 'classify', 'score'}. Defaulting to 'generate'.
    INFO 01-08 16:17:28 config.py:1310] Defaulting to use mp for distributed inference
    WARNING 01-08 16:17:28 arg_utils.py:1103] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.
    INFO 01-08 16:17:28 config.py:1458] Chunked prefill is enabled with max_num_batched_tokens=2048.
    INFO 01-08 16:17:31 config.py:510] This model supports multiple tasks: {'reward', 'generate', 'score', 'classify', 'embed'}. Defaulting to 'generate'.
    INFO 01-08 16:17:32 config.py:1310] Defaulting to use mp for distributed inference
    WARNING 01-08 16:17:32 arg_utils.py:1103] Chunked prefill is enabled by default for models with max_model_len > 32K. Currently, chunked prefill might not work with some features or models. If you encounter any issues, please disable chunked prefill by setting --enable-chunked-prefill=False.
    INFO 01-08 16:17:32 config.py:1458] Chunked prefill is enabled with max_num_batched_tokens=2048.
    INFO 01-08 16:17:32 llm_engine.py:234] Initializing an LLM engine (v0.6.6.post1) with config: model='/tmp/tmp98nn47mr', speculative_config=None, tokenizer='/tmp/tmpo42mbhoo', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.RUNAI_STREAMER, tensor_parallel_size=4, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=s3://llama-3.3-70b-instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"candidate_compile_sizes":[],"compile_sizes":[],"capture_sizes":[8,4,2,1],"max_capture_size":8}, use_cached_outputs=True, 
    WARNING 01-08 16:17:32 multiproc_worker_utils.py:312] Reducing Torch parallelism from 96 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
    INFO 01-08 16:17:32 custom_cache_manager.py:17] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
    INFO 01-08 16:17:33 selector.py:120] Using Flash Attention backend.
    (VllmWorkerProcess pid=95689) INFO 01-08 16:17:39 selector.py:120] Using Flash Attention backend.
    (VllmWorkerProcess pid=95689) INFO 01-08 16:17:39 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
    (VllmWorkerProcess pid=95687) INFO 01-08 16:17:39 selector.py:120] Using Flash Attention backend.
    (VllmWorkerProcess pid=95687) INFO 01-08 16:17:39 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
    (VllmWorkerProcess pid=95688) INFO 01-08 16:17:40 selector.py:120] Using Flash Attention backend.
    (VllmWorkerProcess pid=95688) INFO 01-08 16:17:40 multiproc_worker_utils.py:222] Worker ready; awaiting tasks
    INFO 01-08 16:17:41 utils.py:918] Found nccl from library libnccl.so.2
    (VllmWorkerProcess pid=95687) INFO 01-08 16:17:41 utils.py:918] Found nccl from library libnccl.so.2
    INFO 01-08 16:17:41 pynccl.py:69] vLLM is using nccl==2.21.5
    (VllmWorkerProcess pid=95687) INFO 01-08 16:17:41 pynccl.py:69] vLLM is using nccl==2.21.5
    (VllmWorkerProcess pid=95689) INFO 01-08 16:17:41 utils.py:918] Found nccl from library libnccl.so.2
    (VllmWorkerProcess pid=95688) INFO 01-08 16:17:41 utils.py:918] Found nccl from library libnccl.so.2
    (VllmWorkerProcess pid=95689) INFO 01-08 16:17:41 pynccl.py:69] vLLM is using nccl==2.21.5
    (VllmWorkerProcess pid=95688) INFO 01-08 16:17:41 pynccl.py:69] vLLM is using nccl==2.21.5
    INFO 01-08 16:17:43 custom_all_reduce_utils.py:204] generating GPU P2P access cache in /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
    INFO 01-08 16:18:03 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
    (VllmWorkerProcess pid=95689) INFO 01-08 16:18:03 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
    (VllmWorkerProcess pid=95687) INFO 01-08 16:18:03 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
    (VllmWorkerProcess pid=95688) INFO 01-08 16:18:03 custom_all_reduce_utils.py:242] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1,2,3.json
    INFO 01-08 16:18:03 shm_broadcast.py:255] vLLM message queue communication handle: Handle(connect_ip='127.0.0.1', local_reader_ranks=[1, 2, 3], buffer_handle=(3, 4194304, 6, 'psm_b480f0c3'), local_subscribe_port=43093, remote_subscribe_port=None)
    INFO 01-08 16:18:03 model_runner.py:1094] Starting to load model /tmp/tmp98nn47mr...
    (VllmWorkerProcess pid=95687) INFO 01-08 16:18:03 model_runner.py:1094] Starting to load model /tmp/tmp98nn47mr...
    (VllmWorkerProcess pid=95688) INFO 01-08 16:18:03 model_runner.py:1094] Starting to load model /tmp/tmp98nn47mr...
    (VllmWorkerProcess pid=95689) INFO 01-08 16:18:03 model_runner.py:1094] Starting to load model /tmp/tmp98nn47mr...
    ERROR 01-08 16:18:04 engine.py:366] '_PlaceholderModuleAttr' object is not callable
    ERROR 01-08 16:18:04 engine.py:366] Traceback (most recent call last):
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
    ERROR 01-08 16:18:04 engine.py:366]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
    ERROR 01-08 16:18:04 engine.py:366]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
    ERROR 01-08 16:18:04 engine.py:366]     return cls(ipc_path=ipc_path,
    ERROR 01-08 16:18:04 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
    ERROR 01-08 16:18:04 engine.py:366]     self.engine = LLMEngine(*args, **kwargs)
    ERROR 01-08 16:18:04 engine.py:366]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/engine/llm_engine.py", line 273, in __init__
    ERROR 01-08 16:18:04 engine.py:366]     self.model_executor = executor_class(vllm_config=vllm_config, )
    ERROR 01-08 16:18:04 engine.py:366]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/executor/distributed_gpu_executor.py", line 26, in __init__
    ERROR 01-08 16:18:04 engine.py:366]     super().__init__(*args, **kwargs)
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/executor/executor_base.py", line 36, in __init__
    ERROR 01-08 16:18:04 engine.py:366]     self._init_executor()
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 83, in _init_executor
    ERROR 01-08 16:18:04 engine.py:366]     self._run_workers("load_model",
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 157, in _run_workers
    ERROR 01-08 16:18:04 engine.py:366]     driver_worker_output = driver_worker_method(*args, **kwargs)
    ERROR 01-08 16:18:04 engine.py:366]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/worker/worker.py", line 155, in load_model
    ERROR 01-08 16:18:04 engine.py:366]     self.model_runner.load_model()
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/worker/model_runner.py", line 1096, in load_model
    ERROR 01-08 16:18:04 engine.py:366]     self.model = get_model(vllm_config=self.vllm_config)
    ERROR 01-08 16:18:04 engine.py:366]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
    ERROR 01-08 16:18:04 engine.py:366]     return loader.load_model(vllm_config=vllm_config)
    ERROR 01-08 16:18:04 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/model_loader/loader.py", line 1329, in load_model
    ERROR 01-08 16:18:04 engine.py:366]     model.load_weights(
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py", line 594, in load_weights
    ERROR 01-08 16:18:04 engine.py:366]     return loader.load_weights(
    ERROR 01-08 16:18:04 engine.py:366]            ^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 237, in load_weights
    ERROR 01-08 16:18:04 engine.py:366]     autoloaded_weights = set(self._load_module("", self.module, weights))
    ERROR 01-08 16:18:04 engine.py:366]                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 189, in _load_module
    ERROR 01-08 16:18:04 engine.py:366]     for child_prefix, child_weights in self._groupby_prefix(weights):
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 103, in _groupby_prefix
    ERROR 01-08 16:18:04 engine.py:366]     for prefix, group in itertools.groupby(weights_by_parts,
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 100, in <genexpr>
    ERROR 01-08 16:18:04 engine.py:366]     weights_by_parts = ((weight_name.split(".", 1), weight_data)
    ERROR 01-08 16:18:04 engine.py:366]                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py", line 594, in <genexpr>
    ERROR 01-08 16:18:04 engine.py:366]     return loader.load_weights(
    ERROR 01-08 16:18:04 engine.py:366]                               ^
    ERROR 01-08 16:18:04 engine.py:366]   File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/model_loader/weight_utils.py", line 427, in runai_safetensors_weights_iterator
    ERROR 01-08 16:18:04 engine.py:366]     with SafetensorsStreamer() as streamer:
    ERROR 01-08 16:18:04 engine.py:366]          ^^^^^^^^^^^^^^^^^^^^^
    ERROR 01-08 16:18:04 engine.py:366] TypeError: '_PlaceholderModuleAttr' object is not callable
    ERROR 01-08 16:18:05 multiproc_worker_utils.py:123] Worker VllmWorkerProcess pid 95689 died, exit code: -15
    INFO 01-08 16:18:05 multiproc_worker_utils.py:127] Killing local vLLM worker processes
    Process SpawnProcess-1:
    Traceback (most recent call last):
      File "/usr/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
        self.run()
      File "/usr/lib/python3.11/multiprocessing/process.py", line 108, in run
        self._target(*self._args, **self._kwargs)
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 368, in run_mp_engine
        raise e
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 357, in run_mp_engine
        engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 119, in from_engine_args
        return cls(ipc_path=ipc_path,
               ^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/engine.py", line 71, in __init__
        self.engine = LLMEngine(*args, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/llm_engine.py", line 273, in __init__
        self.model_executor = executor_class(vllm_config=vllm_config, )
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/executor/distributed_gpu_executor.py", line 26, in __init__
        super().__init__(*args, **kwargs)
      File "/usr/local/lib/python3.11/dist-packages/vllm/executor/executor_base.py", line 36, in __init__
        self._init_executor()
      File "/usr/local/lib/python3.11/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 83, in _init_executor
        self._run_workers("load_model",
      File "/usr/local/lib/python3.11/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 157, in _run_workers
        driver_worker_output = driver_worker_method(*args, **kwargs)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/worker/worker.py", line 155, in load_model
        self.model_runner.load_model()
      File "/usr/local/lib/python3.11/dist-packages/vllm/worker/model_runner.py", line 1096, in load_model
        self.model = get_model(vllm_config=self.vllm_config)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/model_loader/__init__.py", line 12, in get_model
        return loader.load_model(vllm_config=vllm_config)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/model_loader/loader.py", line 1329, in load_model
        model.load_weights(
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py", line 594, in load_weights
        return loader.load_weights(
               ^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 237, in load_weights
        autoloaded_weights = set(self._load_module("", self.module, weights))
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 189, in _load_module
        for child_prefix, child_weights in self._groupby_prefix(weights):
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 103, in _groupby_prefix
        for prefix, group in itertools.groupby(weights_by_parts,
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/utils.py", line 100, in <genexpr>
        weights_by_parts = ((weight_name.split(".", 1), weight_data)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py", line 594, in <genexpr>
        return loader.load_weights(
                                  ^
      File "/usr/local/lib/python3.11/dist-packages/vllm/model_executor/model_loader/weight_utils.py", line 427, in runai_safetensors_weights_iterator
        with SafetensorsStreamer() as streamer:
             ^^^^^^^^^^^^^^^^^^^^^
    TypeError: '_PlaceholderModuleAttr' object is not callable
    [rank0]:[W108 16:18:06.792831196 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present,  but this warning has only been added since PyTorch 2.4 (function operator())
    Task exception was never retrieved
    future: <Task finished name='Task-2' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
    Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
        while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/zmq/_future.py", line 400, in poll
        raise _zmq.ZMQError(_zmq.ENOTSUP)
    zmq.error.ZMQError: Operation not supported
    Task exception was never retrieved
    future: <Task finished name='Task-3' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
    Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
        while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/zmq/_future.py", line 400, in poll
        raise _zmq.ZMQError(_zmq.ENOTSUP)
    zmq.error.ZMQError: Operation not supported
    Task exception was never retrieved
    future: <Task finished name='Task-4' coro=<MQLLMEngineClient.run_output_handler_loop() done, defined at /usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/client.py:178> exception=ZMQError('Operation not supported')>
    Traceback (most recent call last):
      File "/usr/local/lib/python3.11/dist-packages/vllm/engine/multiprocessing/client.py", line 184, in run_output_handler_loop
        while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/zmq/_future.py", line 400, in poll
        raise _zmq.ZMQError(_zmq.ENOTSUP)
    zmq.error.ZMQError: Operation not supported
    Traceback (most recent call last):
      File "/usr/local/bin/vllm", line 8, in <module>
        sys.exit(main())
                 ^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/scripts.py", line 201, in main
        args.dispatch_function(args)
      File "/usr/local/lib/python3.11/dist-packages/vllm/scripts.py", line 42, in serve
        uvloop.run(run_server(args))
      File "/usr/local/lib/python3.11/dist-packages/uvloop/__init__.py", line 105, in run
        return runner.run(wrapper())
               ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/lib/python3.11/asyncio/runners.py", line 118, in run
        return self._loop.run_until_complete(task)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
      File "/usr/local/lib/python3.11/dist-packages/uvloop/__init__.py", line 61, in wrapper
        return await main
               ^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/entrypoints/openai/api_server.py", line 740, in run_server
        async with build_async_engine_client(args) as engine_client:
      File "/usr/lib/python3.11/contextlib.py", line 210, in __aenter__
        return await anext(self.gen)
               ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/entrypoints/openai/api_server.py", line 118, in build_async_engine_client
        async with build_async_engine_client_from_engine_args(
      File "/usr/lib/python3.11/contextlib.py", line 210, in __aenter__
        return await anext(self.gen)
               ^^^^^^^^^^^^^^^^^^^^^
      File "/usr/local/lib/python3.11/dist-packages/vllm/entrypoints/openai/api_server.py", line 223, in build_async_engine_client_from_engine_args
        raise RuntimeError(
    RuntimeError: Engine process failed to start. See stack trace for the root cause.
    /usr/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
      warnings.warn('resource_tracker: There appear to be %d '

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@omer-dayan
Copy link
Contributor

The _PlaceholderMethodAttr indicate you encounter PlaceholderModule. Which should occur when you cannot import RunAI Model Streamer.
Basically, RunAI Model Streamer package is optional dependency of vLLM, which means it is not installed as part of pip install vllm.
You need additionally run pip install vllm[runai]

@DarkLight1337 I see the following change:
image
From this commit (eec906d#diff-c1cbb1d99298476e618b6fc564452c8cd9b3ecbf688c5cfb44506209547f7f3dL418)

And my only issue is that it may confuse the user as he does not get appropriate message like "Please install vllm[runai]"

@DarkLight1337
Copy link
Member

And my only issue is that it may confuse the user as he does not get appropriate message like "Please install vllm[runai]"

If you look at the definition of _PlaceholderMethodAttr, we do raise such an error upon accessing attributes. I guess we also should override __call__ to raise a similar error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants