Skip to content

Conversation

@bigPYJ1151
Copy link
Member

@bigPYJ1151 bigPYJ1151 commented Sep 17, 2025

Purpose

  • Wrap torch.cuda.Stream to avoid break on CPU backend
  • Fix onednn linear contiguous check to avoid break in torch.compile tensor reuse case.

Test Plan

CI tests

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: jiang1.li <jiang1.li@intel.com>
@mergify mergify bot added the v1 label Sep 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces fixes for CPU-only execution. It correctly stubs out torch.cuda.Stream to prevent crashes in the CPU model runner and disables the unsupported Dual-Batch Overlap (DBO) feature on CPU. It also attempts to fix an issue with torch.compile by relaxing a contiguity check in the OneDNN matrix multiplication kernel. However, this change in the OneDNN kernel is likely to cause memory corruption, as it allows a non-contiguous output tensor without passing its memory layout (strides) to the underlying implementation. This is a critical issue that must be addressed.

TORCH_CHECK(a.dim() == 2);
TORCH_CHECK(a.stride(-1) == 1);
TORCH_CHECK(c.is_contiguous());
TORCH_CHECK(c.stride(-1) == 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Relaxing the check from c.is_contiguous() to c.stride(-1) == 1 without providing the full tensor strides to the underlying OneDNN kernel is dangerous and will likely lead to memory corruption.

When c is not contiguous (e.g., it's a view of a larger tensor, which can be the case with torch.compile's tensor reuse), its rows are not packed together in memory. The MatMulPrimitiveHandler receives c.data_ptr() but does not appear to receive the stride for c's first dimension (unlike for tensor a, where a.stride(0) is passed via exec_args).

Without the stride information, the kernel will write output rows assuming a contiguous layout, overwriting memory that does not belong to c. This can cause silent data corruption and difficult-to-debug crashes.

To fix this correctly, you must either:

  1. Pass the strides of c to MatMulPrimitiveHandler and ensure the OneDNN primitive is configured to use them. This would likely involve adding c.stride(0) to MatMulPrimitiveHandler::ExecArgs.
  2. If modifying the handler is not feasible, you should enforce contiguity. Instead of relaxing the check, you could create a temporary contiguous tensor for the output and copy it back to c if c was not originally contiguous.

Given the potential for silent memory corruption, this is a critical issue.

@jikunshang jikunshang enabled auto-merge (squash) September 17, 2025 08:43
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 17, 2025
@vllm-bot vllm-bot merged commit 9fccd04 into vllm-project:main Sep 17, 2025
90 of 93 checks passed
@bigPYJ1151 bigPYJ1151 deleted the fix-cpu-modelrunner branch September 17, 2025 13:10
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…vllm-project#25046)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: charlifu <charlifu@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…vllm-project#25046)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…vllm-project#25046)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants