Skip to content

Conversation

@jikunshang
Copy link
Collaborator

@jikunshang jikunshang commented Mar 11, 2025

submit a new PR #19560 to simplify.

This PR add intel GPU V1 engine support. Leverage latest chunked_prefill kernel(same API to flash_attn_varlen_func v2) from ipex, 2.6, XPU V1 support is quite smooth most code can share with GPU side.

  1. This PR mainly include several parts:

a. Introduce new chunked prefill kernel implementd in ipex,
b. Refine some xpu related configs
c. Add IPEX_V1 attnention backend based on chunked prefill kernel
d. Support vLLM V1 by adding xpu_worker/xpu_model_runner to handle xpu code path
e. Docker file/ dependency updates.
f. Add a V1 test in CI.
I can split to smaller pieces if necessary.

  1. Why need XPU worker/model_runner?

When we start support V1 xpu last year, we follow the design in V0, implement almost same interface like cuda path did. And try to inherit from cuda path as much as possible. This is due to cuda path may have some cuda specific function call, like torch.tensor.cuda(). And there may be some xpu unsupported features like cuda graph.

The mainly differences are:
A) init device and distributed env, profiler: can not use "cuda", "nccl" hard code
B) a work around for client gpu memory metrics not correct issue.
C) XPUModelRunner init method: can not use parent's since it have some cuda specifc code like multi_processor_count
D) XPUModelRunner prepare_inputs() method: we use flash attn V2, some parameters are not same in flash attn v3

  1. performance: I did a quick benchmark on both V0/V1 on PVC 1100. V1 have 1.55x throughput performance boost compare to V0 (3328.59 vs 2139.68). Please be aware both are not fine tuned parameters.
    command I use:
VLLM_USE_V1=0 python3 benchmarks/benchmark_throughput.py --model meta-llama/Meta-Llama-3-8B-Instruct --dataset benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json  --enforce-eager 
VLLM_USE_V1=1 python3 benchmarks/benchmark_throughput.py --model meta-llama/Meta-Llama-3-8B-Instruct --dataset benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json  --enforce-eager 
  1. We have a internal test pipeline, covers added features and lots of model. We didn't add any UT yet since UT are most designed for cuda. Currently, V1 folder ut pass rate is only less than 30%, most due to cuda hard code. We are also working on fix these failed uts internally. Internal branch pass rate is 80%+. We will update this in further PRs.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added documentation Improvements or additions to documentation ci/build v1 labels Mar 11, 2025
@jikunshang jikunshang marked this pull request as draft March 11, 2025 13:25
@mergify
Copy link

mergify bot commented Mar 12, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jikunshang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@HumerousGorgon
Copy link

I'm attempting to use this PR but every build I do, it ends up failing by telling me "Failed to import from vllm._C", then it claims to not be able to detect a GPU.
Any ideas?

@jikunshang
Copy link
Collaborator Author

I'm attempting to use this PR but every build I do, it ends up failing by telling me "Failed to import from vllm._C", then it claims to not be able to detect a GPU. Any ideas?

can you show me your build command and logs? or can you build & run xpu path without this PR?

@HumerousGorgon
Copy link

I'm attempting to use this PR but every build I do, it ends up failing by telling me "Failed to import from vllm._C", then it claims to not be able to detect a GPU. Any ideas?

can you show me your build command and logs? or can you build & run xpu path without this PR?

I use the commands from the vLLM XPU installation instructions to build from source. First setup a UV environment, then install the required packagegs from the requirements/xpu.txt, then build and install vLLM with the XPU backend. Then, after I install IPEX 2.6.

I run the application with:

python -m vllm.entrypoints.openai.api_server --model "/llm/models/bge-m3" --devic
e xpu --max_model_len 1024

And it spits out:

[W313 10:36:09.756158416 OperatorEntry.cpp:154] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
    registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: XPU
  previous kernel: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:30477
       new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:468 (function operator())
INFO 03-13 10:36:11 [__init__.py:256] Automatically detected platform xpu.
[W313 10:36:11.963612370 OperatorEntry.cpp:154] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
    registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: XPU
  previous kernel: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:30477
       new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:468 (function operator())
INFO 03-13 10:36:12 [api_server.py:912] vLLM API server version 0.7.4.dev412+gb1cc4dfe
INFO 03-13 10:36:12 [api_server.py:913] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/llm/models/bge-m3', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=1024, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=1, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='xpu', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False)
INFO 03-13 10:36:12 [api_server.py:209] Started engine process with PID 636
INFO 03-13 10:36:12 [config.py:540] Found sentence-transformers tokenize configuration.
INFO 03-13 10:36:12 [config.py:2564] Downcasting torch.float32 to torch.float16.
[W313 10:36:13.163887652 OperatorEntry.cpp:154] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
    registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: XPU
  previous kernel: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:30477
       new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:468 (function operator())
INFO 03-13 10:36:15 [__init__.py:256] Automatically detected platform xpu.
INFO 03-13 10:36:16 [config.py:540] Found sentence-transformers tokenize configuration.
INFO 03-13 10:36:16 [config.py:2564] Downcasting torch.float32 to torch.float16.
INFO 03-13 10:36:17 [config.py:436] Found sentence-transformers modules configuration.
INFO 03-13 10:36:17 [config.py:456] Found pooling configuration.
INFO 03-13 10:36:17 [config.py:577] This model supports multiple tasks: {'classify', 'embed', 'score', 'reward'}. Defaulting to 'embed'.
WARNING 03-13 10:36:17 [_logger.py:68] CUDA graph is not supported on XPU, fallback to the eager mode.
WARNING 03-13 10:36:17 [_logger.py:68] uni is not supported on XPU, fallback to ray distributed executor backend.
INFO 03-13 10:36:21 [config.py:436] Found sentence-transformers modules configuration.
INFO 03-13 10:36:21 [config.py:456] Found pooling configuration.
INFO 03-13 10:36:21 [config.py:577] This model supports multiple tasks: {'reward', 'score', 'embed', 'classify'}. Defaulting to 'embed'.
WARNING 03-13 10:36:21 [_logger.py:68] CUDA graph is not supported on XPU, fallback to the eager mode.
WARNING 03-13 10:36:21 [_logger.py:68] uni is not supported on XPU, fallback to ray distributed executor backend.
INFO 03-13 10:36:21 [llm_engine.py:235] Initializing a V0 LLM engine (v0.7.4.dev412+gb1cc4dfe) with config: model='/llm/models/bge-m3', speculative_config=None, tokenizer='/llm/models/bge-m3', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=xpu, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=/llm/models/bge-m3, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=False, chunked_prefill_enabled=False, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=PoolerConfig(pooling_type='CLS', normalize=True, softmax=None, step_tag_id=None, returned_token_ids=None), compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=True,
WARNING 03-13 10:36:21 [_logger.py:68] No existing RAY instance detected. A new instance will be launched with current node resources.
2025-03-13 10:36:22,655 INFO worker.py:1841 -- Started a local Ray instance.
INFO 03-13 10:36:23 [ray_distributed_executor.py:176] use_ray_spmd_worker: False
(pid=934) [W313 10:36:24.861693331 OperatorEntry.cpp:154] Warning: Warning only once for all operators,  other operators may also be overridden.
(pid=934)   Overriding a previously registered kernel for the same operator and the same dispatch key
(pid=934)   operator: aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()
(pid=934)     registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
(pid=934)   dispatch key: XPU
(pid=934)   previous kernel: registered at /pytorch/build/aten/src/ATen/RegisterCPU.cpp:30477
(pid=934)        new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:468 (function operator())
(pid=934) INFO 03-13 10:36:26 [__init__.py:256] Automatically detected platform xpu.
INFO 03-13 10:36:26 [ray_distributed_executor.py:350] non_carry_over_env_vars from config: set()
INFO 03-13 10:36:26 [ray_distributed_executor.py:352] Copying the following environment variables to workers: ['LD_LIBRARY_PATH']
INFO 03-13 10:36:26 [ray_distributed_executor.py:355] If certain env vars should NOT be copied to workers, add them to /root/.config/vllm/ray_non_carry_over_env_vars.json file
INFO 03-13 10:36:27 [xpu.py:35] Cannot use None backend on XPU.
INFO 03-13 10:36:27 [xpu.py:36] Using IPEX attention backend.
WARNING 03-13 10:36:27 [_logger.py:68] Failed to import from vllm._C with ModuleNotFoundError("No module named 'vllm._C'")
INFO 03-13 10:36:27 [importing.py:16] Triton not installed or not compatible; certain GPU-related functions will not be available.
INFO 03-13 10:36:27 [parallel_state.py:948] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
2025:03:13-10:36:27:(  636) |CCL_WARN| value of CCL_ATL_TRANSPORT changed to be ofi (default:mpi)
2025:03:13-10:36:27:(  636) |CCL_WARN| value of CCL_LOCAL_RANK changed to be 0 (default:-1)
2025:03:13-10:36:27:(  636) |CCL_WARN| value of CCL_LOCAL_SIZE changed to be 1 (default:-1)
2025:03:13-10:36:27:(  636) |CCL_WARN| value of CCL_PROCESS_LAUNCHER changed to be none (default:hydra)
2025:03:13-10:36:27:( 1542) |CCL_WARN| no membind support for NUMA node 0, skip thread membind
2025:03:13-10:36:27:(  636) |CCL_WARN| device_family is unknown, topology discovery could be incorrect, it might result in suboptimal performance
2025:03:13-10:36:27:(  636) |CCL_WARN| pidfd is not supported, fallbacks to drmfd exchange mode
2025:03:13-10:36:27:(  636) |CCL_ERROR| ze_fd_manager.cpp:214 fill_device_fds: condition fds[dev_idx] > 0 failed
open failed: fd: -1, errno: No such file or directory
ERROR 03-13 10:36:27 [worker_base.py:620] Error executing method 'init_device'. This might cause deadlock in distributed execution.
ERROR 03-13 10:36:27 [worker_base.py:620] Traceback (most recent call last):
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/vllm/vllm/worker/worker_base.py", line 612, in execute_method
ERROR 03-13 10:36:27 [worker_base.py:620]     return run_method(self, method, args, kwargs)
ERROR 03-13 10:36:27 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/vllm/vllm/utils.py", line 2238, in run_method
ERROR 03-13 10:36:27 [worker_base.py:620]     return func(*args, **kwargs)
ERROR 03-13 10:36:27 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/vllm/vllm/worker/worker_base.py", line 604, in init_device
ERROR 03-13 10:36:27 [worker_base.py:620]     self.worker.init_device()  # type: ignore
ERROR 03-13 10:36:27 [worker_base.py:620]     ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/vllm/vllm/worker/xpu_worker.py", line 82, in init_device
ERROR 03-13 10:36:27 [worker_base.py:620]     self.init_worker_distributed_environment()
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/vllm/vllm/worker/xpu_worker.py", line 180, in init_worker_distributed_environment
ERROR 03-13 10:36:27 [worker_base.py:620]     torch.distributed.all_reduce(torch.zeros(1).xpu())
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
ERROR 03-13 10:36:27 [worker_base.py:620]     return func(*args, **kwargs)
ERROR 03-13 10:36:27 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [worker_base.py:620]   File "/vllm-env/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2806, in all_reduce
ERROR 03-13 10:36:27 [worker_base.py:620]     work = group.allreduce([tensor], opts)
ERROR 03-13 10:36:27 [worker_base.py:620]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [worker_base.py:620] RuntimeError: oneCCL: ze_fd_manager.cpp:214 fill_device_fds: EXCEPTION: open failed: fd: -1, errno: No such file or directory
ERROR 03-13 10:36:27 [engine.py:411] oneCCL: ze_fd_manager.cpp:214 fill_device_fds: EXCEPTION: open failed: fd: -1, errno: No such file or directory
ERROR 03-13 10:36:27 [engine.py:411] Traceback (most recent call last):
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 402, in run_mp_engine
ERROR 03-13 10:36:27 [engine.py:411]     engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
ERROR 03-13 10:36:27 [engine.py:411]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 125, in from_engine_args
ERROR 03-13 10:36:27 [engine.py:411]     return cls(ipc_path=ipc_path,
ERROR 03-13 10:36:27 [engine.py:411]            ^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 77, in __init__
ERROR 03-13 10:36:27 [engine.py:411]     self.engine = LLMEngine(*args, **kwargs)
ERROR 03-13 10:36:27 [engine.py:411]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/engine/llm_engine.py", line 274, in __init__
ERROR 03-13 10:36:27 [engine.py:411]     self.model_executor = executor_class(vllm_config=vllm_config, )
ERROR 03-13 10:36:27 [engine.py:411]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/executor/executor_base.py", line 271, in __init__
ERROR 03-13 10:36:27 [engine.py:411]     super().__init__(*args, **kwargs)
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/executor/executor_base.py", line 52, in __init__
ERROR 03-13 10:36:27 [engine.py:411]     self._init_executor()
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/executor/ray_distributed_executor.py", line 114, in _init_executor
ERROR 03-13 10:36:27 [engine.py:411]     self._init_workers_ray(placement_group)
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/executor/ray_distributed_executor.py", line 393, in _init_workers_ray
ERROR 03-13 10:36:27 [engine.py:411]     self._run_workers("init_device")
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/executor/ray_distributed_executor.py", line 514, in _run_workers
ERROR 03-13 10:36:27 [engine.py:411]     self.driver_worker.execute_method(sent_method, *args, **kwargs)
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/worker/worker_base.py", line 621, in execute_method
ERROR 03-13 10:36:27 [engine.py:411]     raise e
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/worker/worker_base.py", line 612, in execute_method
ERROR 03-13 10:36:27 [engine.py:411]     return run_method(self, method, args, kwargs)
ERROR 03-13 10:36:27 [engine.py:411]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/utils.py", line 2238, in run_method
ERROR 03-13 10:36:27 [engine.py:411]     return func(*args, **kwargs)
ERROR 03-13 10:36:27 [engine.py:411]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/worker/worker_base.py", line 604, in init_device
ERROR 03-13 10:36:27 [engine.py:411]     self.worker.init_device()  # type: ignore
ERROR 03-13 10:36:27 [engine.py:411]     ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/worker/xpu_worker.py", line 82, in init_device
ERROR 03-13 10:36:27 [engine.py:411]     self.init_worker_distributed_environment()
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/vllm/vllm/worker/xpu_worker.py", line 180, in init_worker_distributed_environment
ERROR 03-13 10:36:27 [engine.py:411]     torch.distributed.all_reduce(torch.zeros(1).xpu())
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
ERROR 03-13 10:36:27 [engine.py:411]     return func(*args, **kwargs)
ERROR 03-13 10:36:27 [engine.py:411]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411]   File "/vllm-env/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2806, in all_reduce
ERROR 03-13 10:36:27 [engine.py:411]     work = group.allreduce([tensor], opts)
ERROR 03-13 10:36:27 [engine.py:411]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-13 10:36:27 [engine.py:411] RuntimeError: oneCCL: ze_fd_manager.cpp:214 fill_device_fds: EXCEPTION: open failed: fd: -1, errno: No such file or directory
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/root/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/root/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 413, in run_mp_engine
    raise e
  File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 402, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 125, in from_engine_args
    return cls(ipc_path=ipc_path,
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/engine/multiprocessing/engine.py", line 77, in __init__
    self.engine = LLMEngine(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/engine/llm_engine.py", line 274, in __init__
    self.model_executor = executor_class(vllm_config=vllm_config, )
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/executor/executor_base.py", line 271, in __init__
    super().__init__(*args, **kwargs)
  File "/vllm-env/vllm/vllm/executor/executor_base.py", line 52, in __init__
    self._init_executor()
  File "/vllm-env/vllm/vllm/executor/ray_distributed_executor.py", line 114, in _init_executor
    self._init_workers_ray(placement_group)
  File "/vllm-env/vllm/vllm/executor/ray_distributed_executor.py", line 393, in _init_workers_ray
    self._run_workers("init_device")
  File "/vllm-env/vllm/vllm/executor/ray_distributed_executor.py", line 514, in _run_workers
    self.driver_worker.execute_method(sent_method, *args, **kwargs)
  File "/vllm-env/vllm/vllm/worker/worker_base.py", line 621, in execute_method
    raise e
  File "/vllm-env/vllm/vllm/worker/worker_base.py", line 612, in execute_method
    return run_method(self, method, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/utils.py", line 2238, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/worker/worker_base.py", line 604, in init_device
    self.worker.init_device()  # type: ignore
    ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/worker/xpu_worker.py", line 82, in init_device
    self.init_worker_distributed_environment()
  File "/vllm-env/vllm/vllm/worker/xpu_worker.py", line 180, in init_worker_distributed_environment
    torch.distributed.all_reduce(torch.zeros(1).xpu())
  File "/vllm-env/lib/python3.12/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py", line 2806, in all_reduce
    work = group.allreduce([tensor], opts)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: oneCCL: ze_fd_manager.cpp:214 fill_device_fds: EXCEPTION: open failed: fd: -1, errno: No such file or directory
INFO 03-13 10:36:27 [ray_distributed_executor.py:127] Shutting down Ray distributed executor. If you see error log from logging.cc regarding SIGTERM received, please ignore because this is the expected termination process in Ray.
*** SIGTERM received at time=1741862188 on cpu 10 ***
PC: @     0x707a5c23c7f8  (unknown)  clock_nanosleep
    @     0x707a5c199520  (unknown)  (unknown)
[2025-03-13 10:36:28,136 E 636 636] logging.cc:484: *** SIGTERM received at time=1741862188 on cpu 10 ***
[2025-03-13 10:36:28,136 E 636 636] logging.cc:484: PC: @     0x707a5c23c7f8  (unknown)  clock_nanosleep
[2025-03-13 10:36:28,137 E 636 636] logging.cc:484:     @     0x707a5c199520  (unknown)  (unknown)
Exception ignored in atexit callback: <function shutdown at 0x70784655b240>
Traceback (most recent call last):
  File "/vllm-env/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/lib/python3.12/site-packages/ray/_private/worker.py", line 1910, in shutdown
    time.sleep(0.5)
  File "/vllm-env/lib/python3.12/site-packages/ray/_private/worker.py", line 1499, in sigterm_handler
    sys.exit(signum)
SystemExit: 15
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/vllm-env/vllm/vllm/entrypoints/openai/api_server.py", line 992, in <module>
    uvloop.run(run_server(args))
  File "/vllm-env/lib/python3.12/site-packages/uvloop/__init__.py", line 109, in run
    return __asyncio.run(
           ^^^^^^^^^^^^^^
  File "/root/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/asyncio/runners.py", line 195, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/root/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "uvloop/loop.pyx", line 1518, in uvloop.loop.Loop.run_until_complete
  File "/vllm-env/lib/python3.12/site-packages/uvloop/__init__.py", line 61, in wrapper
    return await main
           ^^^^^^^^^^
  File "/vllm-env/vllm/vllm/entrypoints/openai/api_server.py", line 947, in run_server
    async with build_async_engine_client(args) as engine_client:
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/entrypoints/openai/api_server.py", line 139, in build_async_engine_client
    async with build_async_engine_client_from_engine_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/share/uv/python/cpython-3.12.9-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 210, in __aenter__
    return await anext(self.gen)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/vllm-env/vllm/vllm/entrypoints/openai/api_server.py", line 233, in build_async_engine_client_from_engine_args
    raise RuntimeError(
RuntimeError: Engine process failed to start. See stack trace for the root cause.

@jikunshang
Copy link
Collaborator Author

@HumerousGorgon sorry can you create an issue so we can discuss here. I don't want to mess up further code review here.

@HumerousGorgon
Copy link

@HumerousGorgon sorry can you create an issue so we can discuss here. I don't want to mess up further code review here.

#14747
Thank you :)
I'm excited to use your build; XPU devices are fantastic.

@jikunshang jikunshang marked this pull request as ready for review March 15, 2025 00:43
@jikunshang jikunshang changed the title [Hardware][Intel GPU][WIP] add V1 engine support and chunked_prefill kernel [Hardware][Intel GPU] Add V1 engine support and chunked_prefill kernel Mar 16, 2025
@jikunshang
Copy link
Collaborator Author

@youkaichao @WoosukKwon Can you or someone else can take a review? thanks!

@mgoin
Copy link
Member

mgoin commented Jun 6, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a comprehensive pull request that introduces V1 engine support for Intel GPUs, leveraging the chunked_prefill kernel from IPEX. The changes are well-structured, touching upon CI, Docker, requirements, core vLLM operations, platform-specific logic, and new V1 worker/runner components for XPU.

The PR description is excellent, providing clear rationale for design choices, performance benchmarks, and a roadmap for testing. This greatly aids in understanding the scope and impact of the changes.

Overall, the additions seem robust and well-integrated with the existing V1 architecture. I have a couple of points regarding the _ipex_ops.py and ipex_attn.py files that could enhance clarity and ensure correctness, detailed in the comments below. Great work on this significant feature enablement!

Summary of Findings

  • Parameter Usage in _ipex_ops.py::chunked_prefill: Several parameters in the chunked_prefill wrapper seem unused or hardcoded when calling the underlying IPEX kernel. Clarification or simplification could improve maintainability.
  • Type Consistency for Scale Factors in ipex_attn.py: There's a potential type mismatch for k_scale and v_scale parameters passed to reshape_and_cache_flash. While likely benign in the current non-FP8 context, confirming the IPEX op's behavior with tensor inputs for these float-expected parameters is advisable.

Merge Readiness

The pull request is substantial and introduces significant new functionality for Intel GPU support. The code is generally well-written and the rationale is clearly explained. The identified points are of medium severity and primarily relate to clarity and ensuring robustness of the interface with IPEX ops. Addressing these points would further strengthen the contribution.

Given the complexity and the note about ongoing UT development, thorough internal testing will be key. I am not authorized to approve pull requests, but based on this review, after addressing the comments, the PR appears to be in a good state for further testing and consideration by other reviewers for merging.

Comment on lines 247 to 261
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The chunked_prefill function accepts several parameters (p_dropout, zero_tensors, return_softmax, gen_, seq_used_k) that are either hardcoded (e.g., p_dropout to 0.0, is_casual to True) or passed as None when calling ipex.llm.modules.PagedAttention.flash_attn_varlen_func.

Could you clarify if these parameters are intended as placeholders for future IPEX kernel capabilities, or to maintain API compatibility with other FlashAttention implementations?

If they are unlikely to be used by the IPEX backend, perhaps the signature could be simplified? Alternatively, adding comments explaining their fixed values or non-usage in the IPEX context might improve clarity for future maintainers. For example:

# p_dropout: float, # IPEX kernel does not support dropout, hence 0.0 is used.
# ... similar for other params

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, keep used parameters and rename chunked_prefill to flash_attn_varlen_func

Comment on lines 204 to 205
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The reshape_and_cache_flash function in vllm/_ipex_ops.py (which this calls) is defined with k_scale: float and v_scale: float. However, here layer._k_scale and layer._v_scale are passed, which are torch.Tensor objects.

Given the assertion on line 183 (layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), it implies that FP8 KV cache is not being used here, and these tensors likely contain the scalar value 1.0.

Does the underlying ipex.llm.modules.PagedAttention.reshape_and_cache_flash op correctly handle torch.Tensor inputs for scale factors, perhaps by implicitly converting tensor(1.0) to float(1.0), or are these scale factors ignored when not using an FP8-related kv_cache_dtype?

Ensuring type consistency or documenting the expected behavior by the IPEX op would be beneficial.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch! updated.

@jikunshang jikunshang force-pushed the ipex_26_v1_support branch 3 times, most recently from 13e97bc to 098aa07 Compare June 9, 2025 06:17
@xuechendi
Copy link
Contributor

@mgoin , we have resolved all Gemini review comments, thanks!

@gemini-code-assist
Copy link
Contributor

Thanks for the update, @xuechendi! Glad to hear the review comments have been addressed. Let me know if you need anything else.

@jikunshang jikunshang force-pushed the ipex_26_v1_support branch 2 times, most recently from 3deaa05 to b4ac463 Compare June 11, 2025 09:50
jikunshang and others added 17 commits June 12, 2025 01:10
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

some v1 fixes

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

remove useless file

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

remove

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

add V1 test and set spawn in docker env

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

add missing dependency

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

fix test

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update api name

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update api

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update default block size for v1

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update memory usage

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

fix rebase issues

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

fix rebase, spec decode meta set to none

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

add xpu v1 config check

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

add mem log

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

fix init cache

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

add xpu profiler for V1

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update rebase issue

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update prepare_inputs for perf

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

update

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>

refine xpu_model_runner

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
…one by default. The modification involves adding a check to prevent potential null exceptions。 (vllm-project#173)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Co-authored-by: yan <yan.ma@intel.com>
Co-authored-by: mayuyuace <qiming1.zhang@intel.com>

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
@mergify
Copy link

mergify bot commented Jun 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jikunshang.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build documentation Improvements or additions to documentation needs-rebase v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants