-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[cpu][fix] Fix onednn_mm crash on consecutive matmuls with same M,K,N and different dtype #27472
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly addresses a crash in onednn_mm that occurs during consecutive matrix multiplications with identical dimensions but different data types. The fix, which incorporates the weight dtype into the ClassMatmulCacheKey, is sound and properly implemented across the cache key definition, hash function, and equality operator. My review includes two suggestions to further improve the code: one enhances the robustness of the hash function to minimize collisions and maintain performance, and the other improves code clarity by renaming a shadowed variable.
csrc/cpu/dnnl_helper.cpp
Outdated
| return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^ | ||
| hash<int>()(static_cast<int>(val.b_type)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using XOR ^ to combine hash values can lead to a higher rate of collisions, which can degrade the performance of the hash map. A more robust approach is to use a method that mixes the bits more thoroughly, such as a polynomial rolling hash. This is a common practice, for example in Java's hashCode implementation. This principle also applies to the other hash specializations in this file.
| return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^ | |
| hash<int>()(static_cast<int>(val.b_type)); | |
| return (std::hash<dnnl_dim_t>()(val.b_n_size) * 31 + std::hash<dnnl_dim_t>()(val.b_k_size)) * 31 + std::hash<int>()(static_cast<int>(val.b_type)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is something for another PR, let's keep the scope of this PR clear.
csrc/cpu/dnnl_helper.cpp
Outdated
| ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_}; | ||
| m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local variable key declared at line 497 shadows the function parameter key. This is confusing and error-prone, as it's not immediately obvious which key is being used in subsequent calls (get_matul_class_primitive_cache vs m_size_cache_->get_or_create). It's better to use a more descriptive name for the local cache key to avoid shadowing.
| ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_}; | |
| m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_); | |
| ClassMatmulCacheKey class_key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_}; | |
| m_size_cache_ = get_matul_class_primitive_cache(class_key, primitive_cache_size_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
953eb68 to
fa9f0e6
Compare
|
@bigPYJ1151 this should fix the crash reported in #27465 and #27244 |
fa9f0e6 to
b3ae6d4
Compare
… and different dtype Makes weight dtype part of the cache key for `ClassMatmulCacheKey` to avoid having 2 onednn_mm(s) with same src/weight dimensions and different dtypes mapped to the same dnnl::matmul primitive Fixes: vllm-project#27465 Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
b3ae6d4 to
3651fa6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix :)
… and different dtype (vllm-project#27472) Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
… and different dtype (vllm-project#27472) Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
… and different dtype (vllm-project#27472) Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
… and different dtype (vllm-project#27472) Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
[cpu][fix] Fix onednn_mm crash on consecutive matmuls with same M,K,N and different dtype
Makes weight dtype part of the cache key for
ClassMatmulCacheKeyto avoid having 2 onednn_mm(s) with same src/weight dimensions and different dtypes mapped to the same dnnl::matmul primitiveFixes: #27465
Purpose
Fixes: #27465
Test Plan
Reproducer in #27465
Test Result
Reproducer in #27465 passes
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.