Skip to content

Conversation

@jinzhen-lin
Copy link
Contributor

@jinzhen-lin jinzhen-lin commented Mar 7, 2025

#12185 and #13321 introduced the triton/cuda moe wna16 kernel to optimize the performance of moe gptq/awq. However, the best-performing gptq/awq kernel currently is the marlin kernel, and I hope to combine it with moe. Although there is already an implementation of moe + marlin kernel, it fails to fully leverage the performance advantages of the marlin kernel (especially when the number of experts is large).

This PR introduces a new moe wna16 marlin kernel that utilizes the m-parallel mechanism inherent to the marlin kernel to process all moe blocks in parallel. To prevent an excessive number of moe blocks from causing high workspace and c_tmp capacity requirements, I have updated the utilization logic for workspace and c_tmp (considering that the maximum number of slice_col_par that shared by different SMs is at most the number of SMs, meaning we only need a workspace of fixed length equal to the number of SMs).

This kernel is based on vllm's gptq_marlin implementation and fully supports all features of it (bfloat16/int8/act_order/...). It also supports expert parallelism.


EDIT: Benchmarks copied from comments in this PR

(The following benchmark result is outdated, after posting this, I made several rounds of optimizations in this PR. The final benchmark result is shown in #16850 (comment) (the "main" section))

kernel benchmarks (on A800): https://gist.github.com/jinzhen-lin/d5228895171a8970631dc953296cec0a

shapes of DeepSeek-V3-AWQ (with TP=8)

image

shapes of Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4 (with TP=1)

image

shapes of Mixtral-8x7B-Instruct-v0.1-AWQ (with TP=1)

image

Summary:

  1. New marlin kernel (this PR) have better performance than old marlin kernel (main).
  2. The performance of marlin kernel is relatively low when k is small (even slower than triton version when bit = 8). The reason is that the marlin kernel is designed for large n and large k, the implemented multistages pipeline have little benefit when k is small.

Performance on DeepSeek-V3-AWQ (on 8*A800), with VLLM_MARLIN_USE_ATOMIC_ADD=1

  MLA-main MLA-PR no-MLA-main no-MLA-PR
prefill 4735.7 7986.5    
generation-bs=1 42.7 45.9 49.9 54.4
generation-bs=2 71.2 82.3 81.5 96.6
generation-bs=4 122.1 142.9 135 161.2
generation-bs=8 187 231.6 200.6 251
generation-bs=16 276.5 363.1 294.7 393.7
generation-bs=32 381.9 558.5 389.2 567.9
generation-bs=64 520.9 873.4 518.5 897.6
generation-bs=128 675 1393.1 681.9 1389.4
generation-bs=256 806 2046.7 826.2 2083
generation-bs=512 1392 2754.4 1461.7 2991

Accuracy Test on DeepSeek-R1-AWQ:

vllm (pretrained=/root/DeepSeek-R1-AWQ/,tensor_parallel_size=8,gpu_memory_utilization=0.95,trust_remote_code=True,max_model_len=16384,dtype=half,max_num_batched_tokens=16384), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9560|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9568|±  |0.0056|

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@github-actions
Copy link

github-actions bot commented Mar 7, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Mar 7, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@jinzhen-lin jinzhen-lin force-pushed the moe-wna16-marlin-kernel branch from 5d6921e to e6896d3 Compare March 10, 2025 10:50
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@mergify
Copy link

mergify bot commented Mar 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jinzhen-lin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 11, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@mergify mergify bot removed the needs-rebase label Mar 11, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@mgoin
Copy link
Member

mgoin commented Apr 11, 2025

And maybe most importantly, for the case we were using Marlin MoE before, this kernel is now the best choice for Mixtral 8x7B as well

# Main's Marlin MoE
Processed prompts: 100%|████████████████| 1319/1319 [03:44<00:00,  5.89it/s, est. speed input: 6255.10 toks/s, output: 783.65 toks/s]
vllm (pretrained=nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6141|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.6118|±  |0.0134|

# Triton MoE
Processed prompts: 100%|████████████████| 1319/1319 [06:06<00:00,  3.60it/s, est. speed input: 3823.89 toks/s, output: 481.79 toks/s]
vllm (pretrained=nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6156|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.6118|±  |0.0134|

# This PR's Marlin MoE
Processed prompts: 100%|███████████████| 1319/1319 [02:50<00:00,  7.73it/s, est. speed input: 8213.52 toks/s, output: 1031.45 toks/s]
vllm (pretrained=nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized,dtype=float16,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6209|±  |0.0134|
|     |       |strict-match    |     5|exact_match|↑  |0.6179|±  |0.0134|

Comment on lines +490 to +493
const int scales_expert_stride = prob_n * prob_k / group_size / 8;
const int zp_expert_stride =
is_zp_float ? prob_n * prob_k / group_size / 8
: prob_n * prob_k / group_size / (pack_factor * 4);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can prob_n * prob_k overflow and int32? If so, could you use int64_t instead? (This looks like it's probably fine, since we'd only overflow if a single expert was > 4GB but int32 overflows are common enough in vLLM that I look for these in every kernel PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression prob_n * prob_k is hard to overflow, even with prob_n = 32768 and prob_k = 65,536, their product only just reaches the maximum value of a 32-bit integer. Such values are almost impossible to encounter in an MoE model. However, if you're still concerned, we could consider using prob_k / group_size * prob_n / 8 instead of prob_n * prob_k / group_size / 8.

For the parts of the code that are easy to overflow, I've already switched to using int64. However, int64 consumes more registers, so I'd prefer to avoid using it unless absolutely necessary.

@davidsyoung
Copy link

davidsyoung commented Apr 11, 2025

I seem to be facing some overflow issues with this PR, which I don't face with mainline, possibly due to tp=16 - 16x3090 gpus:

CleanShot 2025-04-11 at 21 17 48@2x

env:

      - VLLM_USE_V1=0
      - VLLM_MARLIN_USE_ATOMIC_ADD=1
      - 
vllm serve /models/wanzhenchn_DeepSeek-R1-AWQ/
      --api-key b18766c98a9b8092dcb66033afabff4f
      --enable-reasoning 
      --reasoning-parser deepseek_r1
      --served-model-name deepseek-ai/DeepSeek-R1
      --gpu-memory-utilization 0.99
      --max-model-len 16384
      --max-seq-len-to-capture 16384
      --max-num-seqs 8 
      --trust-remote-code
      --tensor-parallel-size 16  
      --host 192.168.10.225 
      --port 8000 
      --dtype bfloat16 
      --max-num-batched-tokens 2048
      --enable-chunked-prefill 
      --enable-prefix-caching
Log

==========
== CUDA ==
==========

CUDA Version 12.4.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

INFO 04-11 12:43:49 [__init__.py:239] Automatically detected platform cuda.
INFO 04-11 12:43:52 [api_server.py:1034] vLLM API server version 0.8.3rc2.dev222+gb9813378c
INFO 04-11 12:43:52 [api_server.py:1035] args: Namespace(subparser='serve', model_tag='/models/wanzhenchn_DeepSeek-R1-AWQ/', config='', host='192.168.10.225', port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='b18766c98a9b8092dcb66033afabff4f', lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/models/wanzhenchn_DeepSeek-R1-AWQ/', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='bfloat16', kv_cache_dtype='auto', max_model_len=16384, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=16, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, disable_custom_all_reduce=False, block_size=None, enable_prefix_caching=True, prefix_caching_hash_algo='builtin', disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.99, num_gpu_blocks_override=None, max_num_batched_tokens=2048, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=8, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_token=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=16384, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=True, speculative_config=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=['deepseek-ai/DeepSeek-R1'], qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=True, reasoning_parser='deepseek_r1', disable_cascade_attn=False, disable_chunked_mm_input=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False, dispatch_function=<function ServeSubcommand.cmd at 0x14e57f81b920>)
WARNING 04-11 12:43:52 [utils.py:2175] Found ulimit of 40960 and failed to automatically increase with error current limit exceeds maximum limit. This can cause fd limit errors like `OSError: [Errno 24] Too many open files`. Consider increasing with ulimit -n
INFO 04-11 12:43:52 [config.py:209] Replacing legacy 'type' key with 'rope_type'
INFO 04-11 12:44:02 [config.py:676] This model supports multiple tasks: {'reward', 'embed', 'generate', 'classify', 'score'}. Defaulting to 'generate'.
INFO 04-11 12:44:04 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 04-11 12:44:04 [config.py:1697] Defaulting to use mp for distributed inference
INFO 04-11 12:44:04 [config.py:1885] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 04-11 12:44:09 [__init__.py:239] Automatically detected platform cuda.
INFO 04-11 12:44:11 [api_server.py:246] Started engine process with PID 327
INFO 04-11 12:44:12 [llm_engine.py:243] Initializing a V0 LLM engine (v0.8.3rc2.dev222+gb9813378c) with config: model='/models/wanzhenchn_DeepSeek-R1-AWQ/', speculative_config=None, tokenizer='/models/wanzhenchn_DeepSeek-R1-AWQ/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=16384, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=16, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend='deepseek_r1'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=deepseek-ai/DeepSeek-R1, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[8,4,2,1],"max_capture_size":8}, use_cached_outputs=True, 
WARNING 04-11 12:44:12 [multiproc_worker_utils.py:306] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 04-11 12:44:17 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=479) INFO 04-11 12:44:20 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:44:23 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=486) INFO 04-11 12:44:26 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:44:29 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=493) INFO 04-11 12:44:32 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:44:35 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=500) INFO 04-11 12:44:38 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:44:41 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=507) INFO 04-11 12:44:45 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:44:48 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=514) INFO 04-11 12:44:51 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:44:54 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=521) INFO 04-11 12:44:57 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:00 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=528) INFO 04-11 12:45:03 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:06 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=535) INFO 04-11 12:45:09 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:12 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=542) INFO 04-11 12:45:16 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:19 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=549) INFO 04-11 12:45:22 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:25 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=556) INFO 04-11 12:45:28 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:31 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=563) INFO 04-11 12:45:34 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:37 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=570) INFO 04-11 12:45:40 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-11 12:45:43 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=570) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=563) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=549) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=500) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=521) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=535) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=514) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=556) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=507) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=528) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=486) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=542) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=479) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=493) INFO 04-11 12:45:46 [cuda.py:191] Using Triton MLA backend.
WARNING 04-11 12:45:46 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=493) WARNING 04-11 12:45:46 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=570) WARNING 04-11 12:45:46 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) INFO 04-11 12:45:46 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
(VllmWorkerProcess pid=479) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=528) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=486) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=507) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=514) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=563) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=542) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=521) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=500) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=556) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=549) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=535) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) INFO 04-11 12:45:47 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=577) WARNING 04-11 12:45:47 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=570) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=577) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=570) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=577) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=479) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=549) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=528) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=549) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=479) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=528) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=486) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=556) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=486) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=514) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=556) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=535) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=514) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=507) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=535) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=507) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=563) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=493) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=542) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=563) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=500) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=493) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=521) INFO 04-11 12:45:58 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=542) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=500) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=521) INFO 04-11 12:45:58 [pynccl.py:69] vLLM is using nccl==2.21.5
WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=577) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=479) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=570) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=493) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=563) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=500) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=486) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=507) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=556) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=549) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=514) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=542) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=535) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=521) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=528) WARNING 04-11 12:46:01 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 04-11 12:46:01 [shm_broadcast.py:264] vLLM message queue communication handle: Handle(local_reader_ranks=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], buffer_handle=(15, 4194304, 6, 'psm_605de055'), local_subscribe_addr='ipc:///tmp/75068ce7-d6fc-42f5-a4a0-75e994e19725', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorkerProcess pid=570) INFO 04-11 12:46:01 [parallel_state.py:957] rank 14 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 14
(VllmWorkerProcess pid=577) INFO 04-11 12:46:01 [parallel_state.py:957] rank 15 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 15
(VllmWorkerProcess pid=549) INFO 04-11 12:46:01 [parallel_state.py:957] rank 11 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 11
(VllmWorkerProcess pid=563) INFO 04-11 12:46:01 [parallel_state.py:957] rank 13 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 13
(VllmWorkerProcess pid=556) INFO 04-11 12:46:01 [parallel_state.py:957] rank 12 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 12
(VllmWorkerProcess pid=521) INFO 04-11 12:46:01 [parallel_state.py:957] rank 7 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 7
(VllmWorkerProcess pid=535) INFO 04-11 12:46:01 [parallel_state.py:957] rank 9 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 9
(VllmWorkerProcess pid=500) INFO 04-11 12:46:01 [parallel_state.py:957] rank 4 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 4
(VllmWorkerProcess pid=528) INFO 04-11 12:46:01 [parallel_state.py:957] rank 8 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 8
(VllmWorkerProcess pid=514) INFO 04-11 12:46:01 [parallel_state.py:957] rank 6 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 6
(VllmWorkerProcess pid=507) INFO 04-11 12:46:01 [parallel_state.py:957] rank 5 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 5
(VllmWorkerProcess pid=479) INFO 04-11 12:46:01 [parallel_state.py:957] rank 1 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 1
(VllmWorkerProcess pid=542) INFO 04-11 12:46:01 [parallel_state.py:957] rank 10 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 10
INFO 04-11 12:46:01 [parallel_state.py:957] rank 0 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 0
(VllmWorkerProcess pid=493) INFO 04-11 12:46:01 [parallel_state.py:957] rank 3 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 3
(VllmWorkerProcess pid=486) INFO 04-11 12:46:01 [parallel_state.py:957] rank 2 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 2
(VllmWorkerProcess pid=479) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=486) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=521) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=500) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=514) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=507) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=493) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=542) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=570) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=549) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=556) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=535) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=563) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=528) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=577) INFO 04-11 12:46:01 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=479) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=486) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=500) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=514) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=521) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=507) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=493) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=542) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=549) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=570) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=556) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=563) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=535) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=528) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=577) WARNING 04-11 12:46:01 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
Loading safetensors checkpoint shards:   0% Completed | 0/71 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   1% Completed | 1/71 [00:23<27:42, 23.75s/it]
Loading safetensors checkpoint shards:   3% Completed | 2/71 [00:47<27:22, 23.80s/it]
Loading safetensors checkpoint shards:   4% Completed | 3/71 [01:11<27:05, 23.90s/it]
Loading safetensors checkpoint shards:   6% Completed | 4/71 [01:34<26:20, 23.59s/it]
Loading safetensors checkpoint shards:   7% Completed | 5/71 [01:58<26:00, 23.65s/it]
Loading safetensors checkpoint shards:   8% Completed | 6/71 [02:21<25:33, 23.60s/it]
Loading safetensors checkpoint shards:  10% Completed | 7/71 [02:45<25:15, 23.67s/it]
Loading safetensors checkpoint shards:  11% Completed | 8/71 [03:09<24:57, 23.76s/it]
Loading safetensors checkpoint shards:  13% Completed | 9/71 [03:33<24:29, 23.70s/it]
Loading safetensors checkpoint shards:  14% Completed | 10/71 [03:57<24:08, 23.75s/it]
Loading safetensors checkpoint shards:  15% Completed | 11/71 [04:20<23:41, 23.69s/it]
Loading safetensors checkpoint shards:  17% Completed | 12/71 [04:44<23:11, 23.59s/it]
Loading safetensors checkpoint shards:  18% Completed | 13/71 [05:07<22:42, 23.48s/it]
Loading safetensors checkpoint shards:  20% Completed | 14/71 [05:31<22:33, 23.74s/it]
Loading safetensors checkpoint shards:  21% Completed | 15/71 [05:55<22:09, 23.75s/it]
Loading safetensors checkpoint shards:  23% Completed | 16/71 [06:19<21:46, 23.76s/it]
Loading safetensors checkpoint shards:  24% Completed | 17/71 [06:43<21:24, 23.79s/it]
Loading safetensors checkpoint shards:  25% Completed | 18/71 [07:06<21:00, 23.78s/it]
Loading safetensors checkpoint shards:  27% Completed | 19/71 [07:30<20:38, 23.81s/it]
Loading safetensors checkpoint shards:  28% Completed | 20/71 [07:54<20:15, 23.82s/it]
Loading safetensors checkpoint shards:  30% Completed | 21/71 [08:18<19:56, 23.93s/it]
Loading safetensors checkpoint shards:  31% Completed | 22/71 [08:42<19:29, 23.86s/it]
Loading safetensors checkpoint shards:  32% Completed | 23/71 [09:06<19:06, 23.88s/it]
Loading safetensors checkpoint shards:  34% Completed | 24/71 [09:30<18:40, 23.85s/it]
Loading safetensors checkpoint shards:  35% Completed | 25/71 [09:53<18:12, 23.75s/it]
Loading safetensors checkpoint shards:  37% Completed | 26/71 [10:17<17:45, 23.67s/it]
Loading safetensors checkpoint shards:  38% Completed | 27/71 [10:40<17:11, 23.43s/it]
Loading safetensors checkpoint shards:  39% Completed | 28/71 [11:02<16:37, 23.19s/it]
Loading safetensors checkpoint shards:  41% Completed | 29/71 [11:25<16:05, 22.99s/it]
Loading safetensors checkpoint shards:  42% Completed | 30/71 [11:47<15:31, 22.71s/it]
Loading safetensors checkpoint shards:  44% Completed | 31/71 [12:06<14:29, 21.75s/it]
Loading safetensors checkpoint shards:  45% Completed | 32/71 [12:26<13:48, 21.25s/it]
Loading safetensors checkpoint shards:  46% Completed | 33/71 [12:47<13:15, 20.94s/it]
Loading safetensors checkpoint shards:  48% Completed | 34/71 [13:07<12:47, 20.74s/it]
Loading safetensors checkpoint shards:  49% Completed | 35/71 [13:27<12:20, 20.57s/it]
Loading safetensors checkpoint shards:  51% Completed | 36/71 [13:39<10:33, 18.09s/it]
Loading safetensors checkpoint shards:  52% Completed | 37/71 [13:59<10:31, 18.57s/it]
Loading safetensors checkpoint shards:  54% Completed | 38/71 [14:19<10:28, 19.04s/it]
Loading safetensors checkpoint shards:  55% Completed | 39/71 [14:39<10:14, 19.21s/it]
Loading safetensors checkpoint shards:  56% Completed | 40/71 [14:59<10:02, 19.45s/it]
Loading safetensors checkpoint shards:  58% Completed | 41/71 [15:19<09:49, 19.64s/it]
Loading safetensors checkpoint shards:  59% Completed | 42/71 [15:39<09:31, 19.69s/it]
Loading safetensors checkpoint shards:  61% Completed | 43/71 [15:58<09:12, 19.72s/it]
Loading safetensors checkpoint shards:  62% Completed | 44/71 [16:18<08:51, 19.67s/it]
Loading safetensors checkpoint shards:  63% Completed | 45/71 [16:38<08:37, 19.90s/it]
Loading safetensors checkpoint shards:  65% Completed | 46/71 [16:51<07:24, 17.76s/it]
Loading safetensors checkpoint shards:  66% Completed | 47/71 [16:53<05:09, 12.88s/it]
Loading safetensors checkpoint shards:  68% Completed | 48/71 [17:17<06:15, 16.33s/it]
Loading safetensors checkpoint shards:  69% Completed | 49/71 [17:41<06:47, 18.51s/it]
Loading safetensors checkpoint shards:  70% Completed | 50/71 [18:04<07:00, 20.01s/it]
Loading safetensors checkpoint shards:  72% Completed | 51/71 [18:28<07:03, 21.16s/it]
Loading safetensors checkpoint shards:  73% Completed | 52/71 [18:52<06:58, 22.05s/it]
Loading safetensors checkpoint shards:  75% Completed | 53/71 [19:16<06:46, 22.59s/it]
Loading safetensors checkpoint shards:  76% Completed | 54/71 [19:40<06:31, 23.05s/it]
Loading safetensors checkpoint shards:  77% Completed | 55/71 [20:04<06:13, 23.35s/it]
Loading safetensors checkpoint shards:  79% Completed | 56/71 [20:06<04:14, 16.96s/it]
Loading safetensors checkpoint shards:  80% Completed | 57/71 [20:29<04:21, 18.65s/it]
Loading safetensors checkpoint shards:  82% Completed | 58/71 [20:53<04:23, 20.28s/it]
Loading safetensors checkpoint shards:  83% Completed | 59/71 [21:17<04:15, 21.31s/it]
Loading safetensors checkpoint shards:  85% Completed | 60/71 [21:41<04:03, 22.11s/it]
Loading safetensors checkpoint shards:  86% Completed | 61/71 [22:04<03:45, 22.59s/it]
Loading safetensors checkpoint shards:  87% Completed | 62/71 [22:29<03:28, 23.15s/it]
Loading safetensors checkpoint shards:  89% Completed | 63/71 [22:53<03:07, 23.43s/it]
Loading safetensors checkpoint shards:  90% Completed | 64/71 [23:17<02:45, 23.62s/it]
Loading safetensors checkpoint shards:  92% Completed | 65/71 [23:41<02:22, 23.77s/it]
Loading safetensors checkpoint shards:  93% Completed | 66/71 [24:05<01:59, 23.85s/it]
Loading safetensors checkpoint shards:  94% Completed | 67/71 [24:29<01:35, 23.76s/it]
Loading safetensors checkpoint shards:  96% Completed | 68/71 [24:51<01:10, 23.47s/it]
Loading safetensors checkpoint shards:  97% Completed | 69/71 [25:14<00:46, 23.34s/it]
Loading safetensors checkpoint shards:  99% Completed | 70/71 [25:38<00:23, 23.42s/it]
Loading safetensors checkpoint shards: 100% Completed | 71/71 [26:01<00:00, 23.34s/it]
Loading safetensors checkpoint shards: 100% Completed | 71/71 [26:01<00:00, 22.00s/it]

(VllmWorkerProcess pid=577) INFO 04-11 13:12:06 [loader.py:458] Loading weights took 1561.63 seconds
(VllmWorkerProcess pid=535) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.42 seconds
(VllmWorkerProcess pid=514) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.35 seconds
(VllmWorkerProcess pid=479) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.59 seconds
(VllmWorkerProcess pid=493) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.47 seconds
INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.72 seconds
(VllmWorkerProcess pid=570) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.34 seconds
(VllmWorkerProcess pid=521) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.35 seconds
(VllmWorkerProcess pid=528) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.43 seconds
(VllmWorkerProcess pid=549) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.47 seconds
(VllmWorkerProcess pid=507) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.52 seconds
(VllmWorkerProcess pid=563) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.35 seconds
(VllmWorkerProcess pid=556) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.30 seconds
(VllmWorkerProcess pid=486) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.51 seconds
(VllmWorkerProcess pid=542) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.46 seconds
(VllmWorkerProcess pid=500) INFO 04-11 13:12:07 [loader.py:458] Loading weights took 1562.36 seconds
(VllmWorkerProcess pid=521) INFO 04-11 13:12:25 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.195885 seconds
(VllmWorkerProcess pid=542) INFO 04-11 13:12:25 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.194863 seconds
(VllmWorkerProcess pid=563) INFO 04-11 13:12:25 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.303328 seconds
(VllmWorkerProcess pid=528) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.632338 seconds
(VllmWorkerProcess pid=479) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.675280 seconds
(VllmWorkerProcess pid=556) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.677176 seconds
(VllmWorkerProcess pid=577) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.758098 seconds
(VllmWorkerProcess pid=535) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.754408 seconds
(VllmWorkerProcess pid=500) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.812992 seconds
(VllmWorkerProcess pid=514) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.788914 seconds
(VllmWorkerProcess pid=507) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1584.859842 seconds
(VllmWorkerProcess pid=570) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1585.127654 seconds
INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1585.206710 seconds
(VllmWorkerProcess pid=549) INFO 04-11 13:12:26 [model_runner.py:1146] Model loading took 21.2080 GiB and 1585.287564 seconds
(VllmWorkerProcess pid=493) INFO 04-11 13:12:27 [model_runner.py:1146] Model loading took 21.2080 GiB and 1586.487682 seconds
(VllmWorkerProcess pid=486) INFO 04-11 13:12:30 [model_runner.py:1146] Model loading took 21.2080 GiB and 1589.428879 seconds
WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=514) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=528) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=577) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=556) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=570) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=479) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=549) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=493) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=507) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=563) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=535) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=542) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=521) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=500) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=486) WARNING 04-11 13:12:41 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=507) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.70 seconds
(VllmWorkerProcess pid=507) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=507) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=570) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.75 seconds
(VllmWorkerProcess pid=570) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=570) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=479) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.72 seconds
(VllmWorkerProcess pid=479) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=479) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=486) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.79 seconds
(VllmWorkerProcess pid=486) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=486) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=556) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.78 seconds
(VllmWorkerProcess pid=556) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=556) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=542) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.77 seconds
(VllmWorkerProcess pid=542) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=542) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=528) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.81 seconds
(VllmWorkerProcess pid=528) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=528) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=514) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.78 seconds
(VllmWorkerProcess pid=514) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=514) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=500) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.83 seconds
(VllmWorkerProcess pid=500) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=500) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=521) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.78 seconds
(VllmWorkerProcess pid=521) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=521) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=577) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.79 seconds
(VllmWorkerProcess pid=577) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=577) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=549) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.80 seconds
(VllmWorkerProcess pid=549) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=549) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=535) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.79 seconds
(VllmWorkerProcess pid=535) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=535) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=493) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.79 seconds
(VllmWorkerProcess pid=493) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=493) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
(VllmWorkerProcess pid=563) INFO 04-11 13:12:48 [worker.py:267] Memory profiling takes 17.79 seconds
(VllmWorkerProcess pid=563) INFO 04-11 13:12:48 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=563) INFO 04-11 13:12:48 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
INFO 04-11 13:12:49 [worker.py:267] Memory profiling takes 18.05 seconds
INFO 04-11 13:12:49 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
INFO 04-11 13:12:49 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.83GiB; the rest of the memory reserved for KV Cache is 1.11GiB.
INFO 04-11 13:12:49 [executor_base.py:112] # cuda blocks: 1062, # CPU blocks: 3819
INFO 04-11 13:12:49 [executor_base.py:117] Maximum concurrency for 16384 tokens per request: 1.04x
(VllmWorkerProcess pid=528) INFO 04-11 13:13:32 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=535) INFO 04-11 13:13:32 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=507) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=514) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=577) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=556) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=542) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=549) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=479) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=500) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=570) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=563) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=493) INFO 04-11 13:13:33 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=486) INFO 04-11 13:13:34 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=521) INFO 04-11 13:13:34 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 04-11 13:13:34 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes: 100%|██████████| 4/4 [00:09<00:00,  2.47s/it]
INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=514) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=486) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=542) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=563) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=556) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=507) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 11 secs, took 0.22 GiB
(VllmWorkerProcess pid=535) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 11 secs, took 0.22 GiB
(VllmWorkerProcess pid=570) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=500) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=549) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=577) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=521) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=528) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 12 secs, took 0.22 GiB
(VllmWorkerProcess pid=493) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
(VllmWorkerProcess pid=479) INFO 04-11 13:13:44 [model_runner.py:1598] Graph capturing finished in 10 secs, took 0.22 GiB
INFO 04-11 13:13:44 [llm_engine.py:449] init engine (profile, create kv cache, warmup model) took 73.29 seconds
WARNING 04-11 13:13:44 [config.py:1164] Default sampling parameters have been overridden by the model's Hugging Face generation config recommended from the model creator. If this is not intended, please relaunch vLLM instance with `--generation-config vllm`.
INFO 04-11 13:13:44 [serving_chat.py:118] Using default chat sampling params from model: {'temperature': 0.6, 'top_p': 0.95}
INFO 04-11 13:13:44 [serving_completion.py:61] Using default completion sampling params from model: {'temperature': 0.6, 'top_p': 0.95}
INFO 04-11 13:13:44 [api_server.py:1081] Starting vLLM API server on http://192.168.10.225:8000
INFO 04-11 13:13:44 [launcher.py:26] Available routes are:
INFO 04-11 13:13:44 [launcher.py:34] Route: /openapi.json, Methods: GET, HEAD
INFO 04-11 13:13:44 [launcher.py:34] Route: /docs, Methods: GET, HEAD
INFO 04-11 13:13:44 [launcher.py:34] Route: /docs/oauth2-redirect, Methods: GET, HEAD
INFO 04-11 13:13:44 [launcher.py:34] Route: /redoc, Methods: GET, HEAD
INFO 04-11 13:13:44 [launcher.py:34] Route: /health, Methods: GET
INFO 04-11 13:13:44 [launcher.py:34] Route: /load, Methods: GET
INFO 04-11 13:13:44 [launcher.py:34] Route: /ping, Methods: GET, POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /tokenize, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /detokenize, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/models, Methods: GET
INFO 04-11 13:13:44 [launcher.py:34] Route: /version, Methods: GET
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/chat/completions, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/completions, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/embeddings, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /pooling, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /score, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/score, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/audio/transcriptions, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /rerank, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v1/rerank, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /v2/rerank, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /invocations, Methods: POST
INFO 04-11 13:13:44 [launcher.py:34] Route: /metrics, Methods: GET
INFO:     Started server process [1]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO 04-11 13:13:50 [chat_utils.py:396] Detected the chat template content format to be 'string'. You can set `--chat-template-content-format` to override this.
INFO 04-11 13:13:50 [logger.py:39] Received request chatcmpl-07cc674384ba4d2288748afc957e0f38: prompt: '<|begin▁of▁sentence|><|User|>test<|Assistant|><think>\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16378, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, lora_request: None, prompt_adapter_request: None.
INFO:     192.168.1.64:33556 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO 04-11 13:13:50 [engine.py:310] Added request chatcmpl-07cc674384ba4d2288748afc957e0f38.
INFO 04-11 13:13:51 [metrics.py:489] Avg prompt throughput: 0.8 tokens/s, Avg generation throughput: 0.1 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.1%, CPU KV cache usage: 0.0%.
INFO 04-11 13:13:51 [metrics.py:505] Prefix cache hit rate: GPU: 0.00%, CPU: 0.00%
INFO 04-11 13:13:56 [metrics.py:489] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 39.4 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 1.2%, CPU KV cache usage: 0.0%.
INFO 04-11 13:13:56 [metrics.py:505] Prefix cache hit rate: GPU: 0.00%, CPU: 0.00%
INFO 04-11 13:13:58 [engine.py:330] Aborted request chatcmpl-07cc674384ba4d2288748afc957e0f38.
INFO 04-11 13:14:08 [metrics.py:489] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 5.4 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 04-11 13:14:08 [metrics.py:505] Prefix cache hit rate: GPU: 0.00%, CPU: 0.00%
INFO 04-11 13:14:18 [metrics.py:489] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 0.0%, CPU KV cache usage: 0.0%.
INFO 04-11 13:14:18 [metrics.py:505] Prefix cache hit rate: GPU: 0.00%, CPU: 0.00%
INFO 04-11 13:14:28 [logger.py:39] Received request chatcmpl-2045e1d060494411b180b835e836d441: prompt: '<|begin▁of▁sentence|><|User|>Hi, how are you?<|Assistant|><think>\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16373, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, lora_request: None, prompt_adapter_request: None.
INFO:     192.168.1.64:33560 - "POST /v1/chat/completions HTTP/1.1" 200 OK
INFO 04-11 13:14:28 [engine.py:310] Added request chatcmpl-2045e1d060494411b180b835e836d441.
INFO 04-11 13:14:32 [logger.py:39] Received request chatcmpl-99f500469c4946fe905573643b691a92: prompt: '<|begin▁of▁sentence|><|User|>### Task:\nGenerate a concise, 3-5 word title with an emoji summarizing the chat history.\n### Guidelines:\n- The title should clearly represent the main theme or subject of the conversation.\n- Use emojis that enhance understanding of the topic, but avoid quotation marks or special formatting.\n- Write the title in the chat\'s primary language; default to English if multilingual.\n- Prioritize accuracy over excessive creativity; keep it clear and simple.\n### Output:\nJSON format: { "title": "your concise title here" }\n### Examples:\n- { "title": "📉 Stock Market Trends" },\n- { "title": "🍪 Perfect Chocolate Chip Recipe" },\n- { "title": "Evolution of Music Streaming" },\n- { "title": "Remote Work Productivity Tips" },\n- { "title": "Artificial Intelligence in Healthcare" },\n- { "title": "🎮 Video Game Development Insights" }\n### Chat History:\n<chat_history>\nUSER: Hi, how are you?\nASSISTANT: Hi! I\'m just a computer\nprogram, so I don\'t have feelings, but thanks for asking! 😊 How can I help you today?\n</chat_history><|Assistant|><think>\n', params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=1000, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, lora_request: None, prompt_adapter_request: None.
INFO 04-11 13:14:32 [engine.py:310] Added request chatcmpl-99f500469c4946fe905573643b691a92.
INFO 04-11 13:14:33 [metrics.py:489] Avg prompt throughput: 51.5 tokens/s, Avg generation throughput: 26.9 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 1.6%, CPU KV cache usage: 0.0%.
INFO 04-11 13:14:33 [metrics.py:505] Prefix cache hit rate: GPU: 0.00%, CPU: 0.00%
INFO 04-11 13:14:38 [metrics.py:489] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 38.0 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 2.7%, CPU KV cache usage: 0.0%.
INFO 04-11 13:14:38 [metrics.py:505] Prefix cache hit rate: GPU: 0.00%, CPU: 0.00%
INFO 04-11 13:14:42 [logger.py:39] Received request chatcmpl-fb779c9f25914e32a7b8dcb7cd1d5f4f: prompt: "<|begin▁of▁sentence|><|User|>Hi, how are you?<|Assistant|>\nHi! I'm just a computer\nprogram, so I don't have feelings, but thanks for asking! 😊 How can I help you today?<|end▁of▁sentence|><|User|>Are you an AI?<|Assistant|><think>\n", params: SamplingParams(n=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.6, top_p=0.95, top_k=-1, min_p=0.0, seed=None, stop=[], stop_token_ids=[], bad_words=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=16333, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None, guided_decoding=None, extra_args=None), prompt_token_ids: None, lora_request: None, prompt_adapter_request: None.
INFO:     192.168.1.64:37824 - "POST /v1/chat/completions HTTP/1.1" 200 OK

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@jinzhen-lin
Copy link
Contributor Author

I seem to be facing some overflow issues with this PR, which I don't face with mainline, possibly due to tp=16 - 16x3090 gpus:

CleanShot 2025-04-11 at 21 17 48@2x

env:

      - VLLM_USE_V1=0
      - VLLM_MARLIN_USE_ATOMIC_ADD=1
      - 
vllm serve /models/wanzhenchn_DeepSeek-R1-AWQ/
      --api-key b18766c98a9b8092dcb66033afabff4f
      --enable-reasoning 
      --reasoning-parser deepseek_r1
      --served-model-name deepseek-ai/DeepSeek-R1
      --gpu-memory-utilization 0.99
      --max-model-len 16384
      --max-seq-len-to-capture 16384
      --max-num-seqs 8 
      --trust-remote-code
      --tensor-parallel-size 16  
      --host 192.168.10.225 
      --port 8000 
      --dtype bfloat16 
      --max-num-batched-tokens 2048
      --enable-chunked-prefill 
      --enable-prefix-caching

Log

This problem is hard to reproduce, you could try to debug it by removing some arguments, for example, try to remove --enable-chunked-prefill and --enable-prefix-caching.

@davidsyoung
Copy link

I seem to be facing some overflow issues with this PR, which I don't face with mainline, possibly due to tp=16 - 16x3090 gpus:
CleanShot 2025-04-11 at 21 17 48@2x
env:

      - VLLM_USE_V1=0
      - VLLM_MARLIN_USE_ATOMIC_ADD=1
      - 
vllm serve /models/wanzhenchn_DeepSeek-R1-AWQ/
      --api-key b18766c98a9b8092dcb66033afabff4f
      --enable-reasoning 
      --reasoning-parser deepseek_r1
      --served-model-name deepseek-ai/DeepSeek-R1
      --gpu-memory-utilization 0.99
      --max-model-len 16384
      --max-seq-len-to-capture 16384
      --max-num-seqs 8 
      --trust-remote-code
      --tensor-parallel-size 16  
      --host 192.168.10.225 
      --port 8000 
      --dtype bfloat16 
      --max-num-batched-tokens 2048
      --enable-chunked-prefill 
      --enable-prefix-caching

Log

This problem is hard to reproduce, you could try to debug it by removing some arguments, for example, try to remove --enable-chunked-prefill and --enable-prefix-caching.

Tried this, removed chunked prefill, prefix caching, max-num-batched-tokens, unfort got error (latest commit):

Log
==========
== CUDA ==
==========

CUDA Version 12.4.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

INFO 04-12 05:09:49 [__init__.py:239] Automatically detected platform cuda.
INFO 04-12 05:09:53 [api_server.py:1034] vLLM API server version 0.8.3rc2.dev224+g49c0d11ec
INFO 04-12 05:09:53 [api_server.py:1035] args: Namespace(subparser='serve', model_tag='/models/cognitivecomputations_DeepSeek-V3-0324-AWQ/', config='', host='192.168.10.225', port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='b18766c98a9b8092dcb66033afabff4f', lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/models/cognitivecomputations_DeepSeek-V3-0324-AWQ/', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='bfloat16', kv_cache_dtype='auto', max_model_len=8192, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=16, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, disable_custom_all_reduce=False, block_size=None, enable_prefix_caching=None, prefix_caching_hash_algo='builtin', disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.99, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=8, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_token=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_config=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=['deepseek-ai/DeepSeek-R1'], qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_chunked_mm_input=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False, dispatch_function=<function ServeSubcommand.cmd at 0x152e824e3920>)
WARNING 04-12 05:09:53 [utils.py:2175] Found ulimit of 40960 and failed to automatically increase with error current limit exceeds maximum limit. This can cause fd limit errors like `OSError: [Errno 24] Too many open files`. Consider increasing with ulimit -n
INFO 04-12 05:09:53 [config.py:209] Replacing legacy 'type' key with 'rope_type'
INFO 04-12 05:10:03 [config.py:676] This model supports multiple tasks: {'classify', 'reward', 'generate', 'score', 'embed'}. Defaulting to 'generate'.
INFO 04-12 05:10:04 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 04-12 05:10:04 [config.py:1697] Defaulting to use mp for distributed inference
INFO 04-12 05:10:08 [__init__.py:239] Automatically detected platform cuda.
INFO 04-12 05:10:10 [api_server.py:246] Started engine process with PID 327
INFO 04-12 05:10:12 [llm_engine.py:243] Initializing a V0 LLM engine (v0.8.3rc2.dev224+g49c0d11ec) with config: model='/models/cognitivecomputations_DeepSeek-V3-0324-AWQ/', speculative_config=None, tokenizer='/models/cognitivecomputations_DeepSeek-V3-0324-AWQ/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=16, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=deepseek-ai/DeepSeek-R1, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[8,4,2,1],"max_capture_size":8}, use_cached_outputs=True, 
WARNING 04-12 05:10:12 [multiproc_worker_utils.py:306] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 04-12 05:10:16 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=479) INFO 04-12 05:10:19 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:22 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=486) INFO 04-12 05:10:25 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:28 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=493) INFO 04-12 05:10:31 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:34 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=500) INFO 04-12 05:10:37 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:40 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=507) INFO 04-12 05:10:43 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:46 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=514) INFO 04-12 05:10:49 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:52 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=521) INFO 04-12 05:10:55 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:10:58 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=528) INFO 04-12 05:11:01 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:04 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=535) INFO 04-12 05:11:07 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:10 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=542) INFO 04-12 05:11:13 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:16 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=549) INFO 04-12 05:11:19 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:22 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=556) INFO 04-12 05:11:25 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:28 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=563) INFO 04-12 05:11:31 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:34 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=570) INFO 04-12 05:11:36 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 05:11:40 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=570) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=556) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=535) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=486) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=479) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=549) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=493) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=500) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=563) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=521) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=507) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=542) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=528) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=514) INFO 04-12 05:11:42 [cuda.py:191] Using Triton MLA backend.
WARNING 04-12 05:11:42 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) INFO 04-12 05:11:43 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
(VllmWorkerProcess pid=493) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=563) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=486) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=556) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=535) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=549) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=479) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=542) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=500) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) INFO 04-12 05:11:43 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=507) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=521) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=570) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=528) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=514) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) WARNING 04-12 05:11:43 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=479) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=521) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=479) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=521) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=577) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=500) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=507) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=577) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=542) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=507) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=500) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=542) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=514) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=570) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=528) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=514) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=556) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=570) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=528) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=563) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=556) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=486) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=563) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=486) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=493) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=535) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=493) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=549) INFO 04-12 05:11:54 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=535) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=549) INFO 04-12 05:11:54 [pynccl.py:69] vLLM is using nccl==2.21.5
WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=479) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=577) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=486) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=570) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=563) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=493) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=500) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=507) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=556) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=514) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=549) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=542) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=521) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=535) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=528) WARNING 04-12 05:11:57 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 04-12 05:11:57 [shm_broadcast.py:264] vLLM message queue communication handle: Handle(local_reader_ranks=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], buffer_handle=(15, 4194304, 6, 'psm_0a64431c'), local_subscribe_addr='ipc:///tmp/33bd88a2-4d72-4f35-97d1-591a3f91f500', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 04-12 05:11:57 [parallel_state.py:957] rank 0 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 0
(VllmWorkerProcess pid=577) INFO 04-12 05:11:57 [parallel_state.py:957] rank 15 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 15
(VllmWorkerProcess pid=570) INFO 04-12 05:11:57 [parallel_state.py:957] rank 14 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 14
(VllmWorkerProcess pid=563) INFO 04-12 05:11:57 [parallel_state.py:957] rank 13 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 13
(VllmWorkerProcess pid=556) INFO 04-12 05:11:57 [parallel_state.py:957] rank 12 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 12
(VllmWorkerProcess pid=549) INFO 04-12 05:11:57 [parallel_state.py:957] rank 11 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 11
(VllmWorkerProcess pid=507) INFO 04-12 05:11:57 [parallel_state.py:957] rank 5 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 5
(VllmWorkerProcess pid=528) INFO 04-12 05:11:57 [parallel_state.py:957] rank 8 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 8
(VllmWorkerProcess pid=514) INFO 04-12 05:11:57 [parallel_state.py:957] rank 6 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 6
(VllmWorkerProcess pid=500) INFO 04-12 05:11:57 [parallel_state.py:957] rank 4 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 4
(VllmWorkerProcess pid=493) INFO 04-12 05:11:57 [parallel_state.py:957] rank 3 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 3
(VllmWorkerProcess pid=535) INFO 04-12 05:11:57 [parallel_state.py:957] rank 9 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 9
(VllmWorkerProcess pid=542) INFO 04-12 05:11:57 [parallel_state.py:957] rank 10 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 10
(VllmWorkerProcess pid=486) INFO 04-12 05:11:57 [parallel_state.py:957] rank 2 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 2
(VllmWorkerProcess pid=479) INFO 04-12 05:11:57 [parallel_state.py:957] rank 1 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 1
(VllmWorkerProcess pid=521) INFO 04-12 05:11:57 [parallel_state.py:957] rank 7 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 7
INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=528) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=535) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=486) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=542) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=507) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=549) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=563) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=521) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=479) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=556) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=514) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=577) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=500) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=493) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=570) INFO 04-12 05:11:57 [model_runner.py:1110] Starting to load model /models/cognitivecomputations_DeepSeek-V3-0324-AWQ/...
(VllmWorkerProcess pid=528) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=542) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=486) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=479) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=563) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=535) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=507) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=549) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=521) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=556) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=514) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=577) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=570) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=493) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=500) WARNING 04-12 05:11:57 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
Loading safetensors checkpoint shards:   0% Completed | 0/36 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   3% Completed | 1/36 [00:46<27:04, 46.42s/it]
Loading safetensors checkpoint shards:   6% Completed | 2/36 [01:32<26:14, 46.30s/it]
Loading safetensors checkpoint shards:   8% Completed | 3/36 [02:19<25:38, 46.63s/it]
Loading safetensors checkpoint shards:  11% Completed | 4/36 [03:06<24:55, 46.74s/it]
Loading safetensors checkpoint shards:  14% Completed | 5/36 [03:52<24:01, 46.50s/it]
Loading safetensors checkpoint shards:  17% Completed | 6/36 [04:28<21:28, 42.95s/it]
Loading safetensors checkpoint shards:  19% Completed | 7/36 [05:13<21:06, 43.66s/it]
Loading safetensors checkpoint shards:  22% Completed | 8/36 [05:42<18:09, 38.90s/it]
Loading safetensors checkpoint shards:  25% Completed | 9/36 [06:30<18:47, 41.78s/it]
Loading safetensors checkpoint shards:  28% Completed | 10/36 [07:01<16:36, 38.34s/it]
Loading safetensors checkpoint shards:  31% Completed | 11/36 [07:05<11:36, 27.86s/it]
Loading safetensors checkpoint shards:  33% Completed | 12/36 [07:10<08:21, 20.88s/it]
Loading safetensors checkpoint shards:  36% Completed | 13/36 [07:15<06:14, 16.27s/it]
Loading safetensors checkpoint shards:  39% Completed | 14/36 [07:21<04:43, 12.91s/it]
Loading safetensors checkpoint shards:  42% Completed | 15/36 [07:25<03:34, 10.22s/it]
Loading safetensors checkpoint shards:  44% Completed | 16/36 [07:45<04:25, 13.27s/it]
Loading safetensors checkpoint shards:  47% Completed | 17/36 [08:32<07:27, 23.55s/it]
Loading safetensors checkpoint shards:  50% Completed | 18/36 [09:20<09:13, 30.73s/it]
Loading safetensors checkpoint shards:  53% Completed | 19/36 [10:08<10:09, 35.87s/it]
Loading safetensors checkpoint shards:  56% Completed | 20/36 [10:11<06:57, 26.11s/it]
Loading safetensors checkpoint shards:  58% Completed | 21/36 [10:56<07:56, 31.76s/it]
Loading safetensors checkpoint shards:  61% Completed | 22/36 [11:44<08:31, 36.54s/it]
Loading safetensors checkpoint shards:  64% Completed | 23/36 [12:31<08:35, 39.66s/it]
Loading safetensors checkpoint shards:  67% Completed | 24/36 [13:18<08:24, 42.03s/it]
Loading safetensors checkpoint shards:  69% Completed | 25/36 [14:04<07:55, 43.20s/it]
Loading safetensors checkpoint shards:  72% Completed | 26/36 [14:51<07:23, 44.35s/it]
Loading safetensors checkpoint shards:  75% Completed | 27/36 [15:37<06:44, 44.94s/it]
Loading safetensors checkpoint shards:  78% Completed | 28/36 [16:22<05:57, 44.68s/it]
Loading safetensors checkpoint shards:  81% Completed | 29/36 [17:07<05:13, 44.84s/it]
Loading safetensors checkpoint shards:  83% Completed | 30/36 [17:53<04:31, 45.28s/it]
Loading safetensors checkpoint shards:  86% Completed | 31/36 [18:39<03:47, 45.41s/it]
Loading safetensors checkpoint shards:  89% Completed | 32/36 [19:25<03:03, 45.79s/it]
Loading safetensors checkpoint shards:  92% Completed | 33/36 [20:13<02:18, 46.20s/it]
Loading safetensors checkpoint shards:  94% Completed | 34/36 [20:59<01:32, 46.18s/it]
Loading safetensors checkpoint shards:  97% Completed | 35/36 [21:45<00:46, 46.16s/it]
Loading safetensors checkpoint shards: 100% Completed | 36/36 [22:31<00:00, 46.03s/it]
Loading safetensors checkpoint shards: 100% Completed | 36/36 [22:31<00:00, 37.53s/it]

INFO 04-12 05:34:32 [loader.py:458] Loading weights took 1351.40 seconds
(VllmWorkerProcess pid=486) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.11 seconds
(VllmWorkerProcess pid=549) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.16 seconds
(VllmWorkerProcess pid=577) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.11 seconds
(VllmWorkerProcess pid=563) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.13 seconds
(VllmWorkerProcess pid=570) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.10 seconds
(VllmWorkerProcess pid=528) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.12 seconds
(VllmWorkerProcess pid=535) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.19 seconds
(VllmWorkerProcess pid=521) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.14 seconds
(VllmWorkerProcess pid=493) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.20 seconds
(VllmWorkerProcess pid=479) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.19 seconds
(VllmWorkerProcess pid=542) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.10 seconds
(VllmWorkerProcess pid=556) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.15 seconds
(VllmWorkerProcess pid=507) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.19 seconds
(VllmWorkerProcess pid=500) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.16 seconds
(VllmWorkerProcess pid=514) INFO 04-12 05:34:34 [loader.py:458] Loading weights took 1353.18 seconds
INFO 04-12 05:34:52 [model_runner.py:1146] Model loading took 21.5551 GiB and 1374.495161 seconds
(VllmWorkerProcess pid=535) INFO 04-12 05:34:52 [model_runner.py:1146] Model loading took 21.5551 GiB and 1375.234557 seconds
(VllmWorkerProcess pid=500) INFO 04-12 05:34:52 [model_runner.py:1146] Model loading took 21.5551 GiB and 1375.243825 seconds
(VllmWorkerProcess pid=521) INFO 04-12 05:34:52 [model_runner.py:1146] Model loading took 21.5551 GiB and 1375.322688 seconds
(VllmWorkerProcess pid=542) INFO 04-12 05:34:52 [model_runner.py:1146] Model loading took 21.5551 GiB and 1375.377357 seconds
(VllmWorkerProcess pid=556) INFO 04-12 05:34:53 [model_runner.py:1146] Model loading took 21.5551 GiB and 1375.801422 seconds
(VllmWorkerProcess pid=486) INFO 04-12 05:34:53 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.017350 seconds
(VllmWorkerProcess pid=479) INFO 04-12 05:34:53 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.041296 seconds
(VllmWorkerProcess pid=549) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.509355 seconds
(VllmWorkerProcess pid=514) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.594522 seconds
(VllmWorkerProcess pid=570) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.746925 seconds
(VllmWorkerProcess pid=563) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.906154 seconds
(VllmWorkerProcess pid=507) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.923624 seconds
(VllmWorkerProcess pid=577) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1376.941064 seconds
(VllmWorkerProcess pid=528) INFO 04-12 05:34:54 [model_runner.py:1146] Model loading took 21.5551 GiB and 1377.015945 seconds
(VllmWorkerProcess pid=493) INFO 04-12 05:34:56 [model_runner.py:1146] Model loading took 21.5551 GiB and 1378.671318 seconds
WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=521) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=556) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=563) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=528) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=514) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=542) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=500) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=493) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=535) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=507) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=577) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=549) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=570) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=479) WARNING 04-12 05:35:07 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=486) WARNING 04-12 05:35:08 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=500) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=542) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=479) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Ex(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=514) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=556) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:2(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=563) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcessception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=486) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:2 pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=535) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
ERROR 04-12 05:35:29 [engine.py:448] NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
ERROR 04-12 05:35:29 [engine.py:448] Traceback (most recent call last):
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 436, in run_mp_engine
ERROR 04-12 05:35:29 [engine.py:448]     engine = MQLLMEngine.from_vllm_config(
ERROR 04-12 05:35:29 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 128, in from_vllm_config
ERROR 04-12 05:35:29 [engine.py:448]     return cls(
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 82, in __init__
ERROR 04-12 05:35:29 [engine.py:448]     self.engine = LLMEngine(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 285, in __init__
ERROR 04-12 05:35:29 [engine.py:448]     self._initialize_kv_caches()
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 434, in _initialize_kv_caches
ERROR 04-12 05:35:29 [engine.py:448]     self.model_executor.determine_num_available_blocks())
ERROR 04-12 05:35:29 [engine.py:448]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 103, in determine_num_available_blocks
ERROR 04-12 05:35:29 [engine.py:448]     results = self.collective_rpc("determine_num_available_blocks")
ERROR 04-12 05:35:29 [engine.py:448]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 331, in collective_rpc
ERROR 04-12 05:35:29 [engine.py:448]     return self._run_workers(method, *args, **(kwargs or {}))
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
ERROR 04-12 05:35:29 [engine.py:448]     driver_worker_output = run_method(self.driver_worker, sent_method,
ERROR 04-12 05:35:29 [engine.py:448]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
ERROR 04-12 05:35:29 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-12 05:35:29 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
ERROR 04-12 05:35:29 [engine.py:448]     self.model_runner.profile_run()
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-12 05:35:29 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
ERROR 04-12 05:35:29 [engine.py:448]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
ERROR 04-12 05:35:29 [engine.py:448]     self.execute_model(model_input, kv_caches, intermediate_tensors)
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-12 05:35:29 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
ERROR 04-12 05:35:29 [engine.py:448]     logits = self.model.compute_logits(hidden_or_intermediate_states,
ERROR 04-12 05:35:29 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
ERROR 04-12 05:35:29 [engine.py:448]     logits = self.logits_processor(self.lm_head, hidden_states,
ERROR 04-12 05:35:29 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-12 05:35:29 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-12 05:35:29 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
ERROR 04-12 05:35:29 [engine.py:448]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
ERROR 04-12 05:35:29 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
ERROR 04-12 05:35:29 [engine.py:448]     logits = self._gather_logits(logits)
ERROR 04-12 05:35:29 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
ERROR 04-12 05:35:29 [engine.py:448]     logits = tensor_model_parallel_gather(logits)
ERROR 04-12 05:35:29 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
ERROR 04-12 05:35:29 [engine.py:448]     return get_tp_group().gather(input_, dst, dim)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
ERROR 04-12 05:35:29 [engine.py:448]     return self.device_communicator.gather(input_, dst, dim)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
ERROR 04-12 05:35:29 [engine.py:448]     torch.distributed.gather(input_,
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
ERROR 04-12 05:35:29 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
ERROR 04-12 05:35:29 [engine.py:448]     work = group.gather(output_tensors, input_tensors, opts)
ERROR 04-12 05:35:29 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 05:35:29 [engine.py:448] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=549) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
9 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] Traceback (most recent call last):
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/multiproc_worker_utils.py", line 232, in _run_worker_process
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     output = run_method(worker, method, args, kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.model_runner.profile_run()
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1243, in profile_run
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self._dummy_run(max_num_batched_tokens, max_num_seqs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1369, in _dummy_run
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     self.execute_model(model_input, kv_caches, intermediate_tensors)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1816, in execute_model
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.model.compute_logits(hidden_or_intermediate_states,
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=493) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
 pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 712, in compute_logits
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self.logits_processor(self.lm_head, hidden_states,
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 70, in forward
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._get_logits(hidden_states, lm_head, embedding_bias)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 113, in _get_logits
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = self._gather_logits(logits)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/logits_processor.py", line 98, in _gather_logits
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     logits = tensor_model_parallel_gather(logits)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/communication_op.py", line 26, in tensor_model_parallel_gather
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return get_tp_group().gather(input_, dst, dim)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/parallel_state.py", line 341, in gather
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return self.device_communicator.gather(input_, dst, dim)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=507) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=521) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=528) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
9 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=570) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/vllm/distributed/device_communicators/base_device_communicator.py", line 86, in gather
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     torch.distributed.gather(input_,
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     return func(*args, **kwargs)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/distributed_c10d.py", line 4006, in gather
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]     work = group.gather(output_tensors, input_tensors, opts)
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=577) ERROR 04-12 05:35:29 [multiproc_worker_utils.py:238] RuntimeError: NCCL Error 1: unhandled cuda error (run with NCCL_DEBUG=INFO for details)
Traceback (most recent call last):
  File "/usr/local/bin/vllm", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/main.py", line 51, in main
    args.dispatch_function(args)
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/serve.py", line 27, in cmd
    uvloop.run(run_server(args))
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 1069, in run_server
    async with build_async_engine_client(args) as engine_client:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 146, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 269, in build_async_engine_client_from_engine_args
    raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 15 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '


@jinzhen-lin
Copy link
Contributor Author

Tried this, removed chunked prefill, prefix caching, max-num-batched-tokens, unfort got error (latest commit):

Your error may due to OOM, 16×24G is really pushing the limits for running this model.

I have tested the kernel with shapes of deepseek-r1-awq + tp16, but unfortunarely, I amd unable to reproduced this. Maybe you can look into more things. Your screenshot shows that the output is mostly correct, but there are a few instances of garbled text. So for now, I suspect the issue might lie in the attention, multithreading, or some other component, rather than this kernel itself.

@davidsyoung
Copy link

davidsyoung commented Apr 12, 2025

Tried this, removed chunked prefill, prefix caching, max-num-batched-tokens, unfort got error (latest commit):

Your error may due to OOM, 16×24G is really pushing the limits for running this model.

I have tested the kernel with shapes of deepseek-r1-awq + tp16, but unfortunarely, I amd unable to reproduced this. Maybe you can look into more things. Your screenshot shows that the output is mostly correct, but there are a few instances of garbled text. So for now, I suspect the issue might lie in the attention, multithreading, or some other component, rather than this kernel itself.

Well, I can normally run the DeepSeek R1 model at 16k ctx with this setup on AWQ, no issues. I can try that model again with the latest build. I don't have issues with mainline compared to this PR. Happy to keep testing if there's anything you can suggest. I'll trial the new build when it's done on CI.

Could it be group_size = 128 on the AWQ models I use?

@jinzhen-lin
Copy link
Contributor Author

Tried this, removed chunked prefill, prefix caching, max-num-batched-tokens, unfort got error (latest commit):

Your error may due to OOM, 16×24G is really pushing the limits for running this model.
I have tested the kernel with shapes of deepseek-r1-awq + tp16, but unfortunarely, I amd unable to reproduced this. Maybe you can look into more things. Your screenshot shows that the output is mostly correct, but there are a few instances of garbled text. So for now, I suspect the issue might lie in the attention, multithreading, or some other component, rather than this kernel itself.

Well, I can normally run the DeepSeek R1 model at 16k ctx with this setup on AWQ, no issues. I can try that model again with the latest build. I don't have issues with mainline compared to this PR. Happy to keep testing if there's anything you can suggest. I'll trial the new build when it's done on CI.

Could it be group_size = 128 on the AWQ models I use?

Yes, this kernel support group_size=32,64,128 and channelwise quantization.

@davidsyoung
Copy link

This is with the previous build before CI, but It looks to be a VRAM issue, that without --enable-chunked-prefill I go OOM.

Log

==========
== CUDA ==
==========

CUDA Version 12.4.1

Container image Copyright (c) 2016-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

INFO 04-12 07:52:09 [__init__.py:239] Automatically detected platform cuda.
INFO 04-12 07:52:12 [api_server.py:1034] vLLM API server version 0.8.3rc2.dev224+g49c0d11ec
INFO 04-12 07:52:12 [api_server.py:1035] args: Namespace(subparser='serve', model_tag='/models/wanzhenchn_DeepSeek-R1-AWQ/', config='', host='192.168.10.225', port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key='b18766c98a9b8092dcb66033afabff4f', lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/models/wanzhenchn_DeepSeek-R1-AWQ/', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='bfloat16', kv_cache_dtype='auto', max_model_len=2048, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=16, data_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, disable_custom_all_reduce=False, block_size=None, enable_prefix_caching=None, prefix_caching_hash_algo='builtin', disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.99, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=8, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_token=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=2048, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_config=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=['deepseek-ai/DeepSeek-R1'], qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_chunked_mm_input=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False, dispatch_function=<function ServeSubcommand.cmd at 0x148d3c4e3920>)
WARNING 04-12 07:52:12 [utils.py:2175] Found ulimit of 40960 and failed to automatically increase with error current limit exceeds maximum limit. This can cause fd limit errors like `OSError: [Errno 24] Too many open files`. Consider increasing with ulimit -n
INFO 04-12 07:52:12 [config.py:209] Replacing legacy 'type' key with 'rope_type'
INFO 04-12 07:52:22 [config.py:676] This model supports multiple tasks: {'classify', 'reward', 'generate', 'embed', 'score'}. Defaulting to 'generate'.
INFO 04-12 07:52:24 [awq_marlin.py:113] The model is convertible to awq_marlin during runtime. Using awq_marlin kernel.
INFO 04-12 07:52:24 [config.py:1697] Defaulting to use mp for distributed inference
INFO 04-12 07:52:28 [__init__.py:239] Automatically detected platform cuda.
INFO 04-12 07:52:30 [api_server.py:246] Started engine process with PID 327
INFO 04-12 07:52:31 [llm_engine.py:243] Initializing a V0 LLM engine (v0.8.3rc2.dev224+g49c0d11ec) with config: model='/models/wanzhenchn_DeepSeek-R1-AWQ/', speculative_config=None, tokenizer='/models/wanzhenchn_DeepSeek-R1-AWQ/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=16, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=awq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=deepseek-ai/DeepSeek-R1, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=None, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[8,4,2,1],"max_capture_size":8}, use_cached_outputs=True, 
WARNING 04-12 07:52:31 [multiproc_worker_utils.py:306] Reducing Torch parallelism from 64 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 04-12 07:52:35 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=479) INFO 04-12 07:52:38 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:52:41 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=486) INFO 04-12 07:52:44 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:52:47 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=493) INFO 04-12 07:52:50 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:52:53 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=500) INFO 04-12 07:52:56 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:52:59 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=507) INFO 04-12 07:53:02 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:05 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=514) INFO 04-12 07:53:08 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:11 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=521) INFO 04-12 07:53:14 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:17 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=528) INFO 04-12 07:53:20 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:23 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=535) INFO 04-12 07:53:26 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:28 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=542) INFO 04-12 07:53:31 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:34 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=549) INFO 04-12 07:53:37 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:40 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=556) INFO 04-12 07:53:43 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:46 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=563) INFO 04-12 07:53:49 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:52 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=570) INFO 04-12 07:53:55 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
INFO 04-12 07:53:58 [__init__.py:239] Automatically detected platform cuda.
(VllmWorkerProcess pid=528) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=479) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=486) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=556) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=493) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=535) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=570) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=563) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=549) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=500) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=514) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=521) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=542) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=507) INFO 04-12 07:54:00 [cuda.py:191] Using Triton MLA backend.
WARNING 04-12 07:54:00 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) INFO 04-12 07:54:01 [multiproc_worker_utils.py:225] Worker ready; awaiting tasks
(VllmWorkerProcess pid=528) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=500) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=521) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=486) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=514) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=507) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=563) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=556) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=493) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=535) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=479) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=542) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=549) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=570) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=577) INFO 04-12 07:54:01 [cuda.py:191] Using Triton MLA backend.
(VllmWorkerProcess pid=577) WARNING 04-12 07:54:01 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorkerProcess pid=479) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=500) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=479) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=507) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=500) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=507) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=528) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=528) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=570) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=549) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=570) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=549) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=542) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=535) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=542) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=535) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=486) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=577) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=486) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=577) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=563) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=563) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=514) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=514) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=521) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=521) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=493) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=556) INFO 04-12 07:54:12 [utils.py:991] Found nccl from library libnccl.so.2
(VllmWorkerProcess pid=493) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorkerProcess pid=556) INFO 04-12 07:54:12 [pynccl.py:69] vLLM is using nccl==2.21.5
WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=577) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=479) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=570) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=486) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=563) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=493) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=500) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=556) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=507) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=549) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=514) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=542) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=535) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=521) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorkerProcess pid=528) WARNING 04-12 07:54:15 [custom_all_reduce.py:97] Custom allreduce is disabled due to an unsupported world size: 16. Supported world sizes: [2, 4, 6, 8]. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 04-12 07:54:15 [shm_broadcast.py:264] vLLM message queue communication handle: Handle(local_reader_ranks=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], buffer_handle=(15, 4194304, 6, 'psm_ab4c83cd'), local_subscribe_addr='ipc:///tmp/4485a71f-ab24-49b3-bd93-9bcad57113e0', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 04-12 07:54:15 [parallel_state.py:957] rank 0 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 0
(VllmWorkerProcess pid=577) INFO 04-12 07:54:15 [parallel_state.py:957] rank 15 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 15
(VllmWorkerProcess pid=556) INFO 04-12 07:54:15 [parallel_state.py:957] rank 12 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 12
(VllmWorkerProcess pid=563) INFO 04-12 07:54:15 [parallel_state.py:957] rank 13 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 13
(VllmWorkerProcess pid=570) INFO 04-12 07:54:15 [parallel_state.py:957] rank 14 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 14
(VllmWorkerProcess pid=549) INFO 04-12 07:54:15 [parallel_state.py:957] rank 11 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 11
(VllmWorkerProcess pid=535) INFO 04-12 07:54:15 [parallel_state.py:957] rank 9 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 9
(VllmWorkerProcess pid=542) INFO 04-12 07:54:15 [parallel_state.py:957] rank 10 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 10
(VllmWorkerProcess pid=514) INFO 04-12 07:54:15 [parallel_state.py:957] rank 6 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 6
(VllmWorkerProcess pid=528) INFO 04-12 07:54:15 [parallel_state.py:957] rank 8 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 8
(VllmWorkerProcess pid=521) INFO 04-12 07:54:15 [parallel_state.py:957] rank 7 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 7
(VllmWorkerProcess pid=493) INFO 04-12 07:54:15 [parallel_state.py:957] rank 3 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 3
(VllmWorkerProcess pid=500) INFO 04-12 07:54:15 [parallel_state.py:957] rank 4 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 4
(VllmWorkerProcess pid=486) INFO 04-12 07:54:15 [parallel_state.py:957] rank 2 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 2
(VllmWorkerProcess pid=507) INFO 04-12 07:54:15 [parallel_state.py:957] rank 5 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 5
(VllmWorkerProcess pid=479) INFO 04-12 07:54:15 [parallel_state.py:957] rank 1 in world size 16 is assigned as DP rank 0, PP rank 0, TP rank 1
(VllmWorkerProcess pid=493) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=500) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=486) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=479) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=528) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=535) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=514) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=542) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=507) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=556) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=563) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=570) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=577) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=549) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=521) INFO 04-12 07:54:15 [model_runner.py:1110] Starting to load model /models/wanzhenchn_DeepSeek-R1-AWQ/...
(VllmWorkerProcess pid=493) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=500) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=486) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=528) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=479) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=535) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=514) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=556) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=507) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=542) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=563) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=570) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=577) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=549) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
(VllmWorkerProcess pid=521) WARNING 04-12 07:54:15 [utils.py:165] The model class DeepseekV3ForCausalLM has not defined `packed_modules_mapping`, this may lead to incorrect mapping of quantized or ignored modules
Loading safetensors checkpoint shards:   0% Completed | 0/71 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:   1% Completed | 1/71 [00:24<28:05, 24.08s/it]
Loading safetensors checkpoint shards:   3% Completed | 2/71 [00:47<27:17, 23.73s/it]
Loading safetensors checkpoint shards:   4% Completed | 3/71 [01:11<26:52, 23.72s/it]
Loading safetensors checkpoint shards:   6% Completed | 4/71 [01:34<26:21, 23.60s/it]
Loading safetensors checkpoint shards:   7% Completed | 5/71 [01:58<26:02, 23.67s/it]
Loading safetensors checkpoint shards:   8% Completed | 6/71 [02:21<25:27, 23.50s/it]
Loading safetensors checkpoint shards:  10% Completed | 7/71 [02:44<25:00, 23.44s/it]
Loading safetensors checkpoint shards:  11% Completed | 8/71 [03:08<24:40, 23.50s/it]
Loading safetensors checkpoint shards:  13% Completed | 9/71 [03:32<24:18, 23.53s/it]
Loading safetensors checkpoint shards:  14% Completed | 10/71 [03:55<23:55, 23.54s/it]
Loading safetensors checkpoint shards:  15% Completed | 11/71 [04:19<23:36, 23.61s/it]
Loading safetensors checkpoint shards:  17% Completed | 12/71 [04:42<23:05, 23.48s/it]
Loading safetensors checkpoint shards:  18% Completed | 13/71 [05:05<22:38, 23.42s/it]
Loading safetensors checkpoint shards:  20% Completed | 14/71 [05:29<22:21, 23.53s/it]
Loading safetensors checkpoint shards:  21% Completed | 15/71 [05:53<21:55, 23.48s/it]
Loading safetensors checkpoint shards:  23% Completed | 16/71 [06:16<21:24, 23.35s/it]
Loading safetensors checkpoint shards:  24% Completed | 17/71 [06:39<20:58, 23.31s/it]
Loading safetensors checkpoint shards:  25% Completed | 18/71 [07:02<20:31, 23.24s/it]
Loading safetensors checkpoint shards:  27% Completed | 19/71 [07:25<20:10, 23.28s/it]
Loading safetensors checkpoint shards:  28% Completed | 20/71 [07:49<19:53, 23.40s/it]
Loading safetensors checkpoint shards:  30% Completed | 21/71 [08:13<19:37, 23.54s/it]
Loading safetensors checkpoint shards:  31% Completed | 22/71 [08:37<19:15, 23.58s/it]
Loading safetensors checkpoint shards:  32% Completed | 23/71 [09:00<18:51, 23.58s/it]
Loading safetensors checkpoint shards:  34% Completed | 24/71 [09:24<18:29, 23.61s/it]
Loading safetensors checkpoint shards:  35% Completed | 25/71 [09:47<18:00, 23.48s/it]
Loading safetensors checkpoint shards:  37% Completed | 26/71 [10:10<17:35, 23.46s/it]
Loading safetensors checkpoint shards:  38% Completed | 27/71 [10:34<17:18, 23.61s/it]
Loading safetensors checkpoint shards:  39% Completed | 28/71 [10:58<16:52, 23.55s/it]
Loading safetensors checkpoint shards:  41% Completed | 29/71 [11:22<16:31, 23.60s/it]
Loading safetensors checkpoint shards:  42% Completed | 30/71 [11:45<16:06, 23.58s/it]
Loading safetensors checkpoint shards:  44% Completed | 31/71 [12:08<15:40, 23.51s/it]
Loading safetensors checkpoint shards:  45% Completed | 32/71 [12:32<15:21, 23.63s/it]
Loading safetensors checkpoint shards:  46% Completed | 33/71 [12:56<14:56, 23.59s/it]
Loading safetensors checkpoint shards:  48% Completed | 34/71 [13:19<14:30, 23.53s/it]
Loading safetensors checkpoint shards:  49% Completed | 35/71 [13:43<14:05, 23.49s/it]
Loading safetensors checkpoint shards:  51% Completed | 36/71 [13:56<11:55, 20.44s/it]
Loading safetensors checkpoint shards:  52% Completed | 37/71 [14:19<12:06, 21.37s/it]
Loading safetensors checkpoint shards:  54% Completed | 38/71 [14:43<12:07, 22.03s/it]
Loading safetensors checkpoint shards:  55% Completed | 39/71 [15:06<11:58, 22.45s/it]
Loading safetensors checkpoint shards:  56% Completed | 40/71 [15:30<11:43, 22.71s/it]
Loading safetensors checkpoint shards:  58% Completed | 41/71 [15:53<11:27, 22.92s/it]
Loading safetensors checkpoint shards:  59% Completed | 42/71 [16:17<11:08, 23.07s/it]
Loading safetensors checkpoint shards:  61% Completed | 43/71 [16:40<10:51, 23.28s/it]
Loading safetensors checkpoint shards:  62% Completed | 44/71 [17:04<10:29, 23.31s/it]
Loading safetensors checkpoint shards:  63% Completed | 45/71 [17:27<10:04, 23.27s/it]
Loading safetensors checkpoint shards:  65% Completed | 46/71 [17:49<09:31, 22.88s/it]
Loading safetensors checkpoint shards:  66% Completed | 47/71 [18:12<09:09, 22.90s/it]
Loading safetensors checkpoint shards:  68% Completed | 48/71 [18:35<08:46, 22.88s/it]
Loading safetensors checkpoint shards:  69% Completed | 49/71 [18:57<08:21, 22.81s/it]
Loading safetensors checkpoint shards:  70% Completed | 50/71 [19:20<07:56, 22.70s/it]
Loading safetensors checkpoint shards:  72% Completed | 51/71 [19:42<07:32, 22.61s/it]
Loading safetensors checkpoint shards:  73% Completed | 52/71 [20:06<07:13, 22.83s/it]
Loading safetensors checkpoint shards:  75% Completed | 53/71 [20:28<06:50, 22.83s/it]
Loading safetensors checkpoint shards:  76% Completed | 54/71 [20:51<06:27, 22.77s/it]
Loading safetensors checkpoint shards:  77% Completed | 55/71 [21:14<06:04, 22.81s/it]
Loading safetensors checkpoint shards:  79% Completed | 56/71 [21:16<04:08, 16.57s/it]
Loading safetensors checkpoint shards:  80% Completed | 57/71 [21:37<04:13, 18.07s/it]
Loading safetensors checkpoint shards:  82% Completed | 58/71 [22:01<04:15, 19.64s/it]
Loading safetensors checkpoint shards:  83% Completed | 59/71 [22:24<04:07, 20.58s/it]
Loading safetensors checkpoint shards:  85% Completed | 60/71 [22:47<03:54, 21.33s/it]
Loading safetensors checkpoint shards:  86% Completed | 61/71 [23:09<03:36, 21.67s/it]
Loading safetensors checkpoint shards:  87% Completed | 62/71 [23:32<03:18, 22.05s/it]
Loading safetensors checkpoint shards:  89% Completed | 63/71 [23:55<02:58, 22.32s/it]
Loading safetensors checkpoint shards:  90% Completed | 64/71 [24:18<02:37, 22.51s/it]
Loading safetensors checkpoint shards:  92% Completed | 65/71 [24:41<02:16, 22.67s/it]
Loading safetensors checkpoint shards:  93% Completed | 66/71 [25:03<01:52, 22.49s/it]
Loading safetensors checkpoint shards:  94% Completed | 67/71 [25:24<01:28, 22.09s/it]
Loading safetensors checkpoint shards:  96% Completed | 68/71 [25:45<01:05, 21.78s/it]
Loading safetensors checkpoint shards:  97% Completed | 69/71 [26:07<00:43, 21.65s/it]
Loading safetensors checkpoint shards:  99% Completed | 70/71 [26:19<00:18, 18.74s/it]
Loading safetensors checkpoint shards: 100% Completed | 71/71 [26:21<00:00, 13.80s/it]
Loading safetensors checkpoint shards: 100% Completed | 71/71 [26:21<00:00, 22.27s/it]

(VllmWorkerProcess pid=486) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.06 seconds
(VllmWorkerProcess pid=507) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.16 seconds
(VllmWorkerProcess pid=493) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.01 seconds
(VllmWorkerProcess pid=570) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.87 seconds
(VllmWorkerProcess pid=549) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.96 seconds
(VllmWorkerProcess pid=577) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.93 seconds
(VllmWorkerProcess pid=535) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.98 seconds
(VllmWorkerProcess pid=563) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.86 seconds
INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.03 seconds
(VllmWorkerProcess pid=528) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.00 seconds
(VllmWorkerProcess pid=542) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.94 seconds
(VllmWorkerProcess pid=479) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.93 seconds
(VllmWorkerProcess pid=500) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.16 seconds
(VllmWorkerProcess pid=556) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.94 seconds
(VllmWorkerProcess pid=514) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1581.89 seconds
(VllmWorkerProcess pid=521) INFO 04-12 08:20:41 [loader.py:458] Loading weights took 1582.01 seconds
INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.210946 seconds
(VllmWorkerProcess pid=542) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.195546 seconds
(VllmWorkerProcess pid=500) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.190996 seconds
(VllmWorkerProcess pid=528) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.263881 seconds
(VllmWorkerProcess pid=507) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.374184 seconds
(VllmWorkerProcess pid=521) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.418148 seconds
(VllmWorkerProcess pid=556) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.404558 seconds
(VllmWorkerProcess pid=514) INFO 04-12 08:20:59 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.526407 seconds
(VllmWorkerProcess pid=535) INFO 04-12 08:21:00 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.686956 seconds
(VllmWorkerProcess pid=563) INFO 04-12 08:21:00 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.704672 seconds
(VllmWorkerProcess pid=479) INFO 04-12 08:21:00 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.742247 seconds
(VllmWorkerProcess pid=570) INFO 04-12 08:21:00 [model_runner.py:1146] Model loading took 21.2080 GiB and 1604.719051 seconds
(VllmWorkerProcess pid=549) INFO 04-12 08:21:00 [model_runner.py:1146] Model loading took 21.2080 GiB and 1605.263926 seconds
(VllmWorkerProcess pid=577) INFO 04-12 08:21:01 [model_runner.py:1146] Model loading took 21.2080 GiB and 1605.622489 seconds
(VllmWorkerProcess pid=486) INFO 04-12 08:21:01 [model_runner.py:1146] Model loading took 21.2080 GiB and 1605.662619 seconds
(VllmWorkerProcess pid=493) INFO 04-12 08:21:01 [model_runner.py:1146] Model loading took 21.2080 GiB and 1605.909308 seconds
WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=514) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=535) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=549) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=500) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=479) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=563) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=493) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=528) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=542) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=570) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=577) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=507) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=556) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=521) WARNING 04-12 08:21:11 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=486) WARNING 04-12 08:21:12 [fused_moe.py:659] Using default MoE config. Performance might be sub-optimal! Config file not found at /usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/configs/E=256,N=14336,device_name=NVIDIA_GeForce_RTX_3090.json
(VllmWorkerProcess pid=486) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.49 seconds
(VllmWorkerProcess pid=486) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=486) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=549) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.55 seconds
(VllmWorkerProcess pid=549) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=549) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=570) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.55 seconds
(VllmWorkerProcess pid=570) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=570) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=479) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.56 seconds
(VllmWorkerProcess pid=479) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=479) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=521) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.56 seconds
(VllmWorkerProcess pid=521) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=521) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=507) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.59 seconds
(VllmWorkerProcess pid=507) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=507) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=577) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.56 seconds
(VllmWorkerProcess pid=577) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=577) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=493) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.57 seconds
(VllmWorkerProcess pid=493) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=493) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=514) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.56 seconds
(VllmWorkerProcess pid=514) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=514) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=535) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.58 seconds
(VllmWorkerProcess pid=535) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=535) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=563) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.57 seconds
(VllmWorkerProcess pid=563) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=563) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=556) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.60 seconds
(VllmWorkerProcess pid=556) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=556) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=528) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.57 seconds
(VllmWorkerProcess pid=528) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=528) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=500) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.60 seconds
(VllmWorkerProcess pid=500) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=500) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
(VllmWorkerProcess pid=542) INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.57 seconds
(VllmWorkerProcess pid=542) INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
(VllmWorkerProcess pid=542) INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
INFO 04-12 08:21:19 [worker.py:267] Memory profiling takes 17.76 seconds
INFO 04-12 08:21:19 [worker.py:267] the current vLLM instance can use total_gpu_memory (23.58GiB) x gpu_memory_utilization (0.99) = 23.35GiB
INFO 04-12 08:21:19 [worker.py:267] model weights take 21.21GiB; non_torch_memory takes 0.20GiB; PyTorch activation peak memory takes 0.37GiB; the rest of the memory reserved for KV Cache is 1.57GiB.
INFO 04-12 08:21:19 [executor_base.py:112] # cuda blocks: 1502, # CPU blocks: 3819
INFO 04-12 08:21:19 [executor_base.py:117] Maximum concurrency for 2048 tokens per request: 11.73x
(VllmWorkerProcess pid=507) INFO 04-12 08:22:02 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=577) INFO 04-12 08:22:03 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=570) INFO 04-12 08:22:03 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=542) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=563) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=514) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=535) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=556) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=521) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=549) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=528) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=493) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=486) INFO 04-12 08:22:04 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 04-12 08:22:05 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=500) INFO 04-12 08:22:05 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
(VllmWorkerProcess pid=479) INFO 04-12 08:22:05 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_utilization` or switching to eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
Capturing CUDA graph shapes:  75%|███████▌  | 3/4 [00:07<00:02,  2.65s/it]
ERROR 04-12 08:22:13 [engine.py:448] CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 23.58 GiB of which 128.00 MiB is free. Process 4109071 has 23.44 GiB memory in use. Of the allocated memory 22.82 GiB is allocated by PyTorch, with 22.00 MiB allocated in private pools (e.g., CUDA Graphs), and 36.20 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
ERROR 04-12 08:22:13 [engine.py:448] Traceback (most recent call last):
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 436, in run_mp_engine
ERROR 04-12 08:22:13 [engine.py:448]     engine = MQLLMEngine.from_vllm_config(
ERROR 04-12 08:22:13 [engine.py:448]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 128, in from_vllm_config
ERROR 04-12 08:22:13 [engine.py:448]     return cls(
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/multiprocessing/engine.py", line 82, in __init__
ERROR 04-12 08:22:13 [engine.py:448]     self.engine = LLMEngine(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 285, in __init__
ERROR 04-12 08:22:13 [engine.py:448]     self._initialize_kv_caches()
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/engine/llm_engine.py", line 447, in _initialize_kv_caches
ERROR 04-12 08:22:13 [engine.py:448]     self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 123, in initialize_cache
ERROR 04-12 08:22:13 [engine.py:448]     self.collective_rpc("initialize_cache",
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/executor_base.py", line 331, in collective_rpc
ERROR 04-12 08:22:13 [engine.py:448]     return self._run_workers(method, *args, **(kwargs or {}))
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/executor/mp_distributed_executor.py", line 185, in _run_workers
ERROR 04-12 08:22:13 [engine.py:448]     driver_worker_output = run_method(self.driver_worker, sent_method,
ERROR 04-12 08:22:13 [engine.py:448]                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils.py", line 2364, in run_method
ERROR 04-12 08:22:13 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 308, in initialize_cache
ERROR 04-12 08:22:13 [engine.py:448]     self._warm_up_model()
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/worker.py", line 338, in _warm_up_model
ERROR 04-12 08:22:13 [engine.py:448]     self.model_runner.capture_model(self.gpu_cache)
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 04-12 08:22:13 [engine.py:448]     return func(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1585, in capture_model
ERROR 04-12 08:22:13 [engine.py:448]     graph_runner.capture(**capture_inputs)
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/worker/model_runner.py", line 1954, in capture
ERROR 04-12 08:22:13 [engine.py:448]     self.model(
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 703, in forward
ERROR 04-12 08:22:13 [engine.py:448]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 04-12 08:22:13 [engine.py:448]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/compilation/decorators.py", line 172, in __call__
ERROR 04-12 08:22:13 [engine.py:448]     return self.forward(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 660, in forward
ERROR 04-12 08:22:13 [engine.py:448]     hidden_states, residual = layer(positions, hidden_states, residual)
ERROR 04-12 08:22:13 [engine.py:448]                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 580, in forward
ERROR 04-12 08:22:13 [engine.py:448]     hidden_states = self.mlp(hidden_states)
ERROR 04-12 08:22:13 [engine.py:448]                     ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/models/deepseek_v2.py", line 159, in forward
ERROR 04-12 08:22:13 [engine.py:448]     final_hidden_states = self.experts(
ERROR 04-12 08:22:13 [engine.py:448]                           ^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return self._call_impl(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 04-12 08:22:13 [engine.py:448]     return forward_call(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 842, in forward
ERROR 04-12 08:22:13 [engine.py:448]     return self.forward_impl(hidden_states, router_logits)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 861, in forward_impl
ERROR 04-12 08:22:13 [engine.py:448]     final_hidden_states = self.quant_method.apply(
ERROR 04-12 08:22:13 [engine.py:448]                           ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/quantization/awq_marlin.py", line 491, in apply
ERROR 04-12 08:22:13 [engine.py:448]     topk_weights, topk_ids = FusedMoE.select_experts(
ERROR 04-12 08:22:13 [engine.py:448]                              ^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/layer.py", line 798, in select_experts
ERROR 04-12 08:22:13 [engine.py:448]     topk_weights, topk_ids = grouped_topk(
ERROR 04-12 08:22:13 [engine.py:448]                              ^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
ERROR 04-12 08:22:13 [engine.py:448]     return fn(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 878, in grouped_topk
ERROR 04-12 08:22:13 [engine.py:448]     @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
ERROR 04-12 08:22:13 [engine.py:448]     return fn(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/aot_autograd.py", line 1184, in forward
ERROR 04-12 08:22:13 [engine.py:448]     return compiled_fn(full_args)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 323, in runtime_wrapper
ERROR 04-12 08:22:13 [engine.py:448]     all_outs = call_func_at_runtime_with_args(
ERROR 04-12 08:22:13 [engine.py:448]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
ERROR 04-12 08:22:13 [engine.py:448]     out = normalize_as_list(f(args))
ERROR 04-12 08:22:13 [engine.py:448]                             ^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 672, in inner_fn
ERROR 04-12 08:22:13 [engine.py:448]     outs = compiled_fn(args)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 490, in wrapper
ERROR 04-12 08:22:13 [engine.py:448]     return compiled_fn(runtime_args)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/output_code.py", line 466, in __call__
ERROR 04-12 08:22:13 [engine.py:448]     return self.current_callable(inputs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/utils.py", line 2128, in run
ERROR 04-12 08:22:13 [engine.py:448]     return model(new_inputs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/tmp/torchinductor_root/5e/c5exfxaw6h3xusrjbccjrniiw4mr4teu7nrjlqd4e2ndsgsndcvl.py", line 329, in call
ERROR 04-12 08:22:13 [engine.py:448]     triton_poi_fused_add_sigmoid_0.run(arg0_1, arg1_1, buf0, 256, grid=grid(256), stream=stream0)
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 1034, in run
ERROR 04-12 08:22:13 [engine.py:448]     self.autotune_to_one_config(*args, grid=grid, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 911, in autotune_to_one_config
ERROR 04-12 08:22:13 [engine.py:448]     timings = self.benchmark_all_configs(*args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 886, in benchmark_all_configs
ERROR 04-12 08:22:13 [engine.py:448]     launcher: self.bench(launcher, *args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/triton_heuristics.py", line 787, in bench
ERROR 04-12 08:22:13 [engine.py:448]     return benchmarker.benchmark_gpu(kernel_call, rep=40)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/benchmarking.py", line 66, in wrapper
ERROR 04-12 08:22:13 [engine.py:448]     return fn(self, *args, **kwargs)
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/torch/_inductor/runtime/benchmarking.py", line 202, in benchmark_gpu
ERROR 04-12 08:22:13 [engine.py:448]     return self.triton_do_bench(_callable, **kwargs, return_mode="median")
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/triton/testing.py", line 120, in do_bench
ERROR 04-12 08:22:13 [engine.py:448]     cache = runtime.driver.active.get_empty_cache_for_benchmark()
ERROR 04-12 08:22:13 [engine.py:448]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448]   File "/usr/local/lib/python3.12/dist-packages/triton/backends/nvidia/driver.py", line 481, in get_empty_cache_for_benchmark
ERROR 04-12 08:22:13 [engine.py:448]     return torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
ERROR 04-12 08:22:13 [engine.py:448]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 04-12 08:22:13 [engine.py:448] torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 23.58 GiB of which 128.00 MiB is free. Process 4109071 has 23.44 GiB memory in use. Of the allocated memory 22.82 GiB is allocated by PyTorch, with 22.00 MiB allocated in private pools (e.g., CUDA Graphs), and 36.20 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Traceback (most recent call last):
  File "/usr/local/bin/vllm", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/main.py", line 51, in main
    args.dispatch_function(args)
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/cli/serve.py", line 27, in cmd
    uvloop.run(run_server(args))
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/usr/local/lib/python3.12/dist-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 1069, in run_server
    async with build_async_engine_client(args) as engine_client:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 146, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/vllm/entrypoints/openai/api_server.py", line 269, in build_async_engine_client_from_engine_args
    raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.

I will test new CI now with --enable-chunked-prefill to see if it makes a difference. Sorry that I can't easily debug.

Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
@jinzhen-lin
Copy link
Contributor Author

@mgoin This PR is ready now. The failed tests seems not related to this PR.

@mgoin mgoin enabled auto-merge (squash) April 14, 2025 21:26
@vllm-bot vllm-bot merged commit d06ba4e into vllm-project:main Apr 15, 2025
63 of 69 checks passed
@vivienfanghuagood
Copy link

What's the meaning of "Performance on DeepSeek-V3-AWQ (on 8*A800)"? Output tokens per seconds?

@jinzhen-lin
Copy link
Contributor Author

What's the meaning of "Performance on DeepSeek-V3-AWQ (on 8*A800)"? Output tokens per seconds?

Yes, but these values are copied from the statistics logs printed by vllm, and not from an end-to-end benchmark test. The actual results may be more complex. Additionally, the benchmark results are from a month ago, and in the past month, I have optimized the operators multiple times, so the current results should be slightly better than those.

yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Yang Wang <elainewy@meta.com>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build force-merge kernel quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants