@@ -187,7 +187,8 @@ template <>
187187struct hash <MatMulPrimitiveHandler::ClassMatmulCacheKey> {
188188 size_t operator ()(
189189 const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
190- return hash<dnnl_dim_t >()(val.b_n_size ) ^ hash<dnnl_dim_t >()(val.b_k_size );
190+ return hash<dnnl_dim_t >()(val.b_n_size ) ^ hash<dnnl_dim_t >()(val.b_k_size ) ^
191+ hash<int >()(static_cast <int >(val.b_type ));
191192 }
192193};
193194
@@ -216,7 +217,8 @@ bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
216217
217218bool operator ==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
218219 const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
219- return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size ;
220+ return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
221+ l.b_type == r.b_type ;
220222}
221223
222224bool operator ==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
@@ -493,8 +495,10 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
493495dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache (
494496 const MSizeCacheKey& key) {
495497 if (m_size_cache_.get () == nullptr ) {
496- ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
497- m_size_cache_ = get_matul_class_primitive_cache (key, primitive_cache_size_);
498+ ClassMatmulCacheKey class_key = {
499+ .b_n_size = b_n_size_, .b_k_size = b_k_size_, .b_type = b_type_};
500+ m_size_cache_ =
501+ get_matul_class_primitive_cache (class_key, primitive_cache_size_);
498502 }
499503 return m_size_cache_->get_or_create (key, [&]() {
500504 dnnl::matmul::primitive_desc desc = this ->create_primitive_desc (key, false );
0 commit comments