Skip to content

Conversation

@fadara01
Copy link
Contributor

@fadara01 fadara01 commented Oct 24, 2025

[cpu][fix] Fix onednn_mm crash on consecutive matmuls with same M,K,N and different dtype

Makes weight dtype part of the cache key for ClassMatmulCacheKey to avoid having 2 onednn_mm(s) with same src/weight dimensions and different dtypes mapped to the same dnnl::matmul primitive

Fixes: #27465

Purpose

Fixes: #27465

Test Plan

Reproducer in #27465

Test Result

Reproducer in #27465 passes


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@fadara01 fadara01 requested a review from bigPYJ1151 as a code owner October 24, 2025 10:22
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a crash in onednn_mm that occurs during consecutive matrix multiplications with identical dimensions but different data types. The fix, which incorporates the weight dtype into the ClassMatmulCacheKey, is sound and properly implemented across the cache key definition, hash function, and equality operator. My review includes two suggestions to further improve the code: one enhances the robustness of the hash function to minimize collisions and maintain performance, and the other improves code clarity by renaming a shadowed variable.

Comment on lines 190 to 191
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
hash<int>()(static_cast<int>(val.b_type));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using XOR ^ to combine hash values can lead to a higher rate of collisions, which can degrade the performance of the hash map. A more robust approach is to use a method that mixes the bits more thoroughly, such as a polynomial rolling hash. This is a common practice, for example in Java's hashCode implementation. This principle also applies to the other hash specializations in this file.

Suggested change
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
hash<int>()(static_cast<int>(val.b_type));
return (std::hash<dnnl_dim_t>()(val.b_n_size) * 31 + std::hash<dnnl_dim_t>()(val.b_k_size)) * 31 + std::hash<int>()(static_cast<int>(val.b_type));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is something for another PR, let's keep the scope of this PR clear.

Comment on lines 497 to 498
ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_};
m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The local variable key declared at line 497 shadows the function parameter key. This is confusing and error-prone, as it's not immediately obvious which key is being used in subsequent calls (get_matul_class_primitive_cache vs m_size_cache_->get_or_create). It's better to use a more descriptive name for the local cache key to avoid shadowing.

Suggested change
ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_};
m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
ClassMatmulCacheKey class_key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_};
m_size_cache_ = get_matul_class_primitive_cache(class_key, primitive_cache_size_);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

@fadara01 fadara01 force-pushed the fix_onednn_mm_crash branch from 953eb68 to fa9f0e6 Compare October 24, 2025 10:24
@fadara01
Copy link
Contributor Author

@bigPYJ1151 this should fix the crash reported in #27465 and #27244

@fadara01 fadara01 force-pushed the fix_onednn_mm_crash branch from fa9f0e6 to b3ae6d4 Compare October 24, 2025 10:35
… and different dtype

Makes weight dtype part of the cache key for `ClassMatmulCacheKey` to avoid having 2 onednn_mm(s)
with same src/weight dimensions and different dtypes mapped to the same dnnl::matmul primitive

Fixes: vllm-project#27465

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Copy link
Member

@bigPYJ1151 bigPYJ1151 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix :)

@bigPYJ1151 bigPYJ1151 enabled auto-merge (squash) October 24, 2025 12:12
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 24, 2025
@bigPYJ1151 bigPYJ1151 merged commit 2080b05 into vllm-project:main Oct 24, 2025
21 checks passed
kingsmad pushed a commit to kingsmad/vllm that referenced this pull request Oct 25, 2025
… and different dtype (vllm-project#27472)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
rohin-garg pushed a commit to rohin-garg/vllm that referenced this pull request Oct 25, 2025
… and different dtype (vllm-project#27472)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
… and different dtype (vllm-project#27472)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
0xrushi pushed a commit to 0xrushi/vllm that referenced this pull request Oct 26, 2025
… and different dtype (vllm-project#27472)

Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: onednn_mm crashes on consecutive bf16, f32 matmuls with same M,K,N

2 participants