Skip to content

Commit 2080b05

Browse files
authored
[cpu][fix] Fix onednn_mm crash on consecutive matmuls with same M,K,N and different dtype (#27472)
Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com>
1 parent 6454afe commit 2080b05

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

csrc/cpu/dnnl_helper.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ template <>
187187
struct hash<MatMulPrimitiveHandler::ClassMatmulCacheKey> {
188188
size_t operator()(
189189
const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
190-
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size);
190+
return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
191+
hash<int>()(static_cast<int>(val.b_type));
191192
}
192193
};
193194

@@ -216,7 +217,8 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
216217

217218
bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
218219
const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
219-
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size;
220+
return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
221+
l.b_type == r.b_type;
220222
}
221223

222224
bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
@@ -493,8 +495,10 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
493495
dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
494496
const MSizeCacheKey& key) {
495497
if (m_size_cache_.get() == nullptr) {
496-
ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
497-
m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
498+
ClassMatmulCacheKey class_key = {
499+
.b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_};
500+
m_size_cache_ =
501+
get_matul_class_primitive_cache(class_key, primitive_cache_size_);
498502
}
499503
return m_size_cache_->get_or_create(key, [&]() {
500504
dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);

csrc/cpu/dnnl_helper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ class MatMulPrimitiveHandler : public DNNLMatMulPrimitiveHandler {
199199
struct ClassMatmulCacheKey {
200200
dnnl_dim_t b_n_size;
201201
dnnl_dim_t b_k_size;
202+
dnnl::memory::data_type b_type;
202203

203204
friend bool operator==(const ClassMatmulCacheKey& l,
204205
const ClassMatmulCacheKey& r);

0 commit comments

Comments
 (0)