Skip to content

Conversation

@dvrogozh
Copy link
Contributor

@dvrogozh dvrogozh commented Jul 9, 2025

Purpose

XPUModelRunner inherits from GPUModelRunner and customizes couple methods. In 2 cases customization is just dispatching logic for different torch backends (cuda vs. xpu). Furhter, as vLLM has generic tests for the v1 core not differentiating between cuda or xpu, it makes sense to have single GPUModelRunner covering both cuda and xpu as it's already covered by existing tests which just need to be enabled for XPU to get good coverage. This commit implements described approach.

Test Result

After the commit these tests which previously were failing now pass:

  • tests/v1/worker/test_gpu_model_runner.py::test_init_kv_cache_without_kv_sharing
  • tests/v1/worker/test_gpu_model_runner.py::test_init_kv_cache_with_kv_sharing_valid

The change in _sync_device() is taking effect in this test:

  • tests/v1/engine/test_llm_engine.py::test_engine_metrics

Overall, as of a3e4e85 with this change applied all tests in tests/v1/worker and tests/v1/engine are now passing on XPU.

CC: @Liangliang-Ma

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @dvrogozh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request refactors the GPUModelRunner to be more device-agnostic, specifically accommodating both CUDA and XPU devices within a single implementation. By introducing conditional logic for device property initialization and synchronization, it eliminates the need for XPUModelRunner to override these methods, streamlining the codebase and enabling existing GPUModelRunner tests to cover XPU functionality.

Highlights

  • Unified Device Handling in GPUModelRunner: The GPUModelRunner's __init__ method now explicitly converts the device argument to a torch.device object, enabling robust device type checks (e.g., self.device.type).
  • Conditional Device Property Initialization: The _init_device_properties method in GPUModelRunner has been updated to conditionally retrieve device properties. It calls torch.cuda.get_device_properties only if the device type is 'cuda'; otherwise, device_properties and num_sms are set to None, making the method compatible with non-CUDA devices like XPU.
  • Device-Specific Synchronization: The _sync_device method in GPUModelRunner now dispatches synchronization calls based on the device type. It calls torch.cuda.synchronize() for CUDA devices and torch.xpu.synchronize() for XPU devices, centralizing the synchronization logic.
  • XPUModelRunner Simplification: The XPUModelRunner class has been simplified by removing its overridden _init_device_properties and _sync_device methods. Their functionality is now handled generically within the base GPUModelRunner, reducing code duplication and improving maintainability.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jul 9, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The code changes generalize the GPUModelRunner to support both CUDA and XPU devices. The changes involve dispatching device-specific calls based on self.device.type. The pull request includes suggestions to improve error handling and code robustness.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reviewers, without this change self.device might be a string and tests will be failing:

$ pytest -rsf tests/v1/worker/test_gpu_model_runner.py::test_init_kv_cache_without_kv_sharing
...
    def _init_device_properties(self) -> None:
        """Initialize attributes from torch.cuda.get_device_properties
        """
>       if self.device.type == "cuda":
           ^^^^^^^^^^^^^^^^
E       AttributeError: 'str' object has no attribute 'type'

@github-actions
Copy link

github-actions bot commented Jul 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

`XPUModelRunner` inherits from `GPUModelRunner` and customizes
couple methods. In 2 cases these customization is just dispatching
logic for different torch backends (cuda vs. xpu). Furhter, as vLLM
has generic tests not differentiating between cuda or xpu, it makes
sense to have single `GPUModelRunner` covering both cuda and xpu.
This commit implements described approach.

After the commit these tests which previously were failing now pass:
* `tests/v1/worker/test_gpu_model_runner.py::test_init_kv_cache_without_kv_sharing`
* `tests/v1/worker/test_gpu_model_runner.py::test_init_kv_cache_with_kv_sharing_valid`

The change in `_sync_device` is taking effect in this test:
* `tests/v1/engine/test_llm_engine.py::test_engine_metrics`

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@Liangliang-Ma
Copy link
Contributor

Actually dispatching this two calls in device model runner was in this PR: #16441. I think we either use inheritance model runner class or use current_platform abstract, rather than if-else in gpu_model_runner.

@github-actions
Copy link

github-actions bot commented Oct 9, 2025

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Oct 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Over 90 days of inactivity v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants